xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/buffer_assignment_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
17 
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/service/async_op_canonicalizer.h"
28 #include "tensorflow/compiler/xla/service/buffer_value.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/copy_insertion.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
32 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_dce.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
39 #include "tensorflow/compiler/xla/service/hlo_parser.h"
40 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
41 #include "tensorflow/compiler/xla/shape_util.h"
42 #include "tensorflow/compiler/xla/test.h"
43 #include "tensorflow/compiler/xla/test_helpers.h"
44 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/lib/core/status_test_util.h"
48 
49 namespace xla {
50 namespace {
51 
52 using memory_space_assignment::PresetAssignments;
53 using ::testing::UnorderedElementsAre;
54 
55 // DFS visitor that collects the instructions referenced by a computation
56 // without descending into nested computations, i.e., only from the operands.
57 class InstructionListVisitor : public DfsHloVisitorWithDefault {
58  public:
InstructionListVisitor(const HloInstruction * root)59   explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {}
60 
DefaultAction(HloInstruction * hlo)61   Status DefaultAction(HloInstruction* hlo) override {
62     // For each instruction, just push it on the list after walking the
63     // operands.
64     instructions_.push_back(hlo);
65     VLOG(0) << "List instruction " << hlo->ToString();
66     return OkStatus();
67   }
68 
GetInstructions()69   std::vector<const HloInstruction*> GetInstructions() { return instructions_; }
70 
71  private:
72   // The instruction root of the computation.
73   const HloInstruction* root_;
74 
75   // The full set of instructions found (may be duplicates, e.g., kParameter).
76   std::vector<const HloInstruction*> instructions_;
77 
78   InstructionListVisitor(const InstructionListVisitor&) = delete;
79   InstructionListVisitor& operator=(const InstructionListVisitor&) = delete;
80 };
81 
GetInstructions(HloInstruction * root)82 const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
83   InstructionListVisitor main_list(root);
84   TF_CHECK_OK(root->Accept(&main_list));
85   return main_list.GetInstructions();
86 }
87 
88 class BufferAssignmentTest : public HloTestBase {
89  protected:
~BufferAssignmentTest()90   ~BufferAssignmentTest() override {}
91 
RunBufferAssignment(HloModule * module,int64_t alignment=1)92   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
93                                                         int64_t alignment = 1) {
94     return BufferAssigner::Run(
95                module, std::make_unique<DependencyHloOrdering>(module),
96                backend().compiler()->BufferSizeBytesFunction(),
97                [alignment](LogicalBuffer::Color) { return alignment; },
98                /*allocate_buffers_for_constants=*/true)
99         .value();
100   }
101 
RunBufferAssignmentWithSequentialOrdering(HloModule * module,int64_t alignment=1)102   std::unique_ptr<BufferAssignment> RunBufferAssignmentWithSequentialOrdering(
103       HloModule* module, int64_t alignment = 1) {
104     return BufferAssigner::Run(
105                module,
106                std::make_unique<SequentialHloOrdering>(module->schedule()),
107                backend().compiler()->BufferSizeBytesFunction(),
108                [alignment](LogicalBuffer::Color) { return alignment; },
109                /*allocate_buffers_for_constants=*/true)
110         .value();
111   }
112 
RunBufferAssignmentNoBuffersForConstants(HloModule * module,int64_t alignment=1)113   std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
114       HloModule* module, int64_t alignment = 1) {
115     return BufferAssigner::Run(
116                module, std::make_unique<DependencyHloOrdering>(module),
117                backend().compiler()->BufferSizeBytesFunction(),
118                [alignment](LogicalBuffer::Color) { return alignment; },
119                /*allocate_buffers_for_constants=*/false)
120         .value();
121   }
122 
RunBufferAssignmentNoBuffersReuseForAdd(HloModule * module,int64_t alignment=1)123   std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersReuseForAdd(
124       HloModule* module, int64_t alignment = 1) {
125     auto must_not_live_out = [](const HloInstruction* instruction,
126                                 const ShapeIndex&) {
127       return instruction->opcode() == HloOpcode::kAdd;
128     };
129 
130     return BufferAssigner::Run(
131                module, std::make_unique<DependencyHloOrdering>(module),
132                backend().compiler()->BufferSizeBytesFunction(),
133                [alignment](LogicalBuffer::Color) { return alignment; },
134                /*allocate_buffers_for_constants=*/false,
135                /*colorer=*/BufferAssigner::DefaultColorer(),
136                /*must_not_live_out=*/must_not_live_out)
137         .value();
138   }
139 
RunColoredBufferAssignment(HloModule * module,BufferAssigner::Colorer colorer,int64_t alignment=1)140   std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
141       HloModule* module, BufferAssigner::Colorer colorer,
142       int64_t alignment = 1) {
143     return BufferAssigner::Run(
144                module, std::make_unique<DependencyHloOrdering>(module),
145                backend().compiler()->BufferSizeBytesFunction(),
146                [alignment](LogicalBuffer::Color) { return alignment; },
147                /*allocate_buffers_for_constants=*/true, std::move(colorer))
148         .value();
149   }
150 
RunBufferAssignmentWithInstructionSequence(HloModule * module,absl::Span<HloInstruction * const> instruction_sequence,int64_t alignment=1)151   std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
152       HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
153       int64_t alignment = 1) {
154     HloSchedule schedule(module);
155     schedule.set_sequence(module->entry_computation(), instruction_sequence);
156     return BufferAssigner::Run(
157                module, std::make_unique<SequentialHloOrdering>(schedule),
158                backend().compiler()->BufferSizeBytesFunction(),
159                [alignment](LogicalBuffer::Color) { return alignment; },
160                /*allocate_buffers_for_constants=*/true)
161         .value();
162   }
163 
RunBufferAssignmentWithPresetAssignments(HloModule * module,std::unique_ptr<PresetAssignments> preset_assignments,int64_t alignment=1)164   std::unique_ptr<BufferAssignment> RunBufferAssignmentWithPresetAssignments(
165       HloModule* module, std::unique_ptr<PresetAssignments> preset_assignments,
166       int64_t alignment = 1) {
167     return BufferAssigner::Run(
168                module, std::make_unique<DependencyHloOrdering>(module),
169                backend().compiler()->BufferSizeBytesFunction(),
170                [alignment](LogicalBuffer::Color) { return alignment; },
171                /*allocate_buffers_for_constants=*/true,
172                BufferAssigner::DefaultColorer(),
173                /*must_not_live_out=*/std::nullopt,
174                /*can_share_buffer=*/nullptr, std::move(preset_assignments))
175         .value();
176   }
177 
178   // Builds an x+1.0 computation to use in a Map.
BuildMapComputationPlus1(const std::string & name)179   std::unique_ptr<HloComputation> BuildMapComputationPlus1(
180       const std::string& name) {
181     auto builder = HloComputation::Builder(name);
182     auto param =
183         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
184     auto value = builder.AddInstruction(
185         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
186     builder.AddInstruction(
187         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
188     return builder.Build();
189   }
190 
BuildReduceComputation(const std::string & name)191   std::unique_ptr<HloComputation> BuildReduceComputation(
192       const std::string& name) {
193     auto builder = HloComputation::Builder(name);
194     auto param =
195         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
196     auto param2 =
197         builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
198     builder.AddInstruction(
199         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2));
200     return builder.Build();
201   }
202 
203   // Builds a simple compare-to-limit (x < 4) computation for a While.
204   //
205   // condition:
206   //   const4[s32] -----------------------------------\
207   //                                                   \
208   //   param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
209   //
BuildWhileConditionComputation(const std::string & name)210   std::unique_ptr<HloComputation> BuildWhileConditionComputation(
211       const std::string& name) {
212     auto builder = HloComputation::Builder(name);
213     auto const4 = builder.AddInstruction(
214         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
215     auto param = builder.AddInstruction(
216         HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
217     auto index = builder.AddInstruction(
218         HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
219     builder.AddInstruction(
220         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
221                                       const4, ComparisonDirection::kLt));
222     return builder.Build();
223   }
224 
225   // Builds a simple body computation for a While.
226   //
227   // body:
228   //   constv[f32[4]] --------------------------------------\
229   //                                                         \
230   //                           /--- get-tuple-elementv[1] --- addv ---\
231   //   param[(s32,f32[4])] ---|                                    tuple
232   //                           \--- get-tuple-elementc[0] --- addc ---/
233   //                                                         /
234   //   const1[s32] -----------------------------------------/
235   //
BuildWhileBodyComputation(const std::string & name)236   std::unique_ptr<HloComputation> BuildWhileBodyComputation(
237       const std::string& name) {
238     auto builder = HloComputation::Builder(name);
239     auto const1 = builder.AddInstruction(
240         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
241     auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
242         LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
243     auto param = builder.AddInstruction(
244         HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
245     auto indexc = builder.AddInstruction(
246         HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
247     auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
248         indexc->shape(), HloOpcode::kAdd, indexc, const1));
249     auto indexv = builder.AddInstruction(
250         HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
251     auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
252         constv->shape(), HloOpcode::kAdd, indexv, constv));
253     builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
254     return builder.Build();
255   }
256 
BuildR0F32UnaryOpComputation(HloOpcode opcode,const std::string & name)257   std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation(
258       HloOpcode opcode, const std::string& name) {
259     auto builder = HloComputation::Builder(name);
260     auto param =
261         builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
262     builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param));
263     return builder.Build();
264   }
265 
266   // Verifies that the given instruction hlo has a valid input buffer assigned,
267   // i.e., the parameter number matches the op's.
GetAssignedInputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)268   const BufferAllocation& GetAssignedInputAllocation(
269       const BufferAssignment& buffers, HloInstruction* hlo) {
270     LOG(INFO) << "Checking input: " << hlo->ToString();
271     const BufferAllocation& buffer =
272         *buffers.GetUniqueTopLevelSlice(hlo).value().allocation();
273     EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number());
274     return buffer;
275   }
276 
277   // Verifies that the given instruction hlo has a valid output buffer
278   // assigned, and returns it.
GetAssignedOutputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)279   const BufferAllocation& GetAssignedOutputAllocation(
280       const BufferAssignment& buffers, HloInstruction* hlo) {
281     LOG(INFO) << "Checking output: " << hlo->ToString();
282     const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo);
283     return buffer;
284   }
285 
286   // Returns the allocation for the given instruction.
GetAllocation(const BufferAssignment & buffers,const HloInstruction * hlo,const ShapeIndex & index)287   const BufferAllocation& GetAllocation(const BufferAssignment& buffers,
288                                         const HloInstruction* hlo,
289                                         const ShapeIndex& index) {
290     return *buffers.GetUniqueSlice(hlo, index).value().allocation();
291   }
GetTopLevelAllocation(const BufferAssignment & buffers,const HloInstruction * hlo)292   const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers,
293                                                 const HloInstruction* hlo) {
294     return *buffers.GetUniqueTopLevelSlice(hlo).value().allocation();
295   }
296 
297   // Verifies that all instructions in the given instruction list except
298   // kConstant have assigned buffers, and returns their total size. If min_index
299   // and max_index are not nullptr, the minimum and maximum buffer indices in
300   // the assignment are written into them.
ValidateBuffers(const std::vector<const HloInstruction * > & instructions,const BufferAssignment & buffers)301   int64_t ValidateBuffers(
302       const std::vector<const HloInstruction*>& instructions,
303       const BufferAssignment& buffers) {
304     // Verifies all instructions have buffers, and gets the index ranges.
305     for (const HloInstruction* hlo : instructions) {
306       if (!buffers.HasTopLevelAllocation(hlo)) {
307         // If `hlo` has no assigned buffer, it is either a constant or a nested
308         // parameter.
309         EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() ||
310                     HloOpcode::kParameter == hlo->opcode());
311         continue;
312       }
313     }
314 
315     // Gets the total size of all buffers assigned.
316     int64_t total_size = 0;
317     for (auto& allocation : buffers.Allocations()) {
318       total_size += allocation.size();
319     }
320     return total_size;
321   }
322 
323   // Shapes for use in the examples.
324   Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
325   Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
326   Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
327   Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10});
328   Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100});
329   Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10});
330   Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_});
331   Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
332 };
333 
334 // Returns true if the buffers assigned to instructions in "a" are distinct
335 // from the buffers assigned to those in "b" (ie, intersection is empty).
BuffersDistinct(const std::vector<const HloInstruction * > & a,const std::vector<const HloInstruction * > & b,const BufferAssignment & assignment)336 static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
337                             const std::vector<const HloInstruction*>& b,
338                             const BufferAssignment& assignment) {
339   absl::flat_hash_set<BufferAllocation::Slice> a_slices;
340   for (const HloInstruction* instruction : a) {
341     if (assignment.HasTopLevelAllocation(instruction)) {
342       a_slices.insert(assignment.GetUniqueTopLevelSlice(instruction).value());
343     }
344   }
345 
346   for (const HloInstruction* instruction : b) {
347     if (assignment.HasTopLevelAllocation(instruction)) {
348       if (a_slices.contains(
349               assignment.GetUniqueTopLevelSlice(instruction).value())) {
350         return false;
351       }
352     }
353   }
354   return true;
355 }
356 
357 // Tests a computation consisting of a single scalar constant node.
TEST_F(BufferAssignmentTest,ScalarConstant)358 TEST_F(BufferAssignmentTest, ScalarConstant) {
359   auto builder = HloComputation::Builder(TestName());
360   auto const0 = builder.AddInstruction(
361       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
362   auto module = CreateNewVerifiedModule();
363   module->AddEntryComputation(builder.Build());
364 
365   {
366     auto buffers = RunBufferAssignment(module.get());
367     EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
368   }
369 
370   {
371     auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
372     EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
373   }
374 }
375 
TEST_F(BufferAssignmentTest,BufferForConst)376 TEST_F(BufferAssignmentTest, BufferForConst) {
377   // Addition of two vector constants: checks that internal constant nodes have
378   // no buffers assigned, and their consumer has a buffer.
379   auto builder = HloComputation::Builder(TestName());
380   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
381       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
382   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
383       LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
384   auto add = builder.AddInstruction(
385       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
386   auto module = CreateNewVerifiedModule();
387   module->AddEntryComputation(builder.Build());
388 
389   {
390     auto buffers = RunBufferAssignment(module.get());
391     EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
392     EXPECT_TRUE(buffers->HasTopLevelAllocation(const1));
393     GetAssignedOutputAllocation(*buffers, add);
394   }
395   {
396     auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
397     EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
398     EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
399     GetAssignedOutputAllocation(*buffers, add);
400   }
401 }
402 
TEST_F(BufferAssignmentTest,HasAllocationAt)403 TEST_F(BufferAssignmentTest, HasAllocationAt) {
404   // Create a tuple with non-const and const elements and check that
405   // HasAllocationAt works correctly.
406   auto builder = HloComputation::Builder(TestName());
407   auto param0 = builder.AddInstruction(
408       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
409   auto constant = builder.AddInstruction(
410       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
411   auto negate = builder.AddInstruction(
412       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
413   auto tuple = builder.AddInstruction(
414       HloInstruction::CreateTuple({negate, param0, constant}));
415   auto module = CreateNewVerifiedModule();
416   module->AddEntryComputation(builder.Build());
417 
418   auto buffers = RunBufferAssignment(module.get());
419   // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
420   // reports for the instruction directly.
421   EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
422             buffers->HasAllocationAt(tuple, /*index=*/{}));
423   EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
424             buffers->HasAllocationAt(tuple, /*index=*/{0}));
425   EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
426             buffers->HasAllocationAt(tuple, /*index=*/{1}));
427   EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
428             buffers->HasAllocationAt(tuple, /*index=*/{2}));
429 }
430 
TEST_F(BufferAssignmentTest,BufferForOutputConst)431 TEST_F(BufferAssignmentTest, BufferForOutputConst) {
432   // This computation copies a constant to output.
433   auto builder = HloComputation::Builder(TestName());
434   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
435       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
436   auto copy = builder.AddInstruction(
437       HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
438   auto module = CreateNewVerifiedModule();
439   module->AddEntryComputation(builder.Build());
440 
441   auto buffers = RunBufferAssignment(module.get());
442   // The copy node now has an output buffer.
443   GetAssignedOutputAllocation(*buffers, copy);
444 }
445 
TEST_F(BufferAssignmentTest,Basic)446 TEST_F(BufferAssignmentTest, Basic) {
447   // paramscalar ------- (mul) -- (add) -- (sub)
448   //                     /        /        /
449   // param0[100] -------/        /        /
450   //                            /        /
451   // param1[100] --------------/--------/
452   auto builder = HloComputation::Builder(TestName());
453   auto paramscalar =
454       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
455   auto broadcast = builder.AddInstruction(
456       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
457   auto param0 = builder.AddInstruction(
458       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
459   auto param1 = builder.AddInstruction(
460       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
461   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
462       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
463   auto add = builder.AddInstruction(
464       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
465   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
466       f32vec100_, HloOpcode::kSubtract, add, param1));
467   auto module = CreateNewVerifiedModule();
468   module->AddEntryComputation(builder.Build());
469 
470   auto buffers = RunBufferAssignment(module.get());
471 
472   // Distinct input buffers were assigned for parameters.
473   BufferAllocation paramscalar_buffer =
474       GetAssignedInputAllocation(*buffers, paramscalar);
475   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
476   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
477   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
478   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
479   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
480 
481   // The mul node has a valid buffer assigned, doesn't share with input.
482   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
483   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
484 
485   // The add node can reuse the mul node's buffer.
486   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
487   EXPECT_EQ(add_buffer.index(), mul_buffer.index());
488 
489   // The sub node has a valid output buffer assigned.
490   GetAssignedOutputAllocation(*buffers, sub);
491 }
492 
TEST_F(BufferAssignmentTest,AliasedParamCanBeReused)493 TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) {
494   // If an input buffer and output buffer aliases, the input buffer can be
495   // reused for other intermediate results.
496   //
497   // param0[100] ----- (neg1) -- (neg2)
498   //    |                           |
499   //    + -------- Aliased ---------+
500 
501   auto builder = HloComputation::Builder(TestName());
502 
503   auto param = builder.AddInstruction(
504       HloInstruction::CreateParameter(0, f32vec100_, "p0"));
505   auto neg_1 = builder.AddInstruction(
506       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
507   auto neg_2 = builder.AddInstruction(
508       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1));
509 
510   auto module = CreateNewVerifiedModule();
511   module->AddEntryComputation(builder.Build());
512 
513   TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({}, 0, {}));
514 
515   auto buffers = RunBufferAssignment(module.get());
516 
517   BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param);
518   BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {});
519   BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {});
520 
521   // Everything use one buffer.
522   EXPECT_EQ(param_buffer.index(), neg_1_buffer.index());
523   EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index());
524 }
525 
TEST_F(BufferAssignmentTest,AddCannotReuse)526 TEST_F(BufferAssignmentTest, AddCannotReuse) {
527   // Pass in a special rule to indicate that "add" cannot be live out.
528   //
529   // paramscalar ------- (mul) -- (add) -- (sub)
530   //                     /        /        /
531   // param0[100] -------/        /        /
532   //                            /        /
533   // param1[100] --------------/--------/
534   auto builder = HloComputation::Builder(TestName());
535   auto paramscalar =
536       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
537   auto broadcast = builder.AddInstruction(
538       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
539   auto param0 = builder.AddInstruction(
540       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
541   auto param1 = builder.AddInstruction(
542       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
543   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
544       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
545   auto add = builder.AddInstruction(
546       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
547   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
548       f32vec100_, HloOpcode::kSubtract, add, param1));
549   auto module = CreateNewVerifiedModule();
550   module->AddEntryComputation(builder.Build());
551 
552   auto buffers = RunBufferAssignmentNoBuffersReuseForAdd(module.get());
553 
554   // Distinct input buffers were assigned for parameters.
555   BufferAllocation paramscalar_buffer =
556       GetAssignedInputAllocation(*buffers, paramscalar);
557   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
558   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
559   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
560   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
561   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
562 
563   // The mul node has a valid buffer assigned, doesn't share with input.
564   const BufferAllocation& sub_buffer = GetTopLevelAllocation(*buffers, sub);
565   EXPECT_NE(sub_buffer.index(), param0_buffer.index());
566 
567   // The add node cannot reuse the mul node's buffer since we told buffer
568   // assignment so.
569   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
570   EXPECT_NE(add_buffer.index(), sub_buffer.index());
571 
572   // The sub node has a valid output buffer assigned.
573   GetAssignedOutputAllocation(*buffers, sub);
574 }
575 
TEST_F(BufferAssignmentTest,BasicUniquelyColored)576 TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
577   // paramscalar ------- (mul) -- (add) -- (sub)
578   //                     /        /        /
579   // param0[100] -------/        /        /
580   //                            /        /
581   // param1[100] --------------/--------/
582   // The output of each op is colored with a different color, so we can not
583   // share anything.
584   auto builder = HloComputation::Builder(TestName());
585   auto paramscalar =
586       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
587   auto broadcast = builder.AddInstruction(
588       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
589   auto param0 = builder.AddInstruction(
590       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
591   auto param1 = builder.AddInstruction(
592       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
593   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
594       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
595   auto add = builder.AddInstruction(
596       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
597   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
598       f32vec100_, HloOpcode::kSubtract, add, param1));
599   auto module = CreateNewVerifiedModule();
600   module->AddEntryComputation(builder.Build());
601 
602   absl::flat_hash_map<const HloInstruction*, int> color_map;
603   auto colorer = [&](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
604     int color = 0;
605     for (HloValue::Id id = 0;
606          id < alias_analysis->dataflow_analysis().values().size(); id++) {
607       auto& value = alias_analysis->dataflow_analysis().GetValue(id);
608       color_map[value.defining_instruction()] = color;
609       value.set_color(BufferValue::Color(color++));
610     }
611     return OkStatus();
612   };
613 
614   auto buffers = RunColoredBufferAssignment(module.get(), colorer);
615 
616   // Distinct input buffers were assigned for parameters.
617   BufferAllocation paramscalar_buffer =
618       GetAssignedInputAllocation(*buffers, paramscalar);
619   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
620   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
621   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
622   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
623   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
624 
625   // The mul node has a valid buffer assigned, doesn't share with input.
626   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
627   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
628 
629   // The add node can not reuse the mul node's buffer due to coloring.
630   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
631   EXPECT_NE(add_buffer.index(), mul_buffer.index());
632 
633   // The sub node has a valid output buffer assigned.
634   GetAssignedOutputAllocation(*buffers, sub);
635 
636   // Check if the HLO instructions have the correct colors in the layout.
637   EXPECT_EQ(param0->shape().layout().memory_space(), color_map[param0]);
638   EXPECT_EQ(param1->shape().layout().memory_space(), color_map[param1]);
639   EXPECT_EQ(mul->shape().layout().memory_space(), color_map[mul]);
640   EXPECT_EQ(add->shape().layout().memory_space(), color_map[add]);
641   EXPECT_EQ(sub->shape().layout().memory_space(), color_map[sub]);
642 }
643 
TEST_F(BufferAssignmentTest,BasicPartiallyColored)644 TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
645   // paramscalar ------- (mul) -- (add) -- (sub)
646   //                     /        /        /
647   // param0[100] -------/        /        /
648   //                            /        /
649   // param1[100] --------------/--------/
650   // The output of the mul and the add have the color 1, and the other buffers
651   // have the color 0, which allows the mul and add to share buffers.
652   auto builder = HloComputation::Builder(TestName());
653   auto paramscalar =
654       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
655   auto broadcast = builder.AddInstruction(
656       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
657   auto param0 = builder.AddInstruction(
658       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
659   auto param1 = builder.AddInstruction(
660       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
661   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
662       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
663   auto add = builder.AddInstruction(
664       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
665   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
666       f32vec100_, HloOpcode::kSubtract, add, param1));
667   auto module = CreateNewVerifiedModule();
668   module->AddEntryComputation(builder.Build());
669 
670   auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
671     for (HloValue::Id id = 0;
672          id < alias_analysis->dataflow_analysis().values().size(); id++) {
673       auto& value = alias_analysis->dataflow_analysis().GetValue(id);
674       auto& buffer = alias_analysis->GetBufferContainingValue(value);
675       for (const auto& alias : buffer.values()) {
676         if (alias->instruction()->opcode() == HloOpcode::kAdd ||
677             alias->instruction()->opcode() == HloOpcode::kMultiply) {
678           value.set_color(LogicalBuffer::Color(1));
679         }
680       }
681       if (!value.has_color()) {
682         value.set_color(LogicalBuffer::Color(0));
683       }
684     }
685     return OkStatus();
686   };
687 
688   auto buffers = RunColoredBufferAssignment(module.get(), colorer);
689 
690   // Distinct input buffers were assigned for parameters.
691   BufferAllocation paramscalar_buffer =
692       GetAssignedInputAllocation(*buffers, paramscalar);
693   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
694   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
695   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
696   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
697   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
698 
699   // The mul node has a valid buffer assigned, doesn't share with input.
700   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
701   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
702 
703   // The add node can reuse the mul node's buffer.
704   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
705   EXPECT_EQ(add_buffer.index(), mul_buffer.index());
706 
707   // The sub node has a valid output buffer assigned.
708   GetAssignedOutputAllocation(*buffers, sub);
709 
710   // Check if the HLO instructions have the correct colors in the layout.
711   EXPECT_EQ(mul->shape().layout().memory_space(), 1);
712   EXPECT_EQ(add->shape().layout().memory_space(), 1);
713   EXPECT_EQ(sub->shape().layout().memory_space(), 0);
714   EXPECT_EQ(param0->shape().layout().memory_space(), 0);
715   EXPECT_EQ(param1->shape().layout().memory_space(), 0);
716 }
717 
TEST_F(BufferAssignmentTest,PresetAssignments)718 TEST_F(BufferAssignmentTest, PresetAssignments) {
719   // paramscalar ------- (mul) -- (add) -- (sub)
720   //                     /        /        /
721   // param0[100] -------/        /        /
722   //                            /        /
723   // param1[100] --------------/--------/
724   // Similar to BasicPartiallyColored, but the color is set in the layout.
725   // The output of the mul and the add have the color 1 and have preset
726   // assignments, and the other buffers have the color 0, which allows the mul
727   // and add to share buffers.
728   auto builder = HloComputation::Builder(TestName());
729   auto paramscalar =
730       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
731   auto broadcast = builder.AddInstruction(
732       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
733   auto param0 = builder.AddInstruction(
734       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
735   auto param1 = builder.AddInstruction(
736       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
737   Shape f32vec100_color1 =
738       ShapeUtil::MakeShapeWithLayout(F32, {100}, {0}, {}, {}, 0, 1);
739   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
740       f32vec100_color1, HloOpcode::kMultiply, broadcast, param0));
741   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
742       f32vec100_color1, HloOpcode::kAdd, mul, param1));
743   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
744       f32vec100_, HloOpcode::kSubtract, add, param1));
745   auto module = CreateNewVerifiedModule();
746   module->AddEntryComputation(builder.Build());
747 
748   auto preset_assignments = std::make_unique<PresetAssignments>();
749   preset_assignments->add_chunk({mul, {}}, {/*offset=*/100, /*size=*/400});
750   preset_assignments->add_chunk({add, {}}, {/*offset=*/550, /*size=*/400});
751   preset_assignments->assignment_information_for_space(/*memory_space=*/1)
752       ->size = 950;
753 
754   auto buffers = RunBufferAssignmentWithPresetAssignments(
755       module.get(), std::move(preset_assignments));
756 
757   // Distinct input buffers were assigned for parameters.
758   BufferAllocation paramscalar_buffer =
759       GetAssignedInputAllocation(*buffers, paramscalar);
760   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
761   BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
762   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
763   EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
764   EXPECT_EQ(paramscalar_buffer.color(), LogicalBuffer::Color(0));
765   EXPECT_NE(param0_buffer.index(), param1_buffer.index());
766   EXPECT_EQ(param0_buffer.color(), LogicalBuffer::Color(0));
767 
768   // The mul and add use the same preset buffer. Ensure it has the correct color
769   // and offsets.
770   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
771   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
772   EXPECT_EQ(mul_buffer, add_buffer);
773   EXPECT_NE(mul_buffer.index(), param0_buffer.index());
774   EXPECT_EQ(mul_buffer.color(), LogicalBuffer::Color(1));
775 
776   EXPECT_EQ(mul_buffer.assigned_buffers().size(), 2);
777   for (const auto& value_and_offsetsize : mul_buffer.assigned_buffers()) {
778     if (value_and_offsetsize.first->instruction() == mul) {
779       EXPECT_EQ(value_and_offsetsize.second.offset, 100);
780       EXPECT_EQ(value_and_offsetsize.second.size, 400);
781     } else {
782       EXPECT_EQ(value_and_offsetsize.first->instruction(), add);
783       EXPECT_EQ(value_and_offsetsize.second.offset, 550);
784       EXPECT_EQ(value_and_offsetsize.second.size, 400);
785     }
786   }
787 
788   // The sub node has a valid output buffer assigned.
789   GetAssignedOutputAllocation(*buffers, sub);
790 }
791 
TEST_F(BufferAssignmentTest,PresetAssignmentsWhile)792 TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
793   // Tests preset assignments when there is no 1-to-1 correspondence between
794   // HloValue and HloBuffer (i.e., a while loop).
795   auto module = CreateNewVerifiedModule();
796   Shape f32vec10_color1 =
797       ShapeUtil::MakeShapeWithLayout(F32, {10}, {0}, {}, {}, 0, 1);
798   Shape t_s32_f32v10_color1 =
799       ShapeUtil::MakeTupleShape({s32_, f32vec10_color1});
800 
801   auto cond_builder = HloComputation::Builder("WhileCond");
802   HloInstruction* cond_param = cond_builder.AddInstruction(
803       HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "cond_param"));
804   HloInstruction* cond_iter = cond_builder.AddInstruction(
805       HloInstruction::CreateGetTupleElement(s32_, cond_param, 0));
806   HloInstruction* cond_limit = cond_builder.AddInstruction(
807       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(50)));
808   cond_builder.AddInstruction(
809       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
810                                     cond_limit, ComparisonDirection::kLt));
811   HloComputation* cond_computation =
812       module->AddEmbeddedComputation(cond_builder.Build());
813 
814   auto body_builder = HloComputation::Builder("WhileBody");
815   HloInstruction* body_param = body_builder.AddInstruction(
816       HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "body_param"));
817   HloInstruction* body_iter = body_builder.AddInstruction(
818       HloInstruction::CreateGetTupleElement(s32_, body_param, 0));
819   HloInstruction* body_data = body_builder.AddInstruction(
820       HloInstruction::CreateGetTupleElement(f32vec10_color1, body_param, 1));
821   HloInstruction* body_data_increment = body_builder.AddInstruction(
822       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
823           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f})));
824   HloInstruction* body_data_next =
825       body_builder.AddInstruction(HloInstruction::CreateBinary(
826           f32vec10_color1, HloOpcode::kAdd, body_data, body_data_increment));
827   HloInstruction* body_iter_increment = body_builder.AddInstruction(
828       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
829   HloInstruction* body_iter_next =
830       body_builder.AddInstruction(HloInstruction::CreateBinary(
831           s32_, HloOpcode::kAdd, body_iter, body_iter_increment));
832   body_builder.AddInstruction(
833       HloInstruction::CreateTuple({body_iter_next, body_data_next}));
834   HloComputation* body_computation =
835       module->AddEmbeddedComputation(body_builder.Build());
836 
837   auto builder = HloComputation::Builder(TestName());
838   HloInstruction* iter = builder.AddInstruction(
839       HloInstruction::CreateParameter(0, s32_, "param_iter"));
840   HloInstruction* data = builder.AddInstruction(
841       HloInstruction::CreateParameter(1, f32vec10_, "param_data"));
842   HloInstruction* negate = builder.AddInstruction(
843       HloInstruction::CreateUnary(f32vec10_color1, HloOpcode::kNegate, data));
844   HloInstruction* tuple =
845       builder.AddInstruction(HloInstruction::CreateTuple({iter, negate}));
846   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
847       t_s32_f32v10_color1, cond_computation, body_computation, tuple));
848   HloInstruction* while_data = builder.AddInstruction(
849       HloInstruction::CreateGetTupleElement(f32vec10_color1, while_op, 1));
850   builder.AddInstruction(HloInstruction::CreateBinary(
851       f32vec10_, HloOpcode::kAdd, while_data, data));
852   module->AddEntryComputation(builder.Build());
853 
854   // Set only one preset assignment for while data and its aliases.
855   auto preset_assignments = std::make_unique<PresetAssignments>();
856   preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
857   preset_assignments->assignment_information_for_space(/*memory_space=*/1)
858       ->size = 140;
859 
860   auto buffers = RunBufferAssignmentWithPresetAssignments(
861       module.get(), std::move(preset_assignments));
862 
863   // All assigned buffers are aliased so they should have the same offset and
864   // size.
865   const BufferAllocation& data_buffer = GetTopLevelAllocation(*buffers, negate);
866   EXPECT_EQ(data_buffer.assigned_buffers().size(), 5);
867   for (const auto& value_and_offsetsize : data_buffer.assigned_buffers()) {
868     EXPECT_EQ(value_and_offsetsize.second.offset, 100);
869     EXPECT_EQ(value_and_offsetsize.second.size, 40);
870     EXPECT_EQ(value_and_offsetsize.first->color(), LogicalBuffer::Color(1));
871   }
872 }
873 
TEST_F(BufferAssignmentTest,MultipleUsersForNode)874 TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
875   // This is similar to the Basic test, with the difference that (sub) is
876   // another user of (mul)'s result, so (mul)'s buffer cannot be reused for
877   // (add)'s output.
878   //
879   // paramscalar -------\     /-----------\
880   //                     \   /             \
881   // param0[100] ------- (mul) -- (add) -- (sub)
882   //                              /
883   // param1[100] ----------------/
884   //
885   auto builder = HloComputation::Builder(TestName());
886   auto paramscalar =
887       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
888   auto broadcast = builder.AddInstruction(
889       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
890   auto param0 = builder.AddInstruction(
891       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
892   auto param1 = builder.AddInstruction(
893       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
894   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
895       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
896   auto add = builder.AddInstruction(
897       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
898   auto sub = builder.AddInstruction(
899       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
900   auto module = CreateNewVerifiedModule();
901   module->AddEntryComputation(builder.Build());
902 
903   auto buffers = RunBufferAssignment(module.get());
904 
905   // Input buffers were assigned for parameters.
906   BufferAllocation paramscalar_buffer =
907       GetAssignedInputAllocation(*buffers, paramscalar);
908   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
909   BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1);
910   EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
911   EXPECT_NE(paramscalar_buffer.index(), param1_index.index());
912   EXPECT_NE(param0_buffer.index(), param1_index.index());
913 
914   // The mul node had a buffer allocated.
915   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
916 
917   // Now the add node can't reuse the mul node's buffer.
918   const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
919   EXPECT_NE(add_buffer.index(), mul_buffer.index());
920 
921   // Log size information for inspection.
922   const std::vector<const HloInstruction*> level0 = GetInstructions(sub);
923   int64_t size0 = ValidateBuffers(level0, *buffers);
924   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
925             << " for " << level0.size() << " instructions; "
926             << "total buffer size " << size0;
927 }
928 
TEST_F(BufferAssignmentTest,TrivialMap)929 TEST_F(BufferAssignmentTest, TrivialMap) {
930   // This tests a trivial x+1 map as the only operation.
931   //
932   // param0[100x10] ---> (map x+1)
933   //
934   // Builds the map function.
935   auto module = CreateNewVerifiedModule();
936   auto map_computation =
937       module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
938   auto inner_last = map_computation->root_instruction();
939 
940   // Creates the main kernel and verifies instruction counts.
941   auto builder = HloComputation::Builder(TestName());
942   auto param0 = builder.AddInstruction(
943       HloInstruction::CreateParameter(0, f32a100x10_, "p"));
944   auto map = builder.AddInstruction(
945       HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
946   module->AddEntryComputation(builder.Build());
947 
948   const std::vector<const HloInstruction*> level0 = GetInstructions(map);
949   EXPECT_EQ(2, level0.size()) << "Invalid main kernel size";
950   const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last);
951   EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
952 
953   // Assigns buffers and fetches sizes.
954   auto buffers = RunBufferAssignment(module.get());
955   int64_t size0 = ValidateBuffers(level0, *buffers);
956   int64_t size1 = ValidateBuffers(level1, *buffers);
957 
958   // Both algorithms assign the map's buffer before processing the embedded
959   // computation, so we can verify that the buffers aren't shared between them
960   // by checking:
961   EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers))
962       << "Reuse between main kernel and embedded mapping.";
963 
964   // An input buffer was assigned for the parameter.
965   BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
966 
967   // An output buffer was assigned for the map.
968   BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map);
969   EXPECT_NE(param0_buffer.index(), map_buffer.index());
970 
971   // The final computation node of the map is an add of an f32 param and a
972   // constant.
973   EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode());
974   const BufferAllocation& inner_add_buffer =
975       GetTopLevelAllocation(*buffers, inner_last);
976   EXPECT_NE(inner_add_buffer.index(), map_buffer.index());
977 
978   // Log size information for inspection.
979   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
980             << " for " << level0.size() + level1.size() << " instructions; "
981             << "total buffer size " << size0 + size1;
982 }
983 
TEST_F(BufferAssignmentTest,CannotReuseInputBufferOfReduce)984 TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
985   // Make sure that the input buffer of a reduce cannot be reused for its
986   // output.  (Reuse is not safe in the general case, as it reshapes and some
987   // out-of-order reductions could overwrite an element before a use.)
988   //
989   // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3)
990   auto module = CreateNewVerifiedModule();
991   auto reduce_computation =
992       module->AddEmbeddedComputation(BuildReduceComputation("f32+f32"));
993 
994   auto builder = HloComputation::Builder(TestName());
995   auto param0 = builder.AddInstruction(
996       HloInstruction::CreateParameter(0, f32a100x10_, "p"));
997   auto exp1 = builder.AddInstruction(
998       HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
999   auto exp2 = builder.AddInstruction(
1000       HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
1001   auto const0 = builder.AddInstruction(
1002       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
1003   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
1004       /*shape=*/f32vec10_,
1005       /*operand=*/exp2,
1006       /*init_value=*/const0,
1007       /*dimensions_to_reduce=*/{0}, reduce_computation));
1008   auto exp3 = builder.AddInstruction(
1009       HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce));
1010 
1011   module->AddEntryComputation(builder.Build());
1012 
1013   auto buffers = RunBufferAssignment(module.get());
1014   const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
1015   ValidateBuffers(instrs, *buffers);
1016 
1017   const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1);
1018   const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2);
1019   const BufferAllocation& reduce_buffer =
1020       GetTopLevelAllocation(*buffers, reduce);
1021 
1022   // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity
1023   // checking.
1024   EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index());
1025 
1026   // The buffer of exp2 cannot be used for reduce, even though it's the only
1027   // operand.
1028   EXPECT_NE(exp2_buffer.index(), reduce_buffer.index());
1029 }
1030 
TEST_F(BufferAssignmentTest,ExampleWhile)1031 TEST_F(BufferAssignmentTest, ExampleWhile) {
1032   // This tests a While loop example from the ir_semantics document.
1033   //
1034   // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation.
1035   // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation.
1036   //
1037   // const3[s32] -------\
1038   // const4[f32[4]] --- tuple --- while[condition, body]
1039   //
1040   // Builds the nested condition and body.
1041   auto module = CreateNewVerifiedModule();
1042   auto condition_computation =
1043       module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
1044   auto body_computation =
1045       module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update"));
1046 
1047   // Creates the main kernel and verifies instruction counts.
1048   auto builder = HloComputation::Builder(TestName());
1049   auto const3 = builder.AddInstruction(
1050       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
1051   auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
1052       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
1053   auto tuple =
1054       builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
1055   auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
1056       t_s32_f32v4_, condition_computation, body_computation, tuple));
1057   module->AddEntryComputation(builder.Build());
1058 
1059   const std::vector<const HloInstruction*> level0 = GetInstructions(while_op);
1060   EXPECT_EQ(4, level0.size()) << "Invalid while kernel size";
1061   const std::vector<const HloInstruction*> levelc =
1062       GetInstructions(condition_computation->root_instruction());
1063   EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size";
1064   const std::vector<const HloInstruction*> levelb =
1065       GetInstructions(body_computation->root_instruction());
1066   EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
1067 
1068   // Assigns buffers and fetches sizes.
1069   auto buffers = RunBufferAssignment(module.get());
1070   int64_t size0 = ValidateBuffers(level0, *buffers);
1071   int64_t sizec = ValidateBuffers(levelc, *buffers);
1072   int64_t sizeb = ValidateBuffers(levelb, *buffers);
1073 
1074   // BufferAssignment will assign a single allocation for the following
1075   // instructions: while, while.cond.param, while.body.param, while.body.result.
1076   EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers))
1077       << "Should be reuse between main kernel and embedded condition.";
1078   EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers))
1079       << "Should be reuse between embedded condition and body.";
1080   // Expect buffer reuse between main kernel and body computation.
1081   EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers))
1082       << "Should be reuse between main kernel and embedded body.";
1083 
1084   // The final computation node of the while body is a tuple of s32 and
1085   // f32[4] adds.
1086   HloInstruction* body_root = body_computation->root_instruction();
1087   EXPECT_EQ(HloOpcode::kTuple, body_root->opcode());
1088 
1089   // Check that buffer for each subshape of 'while_op' shares allocation with
1090   // corresponding buffer from while body computation at same index.
1091   ShapeUtil::ForEachSubshape(
1092       while_op->shape(),
1093       [this, &buffers, while_op, body_root](const Shape& /*subshape*/,
1094                                             const ShapeIndex& index) {
1095         auto while_op_allocation = GetAllocation(*buffers, while_op, index);
1096         auto body_root_allocation = GetAllocation(*buffers, body_root, index);
1097         EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index());
1098       });
1099 
1100   // Log size information for inspection.
1101   LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
1102             << " for " << level0.size() + levelc.size() + levelb.size()
1103             << " instructions; total buffer size " << size0 + sizec + sizeb;
1104 }
1105 
TEST_F(BufferAssignmentTest,ExampleConditional)1106 TEST_F(BufferAssignmentTest, ExampleConditional) {
1107   auto module = CreateNewVerifiedModule();
1108   auto true_computation = module->AddEmbeddedComputation(
1109       BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil"));
1110   auto false_computation = module->AddEmbeddedComputation(
1111       BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor"));
1112 
1113   auto builder = HloComputation::Builder(TestName());
1114   auto pred = builder.AddInstruction(
1115       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1116   auto const1 = builder.AddInstruction(
1117       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
1118   auto const2 = builder.AddInstruction(
1119       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
1120   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1121       r0f32_, pred, const1, true_computation, const2, false_computation));
1122   module->AddEntryComputation(builder.Build());
1123 
1124   const std::vector<const HloInstruction*> conditional_instrs =
1125       GetInstructions(conditional);
1126   const std::vector<const HloInstruction*> true_instrs =
1127       GetInstructions(true_computation->root_instruction());
1128   const std::vector<const HloInstruction*> false_instrs =
1129       GetInstructions(false_computation->root_instruction());
1130   EXPECT_EQ(4, conditional_instrs.size());
1131   EXPECT_EQ(2, true_instrs.size());
1132   EXPECT_EQ(2, false_instrs.size());
1133 
1134   auto buffers = RunBufferAssignment(module.get());
1135   ValidateBuffers(conditional_instrs, *buffers);
1136   ValidateBuffers(true_instrs, *buffers);
1137   ValidateBuffers(false_instrs, *buffers);
1138 
1139   EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers))
1140       << "Should be reuse between conditional and true computation.";
1141   EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers))
1142       << "Should be reuse between conditional and false computation.";
1143   EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers))
1144       << "Should be reuse between true and false computations.";
1145 
1146   const BufferAllocation& conditional_buffer =
1147       GetTopLevelAllocation(*buffers, conditional);
1148   const BufferAllocation& true_buffer =
1149       GetTopLevelAllocation(*buffers, true_computation->root_instruction());
1150   const BufferAllocation& false_buffer =
1151       GetTopLevelAllocation(*buffers, false_computation->root_instruction());
1152   EXPECT_EQ(conditional_buffer.size(), true_buffer.size());
1153   EXPECT_EQ(conditional_buffer.size(), false_buffer.size());
1154 }
1155 
TEST_F(BufferAssignmentTest,UnaryOpReuseChain)1156 TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
1157   // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
1158   auto builder = HloComputation::Builder(TestName());
1159   auto param0 = builder.AddInstruction(
1160       HloInstruction::CreateParameter(0, f32vec100_, "p"));
1161   auto exp1 = builder.AddInstruction(
1162       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
1163   auto tanh = builder.AddInstruction(
1164       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1));
1165   auto exp2 = builder.AddInstruction(
1166       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh));
1167   auto neg = builder.AddInstruction(
1168       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
1169 
1170   auto module = CreateNewVerifiedModule();
1171   module->AddEntryComputation(builder.Build());
1172   auto assignment = RunBufferAssignment(module.get());
1173 
1174   // tanh and exp2 can reuse exp1's buffer
1175   EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
1176   auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1);
1177   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh));
1178   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2));
1179   EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg));
1180 }
1181 
TEST_F(BufferAssignmentTest,ReuseNonOperandBuffer)1182 TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
1183   // This computation is a chain of operations which decreases in buffer size
1184   // (via slice) then increases in size (via broadcast):
1185   //
1186   // param ---> (negate) ---> (slice) ---> (broadcast)
1187   //
1188   // The negate should share a buffer with broadcast.
1189   auto builder = HloComputation::Builder(TestName());
1190   auto param0 = builder.AddInstruction(
1191       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1192   auto negate = builder.AddInstruction(
1193       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1194   auto slice = builder.AddInstruction(
1195       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1196   auto broadcast = builder.AddInstruction(
1197       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1198 
1199   auto module = CreateNewVerifiedModule();
1200   module->AddEntryComputation(builder.Build());
1201   auto assignment = RunBufferAssignment(module.get());
1202 
1203   // negate and broadcast should share a buffer.
1204   EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1205   auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1206   EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1207 
1208   // Slice should have its own buffer.
1209   EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1210 }
1211 
TEST_F(BufferAssignmentTest,NoReuseLiveBuffer)1212 TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
1213   // This computation is identical to that in ReuseNonOperandBuffer, but the
1214   // negate value is live until the end of the computation (due to it being an
1215   // operand of the output tuple) preventing reuse.
1216   //
1217   // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple)
1218   //                  \-----------------------------------/
1219   //
1220   // The negate should not share a buffer with broadcast.
1221   auto builder = HloComputation::Builder(TestName());
1222   auto param0 = builder.AddInstruction(
1223       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1224   auto negate = builder.AddInstruction(
1225       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1226   auto slice = builder.AddInstruction(
1227       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1228   auto broadcast = builder.AddInstruction(
1229       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1230   builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
1231 
1232   auto module = CreateNewVerifiedModule();
1233   module->AddEntryComputation(builder.Build());
1234   auto assignment = RunBufferAssignment(module.get());
1235 
1236   // The instructions should not share buffers.
1237   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1238             GetTopLevelAllocation(*assignment, negate));
1239   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1240             GetTopLevelAllocation(*assignment, slice));
1241   EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1242             GetTopLevelAllocation(*assignment, slice));
1243 }
1244 
TEST_F(BufferAssignmentTest,NoReuseAliasedBuffer)1245 TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
1246   // This computation is identical to that in ReuseNonOperandBuffer, but the
1247   // negate value is placed into a tuple which lives to the end of the
1248   // computation. This extends the live range of negate's buffer preventing
1249   // reuse due to buffer aliasing.
1250   //
1251   // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple)
1252   //                              \-----------------------------------/
1253   //
1254   // The negate should not share a buffer with broadcast.
1255   auto builder = HloComputation::Builder(TestName());
1256   auto param0 = builder.AddInstruction(
1257       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1258   auto negate = builder.AddInstruction(
1259       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1260   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate}));
1261   auto tuple_element = builder.AddInstruction(
1262       HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
1263   auto slice = builder.AddInstruction(
1264       HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
1265   auto broadcast = builder.AddInstruction(
1266       HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1267   builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
1268 
1269   auto module = CreateNewVerifiedModule();
1270   module->AddEntryComputation(builder.Build());
1271   auto assignment = RunBufferAssignment(module.get());
1272 
1273   // The instructions should not share buffers.
1274   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1275             GetTopLevelAllocation(*assignment, negate));
1276   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1277             GetTopLevelAllocation(*assignment, slice));
1278   EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1279             GetTopLevelAllocation(*assignment, slice));
1280 }
1281 
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBuffer)1282 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
1283   // This computation is very similar to ReuseNonOperandBuffer except the
1284   // broadcast has a smaller output than the negate. This should block reuse of
1285   // negate's buffer by broadcast because the output buffer(s) of a computation
1286   // should be exactly sized for the value.
1287   //
1288   // param ---> (negate) ---> (slice) ---> (broadcast)
1289   //
1290   // Neither negate nor slice may share a buffer with broadcast.
1291   auto builder = HloComputation::Builder(TestName());
1292   auto param0 = builder.AddInstruction(
1293       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1294   // Negate output is 100 elements.
1295   auto negate = builder.AddInstruction(
1296       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1297   // Slice output is 10 elements.
1298   auto slice = builder.AddInstruction(
1299       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1300   // Broadcast output is 40 elements.
1301   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1302       ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1303 
1304   auto module = CreateNewVerifiedModule();
1305   module->AddEntryComputation(builder.Build());
1306   auto assignment = RunBufferAssignment(module.get());
1307 
1308   // The broadcast output buffer cannot be shared.
1309   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1310             GetTopLevelAllocation(*assignment, negate));
1311   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1312             GetTopLevelAllocation(*assignment, slice));
1313 }
1314 
TEST_F(BufferAssignmentTest,ReuseOutputBufferIfExactlySized)1315 TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
1316   // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast
1317   // output is exactly the same size as the negate (rather than being
1318   // smaller). This enables reuse of negate's buffer by the broadcast because
1319   // the output buffer will be sized exactly to its value.
1320   //
1321   // param ---> (negate) ---> (slice) ---> (broadcast)
1322   //
1323   // The negate should *not* share a buffer with broadcast.
1324   auto builder = HloComputation::Builder(TestName());
1325   auto param0 = builder.AddInstruction(
1326       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1327   // Negate output is 100 elements.
1328   auto negate = builder.AddInstruction(
1329       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1330   auto slice = builder.AddInstruction(
1331       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1332   // Broadcast output is 40 elements.
1333   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1334       ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
1335 
1336   auto module = CreateNewVerifiedModule();
1337   module->AddEntryComputation(builder.Build());
1338   auto assignment = RunBufferAssignment(module.get());
1339 
1340   // negate and broadcast should share a buffer.
1341   EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1342   auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1343   EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1344 
1345   // Slice should have its own buffer.
1346   EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1347 }
1348 
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBufferInTuple)1349 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
1350   // This computation is very similar to ReuseNonOperandBuffer except the
1351   // broadcast has a smaller output than the negate, and the broadcast is
1352   // contained in the computation output as a tuple element. This should block
1353   // reuse of the negate's buffer by the broadcast because the output buffer(s)
1354   // of a computation should be exactly sized for the value. This includes those
1355   // buffers aliased in the output (eg, contained as tuple elements).
1356   //
1357   // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple)
1358   //
1359   // Neither negate nor slice may share a buffer with broadcast.
1360   auto builder = HloComputation::Builder(TestName());
1361   auto param0 = builder.AddInstruction(
1362       HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1363   // Negate output is 100 elements.
1364   auto negate = builder.AddInstruction(
1365       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1366   // Slice output is 10 elements.
1367   auto slice = builder.AddInstruction(
1368       HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1369   // Broadcast output is 40 elements.
1370   auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1371       ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1372   builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
1373 
1374   auto module = CreateNewVerifiedModule();
1375   module->AddEntryComputation(builder.Build());
1376   auto assignment = RunBufferAssignment(module.get());
1377 
1378   // The broadcast output buffer cannot be shared.
1379   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1380             GetTopLevelAllocation(*assignment, negate));
1381   EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1382             GetTopLevelAllocation(*assignment, slice));
1383 }
1384 
TEST_F(BufferAssignmentTest,EmbeddedComputationBuffers)1385 TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
1386   // Verify that buffers for embedded computations are properly marked as
1387   // thread-local and that embedded parameters are not marked as
1388   // is_entry_computation_parameter.
1389   auto module = CreateNewVerifiedModule();
1390   auto vec_shape = ShapeUtil::MakeShape(F32, {42});
1391   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1392 
1393   // Create a scalar computation to use in a map.
1394   auto map_builder = HloComputation::Builder(TestName() + "_map");
1395   auto map_param = map_builder.AddInstruction(
1396       HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1397   auto map_root = map_builder.AddInstruction(
1398       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1399   auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1400 
1401   // Create a vector computation to use in a kCall.
1402   auto call_builder = HloComputation::Builder(TestName() + "_call");
1403   auto call_param = call_builder.AddInstruction(
1404       HloInstruction::CreateParameter(0, vec_shape, "vec_param"));
1405   auto call_root = call_builder.AddInstruction(
1406       HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param));
1407   auto call_computation = module->AddEmbeddedComputation(call_builder.Build());
1408 
1409   // Create entry computation which kCalls call_computation and then calls map
1410   // with map_computation on the result.
1411   auto builder = HloComputation::Builder(TestName());
1412   auto param = builder.AddInstruction(
1413       HloInstruction::CreateParameter(0, vec_shape, "param"));
1414   auto call = builder.AddInstruction(
1415       HloInstruction::CreateCall(vec_shape, {param}, call_computation));
1416   auto map = builder.AddInstruction(
1417       HloInstruction::CreateMap(vec_shape, {call}, map_computation));
1418   module->AddEntryComputation(builder.Build());
1419 
1420   auto assignment = RunBufferAssignment(module.get());
1421 
1422   // Allocations for the map computation should be thread-local and not
1423   // live-out.
1424   auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1425   EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1426   EXPECT_FALSE(map_param_alloc.maybe_live_out());
1427   EXPECT_TRUE(map_param_alloc.is_thread_local());
1428 
1429   auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1430   EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1431   EXPECT_FALSE(map_root_alloc.maybe_live_out());
1432   EXPECT_TRUE(map_root_alloc.is_thread_local());
1433 
1434   // Allocations for the call computation should not be thread-local.
1435   auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
1436   EXPECT_TRUE(call_param_alloc.is_entry_computation_parameter());
1437   EXPECT_FALSE(call_param_alloc.maybe_live_out());
1438   EXPECT_FALSE(call_param_alloc.is_thread_local());
1439 
1440   auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
1441   EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
1442   EXPECT_FALSE(call_root_alloc.is_thread_local());
1443 
1444   // Entry computation allocations can be marked liveout and
1445   // is_entry_computation_parameter.
1446   auto& param_alloc = GetTopLevelAllocation(*assignment, param);
1447   EXPECT_TRUE(param_alloc.is_entry_computation_parameter());
1448   EXPECT_FALSE(param_alloc.maybe_live_out());
1449   EXPECT_FALSE(param_alloc.is_thread_local());
1450 
1451   auto& map_alloc = GetTopLevelAllocation(*assignment, map);
1452   EXPECT_FALSE(map_alloc.is_entry_computation_parameter());
1453   EXPECT_TRUE(map_alloc.maybe_live_out());
1454   EXPECT_FALSE(map_alloc.is_thread_local());
1455 }
1456 
TEST_F(BufferAssignmentTest,CustomCallEmbeddedComputationBuffers)1457 TEST_F(BufferAssignmentTest, CustomCallEmbeddedComputationBuffers) {
1458   // Verify that buffers for embedded computations in a custom call are properly
1459   // marked as thread-local.
1460   auto module = CreateNewVerifiedModule();
1461   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1462 
1463   // Create a scalar computation to use in a map.
1464   auto map_builder = HloComputation::Builder(TestName() + "_map");
1465   auto map_param = map_builder.AddInstruction(
1466       HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1467   auto map_root = map_builder.AddInstruction(
1468       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1469   auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1470 
1471   // Create entry computation with a custom call on map_computation.
1472   auto builder = HloComputation::Builder(TestName());
1473   auto param = builder.AddInstruction(
1474       HloInstruction::CreateParameter(0, scalar_shape, "param"));
1475   builder.AddInstruction(HloInstruction::CreateCustomCall(
1476       scalar_shape, {param}, map_computation, "call_name"));
1477   module->AddEntryComputation(builder.Build());
1478 
1479   auto assignment = RunBufferAssignment(module.get());
1480 
1481   // Allocations for the map computation should be thread-local and not
1482   // live-out.
1483   auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1484   EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1485   EXPECT_FALSE(map_param_alloc.maybe_live_out());
1486   EXPECT_TRUE(map_param_alloc.is_thread_local());
1487 
1488   auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1489   EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1490   EXPECT_FALSE(map_root_alloc.maybe_live_out());
1491   EXPECT_TRUE(map_root_alloc.is_thread_local());
1492 }
1493 
TEST_F(BufferAssignmentTest,TupleParameterAsOutput)1494 TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
1495   // Test a computation that returns a tuple parameter.
1496   auto builder = HloComputation::Builder(TestName());
1497   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1498       0,
1499       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1500                                  ShapeUtil::MakeShape(F32, {}),
1501                                  ShapeUtil::MakeShape(S32, {42})}),
1502       "param0"));
1503 
1504   auto module = CreateNewVerifiedModule();
1505   module->AddEntryComputation(builder.Build());
1506   auto assignment = RunBufferAssignment(module.get());
1507 
1508   // There should be four allocations: one for vector of pointers, and one for
1509   // each tuple element.
1510   EXPECT_EQ(4, assignment->Allocations().size());
1511 
1512   // Verify each buffer allocation is marked as an entry computation parameter
1513   // and is liveout.
1514   ShapeUtil::ForEachSubshape(
1515       tuple_param->shape(),
1516       [this, &assignment, tuple_param](const Shape& /*subshape*/,
1517                                        const ShapeIndex& index) {
1518         auto allocation = GetAllocation(*assignment, tuple_param, index);
1519         EXPECT_TRUE(allocation.is_entry_computation_parameter());
1520         EXPECT_EQ(0, allocation.parameter_number());
1521         EXPECT_TRUE(allocation.maybe_live_out());
1522       });
1523 }
1524 
TEST_F(BufferAssignmentTest,ElementOfNestedTupleParameterAsOutput)1525 TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
1526   // Test a computation which returns a GetElementTuple of a nested tuple
1527   // parameter.
1528   auto builder = HloComputation::Builder(TestName());
1529   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1530       0,
1531       ShapeUtil::MakeTupleShape(
1532           {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1533            ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}),
1534                                       ShapeUtil::MakeShape(S32, {101})})}),
1535       "param0"));
1536   auto tuple_element =
1537       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1538           ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
1539 
1540   auto module = CreateNewVerifiedModule();
1541   module->AddEntryComputation(builder.Build());
1542   auto assignment = RunBufferAssignment(module.get());
1543 
1544   // Only some of the elements of the input param are liveout.
1545   EXPECT_FALSE(
1546       GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out());
1547   // Tuple element at index={1} is live out because GetTupleElement({1})
1548   // forwards a pointer to this allocation (instead of defining its own buffer).
1549   EXPECT_TRUE(
1550       GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out());
1551   EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0})
1552                   .maybe_live_out());
1553   EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1})
1554                   .maybe_live_out());
1555 
1556   // The GetTupleElement output is liveout.
1557   EXPECT_TRUE(
1558       GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out());
1559 
1560   // Verify that the GetTupleElement allocations of its elements match the
1561   // corresponding tuple parameter allocations because they alias.
1562   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}),
1563             GetAllocation(*assignment, tuple_element, /*index=*/{0}));
1564   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}),
1565             GetAllocation(*assignment, tuple_element, /*index=*/{1}));
1566 
1567   // GetTupleElement forwards a pointer to its underlying buffer, so verify
1568   // that it has the same allocation than the corresponding parameter element.
1569   EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}),
1570             GetTopLevelAllocation(*assignment, tuple_element));
1571 }
1572 
1573 // TODO(b/32248867): Enable when buffer assignment gives allocations to
1574 // constants.
TEST_F(BufferAssignmentTest,TupleConstantAsOutput)1575 TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
1576   // Test that a tuple constant which is forwarded to the computation output
1577   // is properly handled.
1578   auto builder = HloComputation::Builder(TestName());
1579   Literal elements[] = {LiteralUtil::CreateR0<int64_t>(0),
1580                         LiteralUtil::CreateR0<int64_t>(1)};
1581   builder.AddInstruction(HloInstruction::CreateConstant(
1582       LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
1583 
1584   auto module = CreateNewVerifiedModule();
1585   module->AddEntryComputation(builder.Build());
1586   auto assignment = RunBufferAssignment(module.get());
1587 
1588   EXPECT_EQ(3, assignment->Allocations().size());
1589 }
1590 
TEST_F(BufferAssignmentTest,TupleCustomCallAsOutput)1591 TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
1592   // Test a computation which returns a tuple custom call value.
1593   auto builder = HloComputation::Builder(TestName());
1594   auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall(
1595       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1596                                  ShapeUtil::MakeShape(S32, {101})}),
1597       /*operands=*/{}, /*custom_call_target=*/"foo_function"));
1598   auto module = CreateNewVerifiedModule();
1599   module->AddEntryComputation(builder.Build());
1600   auto assignment = RunBufferAssignment(module.get());
1601 
1602   EXPECT_EQ(3, assignment->Allocations().size());
1603   EXPECT_TRUE(
1604       GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out());
1605   EXPECT_TRUE(
1606       GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out());
1607   EXPECT_TRUE(
1608       GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
1609 }
1610 
TEST_F(BufferAssignmentTest,CustomCallAliasedBuffer)1611 TEST_F(BufferAssignmentTest, CustomCallAliasedBuffer) {
1612   // Test a computation with custom call aliasing.
1613   const char* const kModuleString = R"(
1614     HloModule xla_computation_f
1615     ENTRY xla_computation_f {
1616       parameter.1 = f32[2,3,4,5] parameter(0)
1617       parameter.2 = f32[2,3,4,5] parameter(1)
1618       add = f32[2,3,4,5] add(parameter.1, parameter.2)
1619       ROOT custom-call = f32[2,3,4,5] custom-call(add, parameter.2), custom_call_target="dm_softmax", operand_layout_constraints={f32[2,3,4,5], f32[2,3,4,5]}, output_to_operand_aliasing={{}: (0, {})}
1620     }
1621   )";
1622 
1623   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
1624                           ParseAndReturnUnverifiedModule(kModuleString));
1625   std::unique_ptr<BufferAssignment> assignment =
1626       RunBufferAssignment(module.get());
1627   HloInstruction* custom_call = module->entry_computation()->root_instruction();
1628   EXPECT_TRUE(
1629       assignment->SharesTopLevelSlice(custom_call, custom_call->operand(0)));
1630 }
1631 
TEST_F(BufferAssignmentTest,TupleCallAsOutput)1632 TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
1633   // Test a computation which returns a tuple call value.
1634   auto module = CreateNewVerifiedModule();
1635   auto elem_shape = f32vec4_;
1636   auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1637 
1638   auto sub_builder = HloComputation::Builder(TestName() + "_sub");
1639   auto sub_param = sub_builder.AddInstruction(
1640       HloInstruction::CreateParameter(0, elem_shape, "sub_param"));
1641   auto sub_tuple =
1642       sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param}));
1643   auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
1644 
1645   auto builder = HloComputation::Builder(TestName());
1646   auto param = builder.AddInstruction(
1647       HloInstruction::CreateParameter(0, elem_shape, "param"));
1648   auto call = builder.AddInstruction(
1649       HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
1650   module->AddEntryComputation(builder.Build());
1651 
1652   auto assignment = RunBufferAssignment(module.get());
1653 
1654   EXPECT_EQ(2, assignment->Allocations().size());
1655   // Buffers for call are colocated with the sub-computation.
1656   EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}),
1657             GetAllocation(*assignment, sub_tuple, /*index=*/{}));
1658   EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}),
1659             GetAllocation(*assignment, sub_param, /*index=*/{}));
1660 
1661   // The parameter isn't aliased with the result tuple, but it is aliased with
1662   // the call operand.
1663   EXPECT_NE(GetTopLevelAllocation(*assignment, param),
1664             GetTopLevelAllocation(*assignment, sub_tuple));
1665   EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1666             GetTopLevelAllocation(*assignment, sub_param));
1667 }
1668 
TEST_F(BufferAssignmentTest,TupleChainedCallAsOutput)1669 TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
1670   // Test a chain of calls with tuple output. The chain looks like:
1671   // A: call(B, tuple(param))
1672   // B: call(C, param)
1673   // C: call(D, param)
1674   // D: param
1675   auto module = CreateNewVerifiedModule();
1676   auto elem_shape = f32vec4_;
1677   auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1678 
1679   auto d_builder = HloComputation::Builder(TestName() + "_d");
1680   auto d_param = d_builder.AddInstruction(
1681       HloInstruction::CreateParameter(0, tuple_shape, "d_param"));
1682   auto d_computation = d_builder.Build();
1683 
1684   auto c_builder = HloComputation::Builder(TestName() + "_c");
1685   auto c_param = c_builder.AddInstruction(
1686       HloInstruction::CreateParameter(0, tuple_shape, "c_param"));
1687   auto c_call = c_builder.AddInstruction(
1688       HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get()));
1689   auto c_computation = c_builder.Build();
1690 
1691   auto b_builder = HloComputation::Builder(TestName() + "_b");
1692   auto b_param = b_builder.AddInstruction(
1693       HloInstruction::CreateParameter(0, tuple_shape, "b_param"));
1694   auto b_call = b_builder.AddInstruction(
1695       HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get()));
1696   auto b_computation = b_builder.Build();
1697 
1698   auto a_builder = HloComputation::Builder(TestName());
1699   auto a_param = a_builder.AddInstruction(
1700       HloInstruction::CreateParameter(0, elem_shape, "param"));
1701   auto a_tuple =
1702       a_builder.AddInstruction(HloInstruction::CreateTuple({a_param}));
1703   auto a_call = a_builder.AddInstruction(
1704       HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get()));
1705   auto a_computation = a_builder.Build();
1706 
1707   // Add the computations in an order that doesn't match the dependency
1708   // post-order, to shake out more possible bugs.
1709   module->AddEmbeddedComputation(std::move(d_computation));
1710   module->AddEmbeddedComputation(std::move(c_computation));
1711   module->AddEntryComputation(std::move(a_computation));
1712   module->AddEmbeddedComputation(std::move(b_computation));
1713 
1714   auto assignment = RunBufferAssignment(module.get());
1715 
1716   // Buffers for call are colocated with the sub-computations.
1717   EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
1718             GetAllocation(*assignment, b_call, /*index=*/{}));
1719   EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}),
1720             GetAllocation(*assignment, c_call, /*index=*/{}));
1721   EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}),
1722             GetAllocation(*assignment, d_param, /*index=*/{}));
1723   EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}),
1724             GetAllocation(*assignment, b_call, /*index=*/{0}));
1725   EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}),
1726             GetAllocation(*assignment, c_call, /*index=*/{0}));
1727   EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}),
1728             GetAllocation(*assignment, d_param, /*index=*/{0}));
1729 
1730   EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment));
1731   EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment));
1732   EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment));
1733 
1734   EXPECT_EQ(GetAllocation(*assignment, b_param, /*index=*/{0}),
1735             GetAllocation(*assignment, c_param, /*index=*/{0}));
1736   EXPECT_EQ(GetAllocation(*assignment, c_param, /*index=*/{0}),
1737             GetAllocation(*assignment, d_param, /*index=*/{0}));
1738 }
1739 
TEST_F(BufferAssignmentTest,BitcastAsOutput)1740 TEST_F(BufferAssignmentTest, BitcastAsOutput) {
1741   // Test a computation which returns a bitcast value.
1742   auto builder = HloComputation::Builder(TestName());
1743   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1744       0, ShapeUtil::MakeShape(F32, {42}), "param"));
1745   auto bitcast = builder.AddInstruction(
1746       HloInstruction::CreateBitcast(param->shape(), param));
1747 
1748   auto module = CreateNewVerifiedModule();
1749   module->AddEntryComputation(builder.Build());
1750   auto assignment = RunBufferAssignment(module.get());
1751 
1752   // Bitcast should get the same allocation as the param.
1753   EXPECT_EQ(1, assignment->Allocations().size());
1754   EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1755             GetTopLevelAllocation(*assignment, bitcast));
1756 }
1757 
1758 // TODO(b/34669761): Remove this test when buffers are allowed to share
1759 // allocations.
TEST_F(BufferAssignmentTest,TupleBufferNotReused)1760 TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
1761   // Test a computation that returns a tuple parameter.
1762   auto builder = HloComputation::Builder(TestName());
1763   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1764   auto param = builder.AddInstruction(
1765       HloInstruction::CreateParameter(0, scalar_shape, "param0"));
1766   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({param}));
1767   auto tuple_element = builder.AddInstruction(
1768       HloInstruction::CreateGetTupleElement(scalar_shape, tuple, 0));
1769   auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
1770       scalar_shape, HloOpcode::kCopy, tuple_element));
1771 
1772   auto module = CreateNewVerifiedModule();
1773   module->AddEntryComputation(builder.Build());
1774   auto assignment = RunBufferAssignment(module.get());
1775 
1776   // There should be no buffer reuse. The copy should not reuse the tuple
1777   // buffer.
1778   EXPECT_EQ(3, assignment->Allocations().size());
1779   EXPECT_NE(GetTopLevelAllocation(*assignment, tuple),
1780             GetTopLevelAllocation(*assignment, copy));
1781 }
1782 
TEST_F(BufferAssignmentTest,OneTempAllocation)1783 TEST_F(BufferAssignmentTest, OneTempAllocation) {
1784   // Test a computation that requires multiple temp buffers, and ensure they
1785   // are combined into a single allocation.
1786   auto builder = HloComputation::Builder(TestName());
1787   Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3});
1788   Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4});
1789   Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4});
1790   Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4});
1791   Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4});
1792 
1793   // There should be separate temp buffers for dot_ab and dot_bc.
1794   auto param_a = builder.AddInstruction(
1795       HloInstruction::CreateParameter(0, shape_2x3, "param_a"));
1796   auto param_b = builder.AddInstruction(
1797       HloInstruction::CreateParameter(1, shape_3x4, "param_b"));
1798   auto param_c = builder.AddInstruction(
1799       HloInstruction::CreateParameter(2, shape_4x4, "param_c"));
1800   DotDimensionNumbers dot_dnums;
1801   dot_dnums.add_lhs_contracting_dimensions(1);
1802   dot_dnums.add_rhs_contracting_dimensions(0);
1803   PrecisionConfig precision_config;
1804   precision_config.mutable_operand_precision()->Resize(
1805       2, PrecisionConfig::DEFAULT);
1806   auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
1807       shape_2x4, param_a, param_b, dot_dnums, precision_config));
1808   auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
1809       shape_3x4, param_b, param_c, dot_dnums, precision_config));
1810   builder.AddInstruction(
1811       HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
1812 
1813   // Run buffer assignment with alignment=1.
1814   auto module = CreateNewVerifiedModule();
1815   module->AddEntryComputation(builder.Build());
1816   auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
1817 
1818   // There are 5 allocations: 3 parameters, 1 output, and 1 temp.
1819   EXPECT_EQ(5, assignment->Allocations().size());
1820 
1821   // Ensure the temp buffers for dot_ab and dot_bc share a single allocation,
1822   // and each occupies different slices of that allocation.
1823   BufferAllocation::Slice slice_ab =
1824       assignment->GetUniqueTopLevelSlice(dot_ab).value();
1825   BufferAllocation::Slice slice_bc =
1826       assignment->GetUniqueTopLevelSlice(dot_bc).value();
1827   EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1828   EXPECT_NE(slice_ab, slice_bc);
1829   EXPECT_EQ(32, slice_ab.size());
1830   EXPECT_EQ(48, slice_bc.size());
1831   EXPECT_EQ(80, slice_ab.allocation()->size());
1832   EXPECT_EQ(80, slice_bc.allocation()->size());
1833 
1834   // Re-run buffer assignment with alignment=64.
1835   assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
1836   EXPECT_EQ(5, assignment->Allocations().size());
1837   slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).value();
1838   slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).value();
1839   EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1840   EXPECT_NE(slice_ab, slice_bc);
1841   EXPECT_EQ(32, slice_ab.size());
1842   EXPECT_EQ(48, slice_bc.size());
1843   // Ensure the offsets and allocation size account for the alignment, without
1844   // assuming which buffer gets assigned first.
1845   if (slice_ab.offset() == 0) {
1846     EXPECT_EQ(64, slice_bc.offset());
1847     EXPECT_EQ(64 + 48, slice_ab.allocation()->size());
1848     EXPECT_EQ(64 + 48, slice_bc.allocation()->size());
1849   } else {
1850     EXPECT_EQ(64, slice_ab.offset());
1851     EXPECT_EQ(0, slice_bc.offset());
1852     EXPECT_EQ(64 + 32, slice_ab.allocation()->size());
1853     EXPECT_EQ(64 + 32, slice_bc.allocation()->size());
1854   }
1855 }
1856 
TEST_F(BufferAssignmentTest,TrivialPeakBuffers)1857 TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
1858   // paramscalar -(bc)- (mul) -- (add) -- (sub)
1859   //                     /        /        /
1860   // param0[100] -------/        /        /
1861   //                            /        /
1862   // param1[100] --------------/--------/
1863   auto builder = HloComputation::Builder(TestName());
1864   auto paramscalar =
1865       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
1866   auto broadcast = builder.AddInstruction(
1867       HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
1868   auto param0 = builder.AddInstruction(
1869       HloInstruction::CreateParameter(1, f32vec100_, "p1"));
1870   auto param1 = builder.AddInstruction(
1871       HloInstruction::CreateParameter(2, f32vec100_, "p2"));
1872   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
1873       f32vec100_, HloOpcode::kMultiply, broadcast, param0));
1874   auto add = builder.AddInstruction(
1875       HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
1876   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
1877       f32vec100_, HloOpcode::kSubtract, add, param1));
1878   auto module = CreateNewVerifiedModule();
1879   module->AddEntryComputation(builder.Build());
1880 
1881   auto buffers = RunBufferAssignment(module.get());
1882 
1883   const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
1884   const std::vector<const HloValue*>& peak_buffers =
1885       mul_buffer.PeakMemoryLogicalBuffers();
1886   ASSERT_EQ(peak_buffers.size(), 1);
1887   EXPECT_EQ(peak_buffers[0]->instruction(), sub);
1888 }
1889 
TEST_F(BufferAssignmentTest,PeakBuffers)1890 TEST_F(BufferAssignmentTest, PeakBuffers) {
1891   // Compute the peak liveness buffers of the following sequence:
1892   //
1893   //   %param = ...
1894   //   %log = log(%param)
1895   //   %rev = reverse(%log)
1896   //   %neg = neg(%param)
1897   //   %concat = concat(%rev, %neg)
1898   //   ROOT %root = slice(concat)
1899   //
1900   // In the temporary block, the set of live buffers at peak memory use should
1901   // be {%rev, %neg, %concat}. This occurs right at the concat itself.
1902   auto builder = HloComputation::Builder(TestName());
1903   auto param = builder.AddInstruction(
1904       HloInstruction::CreateParameter(0, f32vec100_, "p"));
1905   auto log = builder.AddInstruction(
1906       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
1907   auto rev = builder.AddInstruction(
1908       HloInstruction::CreateReverse(f32vec100_, log, {0}));
1909   auto neg = builder.AddInstruction(
1910       HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
1911   const Shape concat_shape = ShapeUtil::MakeShape(F32, {200});
1912   auto concat = builder.AddInstruction(
1913       HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0));
1914   // Make the root tiny so no interior nodes can share its buffer.
1915   auto root = builder.AddInstruction(HloInstruction::CreateSlice(
1916 
1917       ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1}));
1918 
1919   auto module = CreateNewVerifiedModule();
1920   module->AddEntryComputation(builder.Build());
1921 
1922   auto buffers = RunBufferAssignmentWithInstructionSequence(
1923       module.get(), {param, log, rev, neg, concat, root});
1924 
1925   // The temporary buffer should hold the 4 interior instructions.
1926   const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
1927   EXPECT_FALSE(buffer.IsInputOrOutput());
1928   EXPECT_TRUE(buffer.IsPreallocatedTempBuffer());
1929   ASSERT_EQ(buffer.assigned_buffers().size(), 4);
1930 
1931   const std::vector<const HloValue*>& peak_buffers =
1932       buffer.PeakMemoryLogicalBuffers();
1933 
1934   // The peak live set should be concat and its inputs.
1935   ASSERT_EQ(peak_buffers.size(), 3);
1936   std::vector<const HloInstruction*> peak_instructions;
1937   for (const HloValue* logical_buffer : peak_buffers) {
1938     peak_instructions.push_back(logical_buffer->instruction());
1939   }
1940   EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
1941 }
1942 
TEST_F(BufferAssignmentTest,AliasedBuffersShouldntCoexistInPeakBuffers)1943 TEST_F(BufferAssignmentTest, AliasedBuffersShouldntCoexistInPeakBuffers) {
1944   std::string hlo_text = R"(
1945 HloModule test_module, is_scheduled=true
1946 
1947 cond {
1948   param = (s32[], s32[]) parameter(0)
1949   ROOT constant = pred[] constant(true)
1950 }
1951 
1952 body {
1953   param.0 = (s32[], s32[]) parameter(0)
1954   gte = s32[] get-tuple-element(param.0), index=0
1955   add = s32[] add(gte, gte)
1956   ROOT tuple = (s32[], s32[]) tuple(add, add)
1957 }
1958 
1959 ENTRY test_module {
1960   param.3 = s32[] parameter(0)
1961   copy = s32[] copy(param.3)
1962   tuple = (s32[], s32[]) tuple(copy, copy)
1963   while = (s32[], s32[]) while(tuple), condition=cond, body=body
1964   gte = s32[] get-tuple-element(while), index=0
1965   ROOT negate = s32[] negate(gte)
1966 })";
1967 
1968   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
1969   auto assignment = RunBufferAssignmentWithSequentialOrdering(module.get());
1970   const BufferAllocation& buffer =
1971       GetTopLevelAllocation(*assignment, FindInstruction(module.get(), "copy"));
1972   const std::vector<const HloValue*>& peak_buffers =
1973       buffer.PeakMemoryLogicalBuffers();
1974 
1975   // Since the same aliased buffer (copy) is passed into while, we expect the
1976   // number of peak array buffers to be one.
1977   int num_peak_buffers = 0;
1978   for (const HloValue* peak_buffer : peak_buffers) {
1979     if (peak_buffer->shape().IsArray()) {
1980       ++num_peak_buffers;
1981     }
1982   }
1983   EXPECT_EQ(num_peak_buffers, 1);
1984 }
1985 
TEST_F(BufferAssignmentTest,InPlaceBuffer)1986 TEST_F(BufferAssignmentTest, InPlaceBuffer) {
1987   const char* hlo_text = R"(
1988 HloModule Module
1989 
1990 ENTRY main {
1991   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
1992   constant.1 = f32[] constant(0)
1993   broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
1994   get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
1995   get-tuple-element.3 = s32[] get-tuple-element(state), index=0
1996   constant.2 = s32[] constant(128)
1997   add.5 = s32[] add(get-tuple-element.3, constant.2)
1998   constant.3 = s32[] constant(0)
1999   dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2000   dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2001   ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2002 }
2003 )";
2004 
2005   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
2006   HloInstruction* parameter =
2007       m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
2008   HloInstruction* dus1 =
2009       m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
2010   HloInstruction* dus2 =
2011       m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
2012 
2013   auto buffers = RunBufferAssignment(m.get());
2014 
2015   {
2016     const BufferAllocation& parameter_alloc =
2017         GetTopLevelAllocation(*buffers, parameter);
2018 
2019     const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
2020     EXPECT_EQ(parameter_alloc, dus1_alloc);
2021     const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
2022     EXPECT_EQ(parameter_alloc, dus2_alloc);
2023   }
2024 }
2025 
TEST_F(BufferAssignmentTest,ConstantBuffersAreNotReused)2026 TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) {
2027   const char* hlo_text = R"(
2028 HloModule Module
2029 
2030 True {
2031   ROOT x.0.1 = f32[] parameter(0)
2032 }
2033 
2034 False {
2035   x.0.0 = f32[] parameter(0)
2036   ROOT copy.1 = f32[] copy(x.0.0)
2037 }
2038 
2039 ENTRY main {
2040   pred.1.0 = pred[] parameter(0)
2041   constant.1.1 = f32[] constant(56)
2042   copy.2 = f32[] copy(constant.1.1)
2043   constant.1.2 = f32[] constant(12)
2044   ROOT conditional.1.3 = f32[] conditional(pred.1.0, copy.2, constant.1.2),
2045       true_computation=True, false_computation=False
2046 }
2047 )";
2048 
2049   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
2050   HloInstruction* constant_1 =
2051       m->entry_computation()->GetInstructionWithName("constant.1.1");
2052   HloInstruction* constant_2 =
2053       m->entry_computation()->GetInstructionWithName("constant.1.2");
2054 
2055   auto buffers = RunBufferAssignment(m.get());
2056 
2057   {
2058     const BufferAllocation& allocation_for_const_1 =
2059         GetTopLevelAllocation(*buffers, constant_1);
2060     EXPECT_TRUE(allocation_for_const_1.is_constant());
2061     for (const auto& buffer_offset_pair :
2062          allocation_for_const_1.assigned_buffers()) {
2063       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2064                 HloOpcode::kCopy);
2065       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2066                 HloOpcode::kConditional);
2067     }
2068   }
2069 
2070   {
2071     const BufferAllocation& allocation_for_const_2 =
2072         GetTopLevelAllocation(*buffers, constant_2);
2073     EXPECT_TRUE(allocation_for_const_2.is_constant());
2074     for (const auto& buffer_offset_pair :
2075          allocation_for_const_2.assigned_buffers()) {
2076       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2077                 HloOpcode::kCopy);
2078       EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2079                 HloOpcode::kConditional);
2080     }
2081   }
2082 }
2083 
2084 class WhileBufferAssignmentTest : public HloTestBase {
2085  protected:
BuildWhileConditionComputation(const std::string & name)2086   std::unique_ptr<HloComputation> BuildWhileConditionComputation(
2087       const std::string& name) {
2088     auto builder = HloComputation::Builder(name);
2089     builder.AddInstruction(
2090         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2091     auto zero = builder.AddInstruction(
2092         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
2093     auto ten = builder.AddInstruction(
2094         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
2095     builder.AddInstruction(HloInstruction::CreateCompare(
2096         ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
2097     return builder.Build();
2098   }
2099 
BuildWhileBodyComputation(const std::string & name)2100   std::unique_ptr<HloComputation> BuildWhileBodyComputation(
2101       const std::string& name) {
2102     auto builder = HloComputation::Builder(name);
2103     auto loop_state = builder.AddInstruction(
2104         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2105     auto input = builder.AddInstruction(
2106         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
2107     auto weights = builder.AddInstruction(
2108         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
2109     auto output = builder.AddInstruction(HloInstruction::CreateBinary(
2110         data_shape_, HloOpcode::kMultiply, input, weights));
2111     builder.AddInstruction(
2112         HloInstruction::CreateTuple({input, weights, output}));
2113     return builder.Build();
2114   }
2115 
RunBufferAssignment(HloModule * module,int64_t alignment=1)2116   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
2117                                                         int64_t alignment = 1) {
2118     HloSchedule schedule = ScheduleModule(module, ByteSizeOf).value();
2119     return BufferAssigner::Run(
2120                module, std::make_unique<SequentialHloOrdering>(schedule),
2121                ByteSizeOf,
2122                [alignment](LogicalBuffer::Color) { return alignment; },
2123                /*allocate_buffers_for_constants=*/true)
2124         .value();
2125   }
2126 
ByteSizeOf(const BufferValue & buffer)2127   static int64_t ByteSizeOf(const BufferValue& buffer) {
2128     return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
2129   }
2130 
2131   Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
2132   Shape loop_state_shape_ =
2133       ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
2134 };
2135 
RunCopyInsertion(HloModule * module)2136 static void RunCopyInsertion(HloModule* module) {
2137   CopyInsertion copy_insertion;
2138   EXPECT_IS_OK(copy_insertion.Run(module).status());
2139 }
2140 
TEST_F(WhileBufferAssignmentTest,TwoForwardWhileLoops)2141 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
2142   auto module = CreateNewVerifiedModule();
2143   auto builder = HloComputation::Builder("entry");
2144 
2145   auto input0 = builder.AddInstruction(
2146       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2147   auto weights0 = builder.AddInstruction(
2148       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2149   auto weights1 = builder.AddInstruction(
2150       HloInstruction::CreateParameter(2, data_shape_, "weights1"));
2151 
2152   auto zero = builder.AddInstruction(
2153       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2154   auto output0 = builder.AddInstruction(
2155       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2156   auto output1 = builder.AddInstruction(
2157       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2158 
2159   auto cond0 =
2160       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2161   auto body0 =
2162       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2163 
2164   auto tuple0 = builder.AddInstruction(
2165       HloInstruction::CreateTuple({input0, weights0, output0}));
2166   auto while0 = builder.AddInstruction(
2167       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2168 
2169   auto cond1 =
2170       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2171   auto body1 =
2172       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2173   auto input1 = builder.AddInstruction(
2174       HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2175   auto tuple1 = builder.AddInstruction(
2176       HloInstruction::CreateTuple({input1, weights1, output1}));
2177   auto while1 = builder.AddInstruction(
2178       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2179 
2180   module->AddEntryComputation(builder.Build());
2181   RunCopyInsertion(module.get());
2182   auto assignment = RunBufferAssignment(module.get());
2183 
2184   // Verify 'input0' and read-only use while0{0} alias.
2185   EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).value(),
2186             assignment->GetUniqueSlice(while0, {0}).value());
2187   // Verify 'weights0' and read-only use while0{1} alias.
2188   EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).value(),
2189             assignment->GetUniqueSlice(while0, {1}).value());
2190   // Verify 'while0{2}' and read-only use while1{0} alias.
2191   EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).value(),
2192             assignment->GetUniqueSlice(while1, {0}).value());
2193   // Verify 'weights1' and read-only use while1{1} alias.
2194   EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).value(),
2195             assignment->GetUniqueSlice(while1, {1}).value());
2196 }
2197 
2198 // Tests that two colocated buffer sets are not merged if an entry parameter
2199 // buffer belongs to either of the colocation sets (b/73267882).
2200 //
2201 // %param --> %while.0 --> %mul --> %while.1 --> %broadcast
2202 //
2203 // %while.0 body just forwards the init value, so the loop carried variable
2204 // remains the constant, whereas %while.1 changes the loop carried variable.
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithEntryParameter)2205 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) {
2206   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2207 
2208   const char* module_str = R"(
2209 HloModule test_module
2210 
2211 %cond.v0 {
2212   %param = s32[] parameter(0)
2213   ROOT %constant = pred[] constant(true)
2214 }
2215 
2216 %cond.v1 {
2217   %param.0 = s32[] parameter(0)
2218   ROOT %constant.0 = pred[] constant(true)
2219 }
2220 
2221 %body.v0 {
2222   ROOT %param.1 = s32[] parameter(0)
2223 }
2224 
2225 %body.v1 {
2226   %param.2 = s32[] parameter(0)
2227   ROOT add = s32[] add(%param.2, %param.2)
2228 }
2229 
2230 ENTRY %test_module {
2231   %param.3 = s32[] parameter(0)
2232   %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0
2233   %mul = s32[] multiply(%while.0, %while.0)
2234   %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2235   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2236 })";
2237 
2238   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2239 
2240   // Run CopyInsertion and check if the graph constructed above doesn't need
2241   // any copies inserted for BufferAssignment to run.
2242   int64_t instruction_count = m->instruction_count();
2243   CopyInsertion copy_insertion;
2244   ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2245   ASSERT_EQ(instruction_count, m->instruction_count());
2246 
2247   // Get the instructions in the module.
2248   const HloInstruction* bcast = m->entry_computation()->root_instruction();
2249   const HloInstruction* param =
2250       m->entry_computation()->parameter_instruction(0);
2251   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2252   const HloInstruction* while1 = bcast->operand(0);
2253   ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2254   const HloInstruction* while0 = while1->operand(0)->operand(0);
2255   ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2256 
2257   // Run buffer assignment.
2258   auto assignment = RunBufferAssignment(m.get());
2259   TF_ASSERT_OK_AND_ASSIGN(auto slice_param,
2260                           assignment->GetUniqueSlice(param, {}));
2261   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2262                           assignment->GetUniqueSlice(while0, {}));
2263   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2264                           assignment->GetUniqueSlice(while1, {}));
2265 
2266   // The parameter slice is part of the while0's colocation set (init value),
2267   // but not merged into the while1's colocation set.
2268   EXPECT_EQ(slice_param, slice_while0);
2269   EXPECT_NE(slice_param, slice_while1);
2270 }
2271 
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithConstant)2272 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) {
2273   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2274 
2275   const char* module_str = R"(
2276 HloModule test_module
2277 
2278 %cond.v0 {
2279   %param = s32[] parameter(0)
2280   ROOT %constant = pred[] constant(true)
2281 }
2282 
2283 %cond.v1 {
2284   %param.0 = s32[] parameter(0)
2285   ROOT %constant.0 = pred[] constant(true)
2286 }
2287 
2288 %body.v0 {
2289   ROOT %param.1 = s32[] parameter(0)
2290 }
2291 
2292 %body.v1 {
2293   %param.2 = s32[] parameter(0)
2294   ROOT add = s32[] add(%param.2, %param.2)
2295 }
2296 
2297 ENTRY %test_module {
2298   %constant.42 = s32[] constant(42)
2299   %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0
2300   %mul = s32[] multiply(%while.0, %while.0)
2301   %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2302   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2303 })";
2304 
2305   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2306 
2307   // Run CopyInsertion and check if the graph constructed above doesn't need
2308   // any copies inserted for BufferAssignment to run.
2309   int64_t instruction_count = m->instruction_count();
2310   CopyInsertion copy_insertion;
2311   ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2312   ASSERT_EQ(instruction_count, m->instruction_count());
2313 
2314   // Get the instructions in the module.
2315   const HloInstruction* bcast = m->entry_computation()->root_instruction();
2316   const HloInstruction* constant =
2317       m->entry_computation()->GetInstructionWithName("constant.42");
2318   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2319   const HloInstruction* while1 = bcast->operand(0);
2320   ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2321   const HloInstruction* while0 = while1->operand(0)->operand(0);
2322   ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2323 
2324   // Run buffer assignment.
2325   auto assignment = RunBufferAssignment(m.get());
2326   TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
2327                           assignment->GetUniqueSlice(constant, {}));
2328   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2329                           assignment->GetUniqueSlice(while0, {}));
2330   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2331                           assignment->GetUniqueSlice(while1, {}));
2332 
2333   // The constant slice is part of the while0's colocation set (init value), but
2334   // not merged into the while1's colocation set.
2335   EXPECT_EQ(slice_constant, slice_while0);
2336   EXPECT_NE(slice_constant, slice_while1);
2337 }
2338 
2339 // Tests that the colocated buffers for while instructions are properly assigned
2340 // during buffer assignment such that the result tuple elements are not assigned
2341 // to the same buffer.
2342 //
2343 // %infeed --> %while.0 --> %while.1 --+
2344 //                                     +-- %tuple
2345 //   %zero -->   %add   --> %while.2 --+
2346 //
2347 // Execution Order:
2348 // %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
2349 //
2350 // The HLO computation used in this test requires specific ordering to expose
2351 // the bug (b/72496031). During buffer assignment, the visitation order of
2352 // colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
2353 // assignment was coalescing the colocated buffers for all 3 while instructions,
2354 // therefore assigning the same buffer to the two result tuple elements.
TEST_F(WhileBufferAssignmentTest,ColocatedBuffers)2355 TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
2356   const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2357 
2358   // Builds a condition computation: x -> x < 4
2359   auto build_cond = [&]() {
2360     auto builder = HloComputation::Builder("cond");
2361     auto const4 = builder.AddInstruction(
2362         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
2363     auto param =
2364         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2365     builder.AddInstruction(
2366         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
2367                                       const4, ComparisonDirection::kLt));
2368     return builder.Build();
2369   };
2370 
2371   // Builds a body computation: x -> x + 9
2372   auto build_body = [&]() {
2373     auto builder = HloComputation::Builder("body");
2374     auto const9 = builder.AddInstruction(
2375         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
2376     auto param =
2377         builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2378     builder.AddInstruction(
2379         HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
2380     return builder.Build();
2381   };
2382 
2383   // Build the entry computation as described in the comment above.
2384   auto module = CreateNewVerifiedModule();
2385   auto builder = HloComputation::Builder("entry");
2386 
2387   auto token = builder.AddInstruction(HloInstruction::CreateToken());
2388   auto infeed =
2389       builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
2390   auto infeed_data = builder.AddInstruction(
2391       HloInstruction::CreateGetTupleElement(r0s32, infeed, 0));
2392   auto cond0 = module->AddEmbeddedComputation(build_cond());
2393   auto body0 = module->AddEmbeddedComputation(build_body());
2394   auto while0 = builder.AddInstruction(
2395       HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data));
2396 
2397   auto cond1 = module->AddEmbeddedComputation(build_cond());
2398   auto body1 = module->AddEmbeddedComputation(build_body());
2399   auto while1 = builder.AddInstruction(
2400       HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
2401 
2402   auto zero = builder.AddInstruction(
2403       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
2404   auto add = builder.AddInstruction(
2405       HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
2406   auto cond2 = module->AddEmbeddedComputation(build_cond());
2407   auto body2 = module->AddEmbeddedComputation(build_body());
2408   auto while2 = builder.AddInstruction(
2409       HloInstruction::CreateWhile(r0s32, cond2, body2, add));
2410 
2411   auto tuple =
2412       builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
2413   module->AddEntryComputation(builder.Build());
2414 
2415   // Run CopyInsertion and check if the graph constructed above doesn't need
2416   // any copies inserted for BufferAssignment to run.
2417   int64_t instruction_count = module->instruction_count();
2418   CopyInsertion copy_insertion;
2419   ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
2420   ASSERT_EQ(instruction_count, module->instruction_count());
2421 
2422   // Create a sequential order among all the instructions in the entry
2423   // computation, since the issue this test stresses depends on the order the
2424   // nodes are traversed during BufferAssignment.
2425   TF_ASSERT_OK_AND_ASSIGN(
2426       HloSchedule schedule,
2427       ScheduleModule(module.get(), [](const BufferValue& buffer) {
2428         return ShapeUtil::ByteSizeOf(buffer.shape(),
2429                                      /*pointer_size=*/sizeof(void*));
2430       }));
2431   schedule.set_sequence(
2432       module->entry_computation(),
2433       {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
2434   TF_ASSERT_OK(schedule.Verify());
2435 
2436   TF_ASSERT_OK_AND_ASSIGN(
2437       auto assignment,
2438       BufferAssigner::Run(
2439           module.get(), std::make_unique<SequentialHloOrdering>(schedule),
2440           backend().compiler()->BufferSizeBytesFunction(),
2441           [](LogicalBuffer::Color) { return 1; },
2442           /*allocate_buffers_for_constants=*/true));
2443 
2444   // The result tuple elements must be assigned with different buffers.
2445   TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
2446   TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
2447   EXPECT_NE(slice0, slice1);
2448 
2449   // while0 and while1 result buffers must be equal to slice1.
2450   TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2451                           assignment->GetUniqueSlice(while0, {}));
2452   TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2453                           assignment->GetUniqueSlice(while1, {}));
2454   EXPECT_EQ(slice1, slice_while0);
2455   EXPECT_EQ(slice1, slice_while1);
2456 
2457   // while2 result buffer must be equal to slice0.
2458   TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
2459                           assignment->GetUniqueSlice(while2, {}));
2460   EXPECT_EQ(slice0, slice_while2);
2461 }
2462 
TEST_F(WhileBufferAssignmentTest,OneForwardBackwardWhileLoopSet)2463 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
2464   auto module = CreateNewVerifiedModule();
2465   auto builder = HloComputation::Builder("entry");
2466 
2467   auto input0 = builder.AddInstruction(
2468       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2469   auto weights0 = builder.AddInstruction(
2470       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2471 
2472   auto zero = builder.AddInstruction(
2473       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2474   auto output0 = builder.AddInstruction(
2475       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2476 
2477   auto cond0 =
2478       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2479   auto body0 =
2480       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2481 
2482   auto tuple0 = builder.AddInstruction(
2483       HloInstruction::CreateTuple({input0, weights0, output0}));
2484   auto while0 = builder.AddInstruction(
2485       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2486 
2487   auto cond1 =
2488       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2489   auto body1 =
2490       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2491 
2492   auto while1 = builder.AddInstruction(
2493       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
2494 
2495   module->AddEntryComputation(builder.Build());
2496   RunCopyInsertion(module.get());
2497   auto assignment = RunBufferAssignment(module.get());
2498 
2499   // while0 and while1 buffers should be completely aligned.
2500   EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).value(),
2501             assignment->GetUniqueSlice(while1, {0}).value());
2502   EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).value(),
2503             assignment->GetUniqueSlice(while1, {1}).value());
2504   EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).value(),
2505             assignment->GetUniqueSlice(while1, {2}).value());
2506 }
2507 
TEST_F(BufferAssignmentTest,TwoCalls)2508 TEST_F(BufferAssignmentTest, TwoCalls) {
2509   auto module = CreateNewVerifiedModule();
2510   Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
2511   HloComputation* sub_computation;
2512   {
2513     auto builder = HloComputation::Builder(TestName() + "_sub_comp");
2514     auto param = builder.AddInstruction(
2515         HloInstruction::CreateParameter(0, r0f32, "param"));
2516     auto constant1 = builder.AddInstruction(
2517         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2518     auto add = builder.AddInstruction(
2519         HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
2520     sub_computation = module->AddEmbeddedComputation(builder.Build(add));
2521   }
2522   auto builder = HloComputation::Builder(TestName());
2523   auto constant2 = builder.AddInstruction(
2524       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
2525   auto constant3 = builder.AddInstruction(
2526       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
2527   auto call1 = builder.AddInstruction(
2528       HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
2529   auto call2 = builder.AddInstruction(
2530       HloInstruction::CreateCall(r0f32, {constant3}, sub_computation));
2531   auto add1 = builder.AddInstruction(
2532       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2));
2533   auto add2 = builder.AddInstruction(
2534       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1));
2535   module->AddEntryComputation(builder.Build(add2));
2536 
2537   {
2538     FlattenCallGraph flatten;
2539     TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2540     EXPECT_TRUE(result);
2541     std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
2542   }
2543 
2544   RunCopyInsertion(module.get());
2545   auto assignment = RunBufferAssignment(module.get());
2546 
2547   EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
2548 }
2549 
TEST_F(BufferAssignmentTest,CallParamCoAllocation)2550 TEST_F(BufferAssignmentTest, CallParamCoAllocation) {
2551   const char* hlo_text = R"(
2552 HloModule CallParamCoAllocation
2553 
2554 Callee {
2555   param0 = (f32[100],(f32[200],f32[300])) parameter(0)
2556   param1 = s32[20] parameter(1)
2557   ROOT constant = f32[] constant(1)
2558 }
2559 
2560 ENTRY Main {
2561   entry_param0 = f32[100] parameter(0)
2562   entry_param1 = s32[20]  parameter(1)
2563   custom_call = (f32[200],f32[300]) custom-call(), custom_call_target="call-target"
2564   call_op0 = (f32[100],(f32[200],f32[300])) tuple(entry_param0, custom_call)
2565   ROOT call_result = f32[] call(call_op0, entry_param1), to_apply=Callee
2566 }
2567 )";
2568 
2569   HloModuleConfig config;
2570   config.set_debug_options(GetDebugOptionsFromFlags());
2571   TF_ASSERT_OK_AND_ASSIGN(auto m,
2572                           ParseAndReturnVerifiedModule(hlo_text, config));
2573 
2574   auto buffers = RunBufferAssignment(m.get());
2575 
2576   HloComputation* main = m->entry_computation();
2577   HloComputation* callee = m->GetComputationWithName("Callee");
2578   EXPECT_NE(callee, nullptr);
2579 
2580   HloInstruction* param0 = callee->parameter_instruction(0);
2581   HloInstruction* param1 = callee->parameter_instruction(1);
2582 
2583   HloInstruction* entry_param0 = main->parameter_instruction(0);
2584   HloInstruction* entry_param1 = main->parameter_instruction(1);
2585   HloInstruction* custom_call = main->GetInstructionWithName("custom_call");
2586 
2587   EXPECT_EQ(GetAllocation(*buffers, entry_param0, {}),
2588             GetAllocation(*buffers, param0, {0}));
2589   EXPECT_EQ(GetAllocation(*buffers, entry_param1, {}),
2590             GetAllocation(*buffers, param1, {}));
2591 
2592   EXPECT_EQ(GetAllocation(*buffers, custom_call, {}),
2593             GetAllocation(*buffers, param0, {1}));
2594   EXPECT_EQ(GetAllocation(*buffers, custom_call, {0}),
2595             GetAllocation(*buffers, param0, {1, 0}));
2596   EXPECT_EQ(GetAllocation(*buffers, custom_call, {1}),
2597             GetAllocation(*buffers, param0, {1, 1}));
2598 }
2599 
TEST_F(BufferAssignmentTest,AsyncCall)2600 TEST_F(BufferAssignmentTest, AsyncCall) {
2601   const char* hlo_text = R"(
2602 HloModule AsyncCall, is_scheduled=true
2603 
2604 %called_computation (param_0: f32[4096], param_1: f32[4096]) -> f32[4096] {
2605   %param_0 = f32[4096]{0} parameter(0)
2606   %param_1 = f32[4096]{0} parameter(1)
2607   %negate_0 = f32[4096]{0} negate(f32[4096]{0} %param_0)
2608   %negate_1 = f32[4096]{0} negate(f32[4096]{0} %param_1)
2609   %negate_2 = f32[4096]{0} negate(f32[4096]{0} %negate_1)
2610   %negate_3 = f32[4096]{0} negate(f32[4096]{0} %negate_2)
2611   ROOT %result.1 = f32[4096]{0} add(f32[4096]{0} %negate_0, f32[4096]{0} %negate_3)
2612 }
2613 
2614 ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] {
2615   %a = f32[4096]{0} parameter(0)
2616   %b = f32[4096]{0} parameter(1)
2617   %async-start = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-start(f32[4096]{0} %a, f32[4096]{0} %b), to_apply=%called_computation
2618   %negate_4 = f32[4096]{0} negate(f32[4096]{0} %a)
2619   %negate_5 = f32[4096]{0} negate(f32[4096]{0} %b)
2620   %negate_6 = f32[4096]{0} negate(f32[4096]{0} %negate_5)
2621   %negate_7 = f32[4096]{0} negate(f32[4096]{0} %negate_6)
2622   %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_4, f32[4096]{0} %negate_7)
2623   %async-done = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start), to_apply=%called_computation
2624   ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done)
2625 }
2626 )";
2627 
2628   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
2629   AsyncOpCanonicalizer async_op_canonicalizer;
2630   EXPECT_TRUE(async_op_canonicalizer.Run(m.get()).ok());
2631   HloDCE dce;
2632   EXPECT_TRUE(dce.Run(m.get()).ok());
2633 
2634   auto buffers = RunBufferAssignmentWithSequentialOrdering(m.get());
2635 
2636   LOG(INFO) << buffers->ToString();
2637 
2638   auto get_slice = [&](std::string_view hlo_name, const ShapeIndex& index) {
2639     return buffers->GetUniqueSlice(FindInstruction(m.get(), hlo_name), index)
2640         .ValueOrDie();
2641   };
2642 
2643   // Make sure the parameters and root of the async called computation has the
2644   // same slice as the async call operands/output.
2645   EXPECT_EQ(get_slice("param_0", {}), get_slice("a", {}));
2646   EXPECT_EQ(get_slice("param_1", {}), get_slice("b", {}));
2647   EXPECT_EQ(get_slice("result.1", {}), get_slice("async-done", {}));
2648 
2649   // Make sure the intermediate values in the async called computation have
2650   // different allocated slices than the values that overlap it.
2651   for (const auto& hlo_name :
2652        {"negate_0", "negate_1", "negate_2", "negate_3"}) {
2653     EXPECT_NE(get_slice(hlo_name, {}), get_slice("negate_4", {}));
2654     EXPECT_NE(get_slice(hlo_name, {}), get_slice("negate_5", {}));
2655     EXPECT_NE(get_slice(hlo_name, {}), get_slice("negate_6", {}));
2656     EXPECT_NE(get_slice(hlo_name, {}), get_slice("negate_7", {}));
2657     EXPECT_NE(get_slice(hlo_name, {}), get_slice("add_0", {}));
2658   }
2659 }
2660 
TEST_F(BufferAssignmentTest,BufferInfoStringTest)2661 TEST_F(BufferAssignmentTest, BufferInfoStringTest) {
2662   absl::string_view module_str = R"(
2663 HloModule test_module
2664 
2665 ENTRY %test_module {
2666   %param.0 = s32[1024]{0} parameter(0)
2667   %param.1 = s32[1024]{0} parameter(1)
2668   %mul = s32[1024]{0} multiply(%param.0, %param.1)
2669   %add = s32[1024]{0} add(%mul, %param.0)
2670   ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[1024] %add), dimensions={0}
2671 })";
2672 
2673   absl::string_view reference_str =
2674       R"(buffer_id,buffer_name,offset,size,definition_time,end_time,num_uses,use_times,use_names
2675 0,"<0 param.0 @0>",0,4096,0,5,2,"2;3","mul, operand 0;add, operand 1"
2676 1,"<1 param.1 @0>",0,4096,1,5,1,"2","mul, operand 1"
2677 2,"<2 mul @0>",0,4096,2,3,1,"3","add, operand 0"
2678 3,"<3 add @0>",0,4096,3,4,1,"4","bcast, operand 0"
2679 4,"<4 bcast @0>",0,4194304,4,5,0,"",""
2680 )";
2681 
2682   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2683   HloInstruction* const param0 = FindInstruction(m.get(), "param.0");
2684   HloInstruction* const param1 = FindInstruction(m.get(), "param.1");
2685   HloInstruction* const mul = FindInstruction(m.get(), "mul");
2686   HloInstruction* const add = FindInstruction(m.get(), "add");
2687   HloInstruction* const bcast = FindInstruction(m.get(), "bcast");
2688   // Run buffer assignment.
2689   auto assignment = RunBufferAssignmentWithInstructionSequence(
2690       m.get(), {param0, param1, mul, add, bcast});
2691   const std::string buffer_info_str = assignment->BufferInfoString();
2692 
2693   EXPECT_EQ(buffer_info_str, reference_str);
2694 }
2695 
TEST_F(WhileBufferAssignmentTest,WhileLoopsInterferingResultRange)2696 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
2697   auto module = CreateNewVerifiedModule();
2698   auto builder = HloComputation::Builder(TestName());
2699 
2700   auto zero = builder.AddInstruction(
2701       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2702   auto one = builder.AddInstruction(
2703       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2704 
2705   auto input0 = builder.AddInstruction(
2706       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2707   auto weights0 = builder.AddInstruction(
2708       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2709   auto output0 = builder.AddInstruction(
2710       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2711 
2712   auto input1 = builder.AddInstruction(
2713       HloInstruction::CreateParameter(2, data_shape_, "input1"));
2714   auto weights1 = builder.AddInstruction(
2715       HloInstruction::CreateParameter(3, data_shape_, "weights1"));
2716   auto output1 = builder.AddInstruction(
2717       HloInstruction::CreateBroadcast(data_shape_, one, {}));
2718 
2719   auto cond =
2720       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2721   auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2722 
2723   auto tuple0 = builder.AddInstruction(
2724       HloInstruction::CreateTuple({input0, weights0, output0}));
2725   auto tuple1 = builder.AddInstruction(
2726       HloInstruction::CreateTuple({input1, weights1, output1}));
2727 
2728   auto while0 = builder.AddInstruction(
2729       HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
2730   auto while1 = builder.AddInstruction(
2731       HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
2732 
2733   auto gte0 = builder.AddInstruction(
2734       HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
2735   auto gte1 = builder.AddInstruction(
2736       HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
2737   auto root_add = builder.AddInstruction(
2738       HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1));
2739 
2740   module->AddEntryComputation(builder.Build());
2741 
2742   {
2743     FlattenCallGraph flatten;
2744     TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2745     EXPECT_TRUE(result);
2746   }
2747 
2748   RunCopyInsertion(module.get());
2749 
2750   HloSchedule schedule = ScheduleModule(module.get(), ByteSizeOf).value();
2751 
2752   // To trigger b/38494731, we want a specific Hlo schedule for the
2753   // root computation, so we overwrite that entry with a manually
2754   // crafted sequence.
2755   schedule.set_sequence(
2756       module->entry_computation(),
2757       {input1, weights1, one, output1, while1->mutable_operand(0), while1,
2758        input0, weights0, zero, output0, while0->mutable_operand(0), while0,
2759        gte0, gte1, root_add});
2760 
2761   // If this ASSERT fails, we constructed a bogus sequence above and this test
2762   // itself is buggy.
2763   TF_ASSERT_OK(schedule.Verify());
2764 
2765   auto assignment =
2766       BufferAssigner::Run(
2767           module.get(), std::make_unique<SequentialHloOrdering>(schedule),
2768           ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
2769           /*allocate_buffers_for_constants=*/true)
2770           .value();
2771 
2772   EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
2773 }
2774 
TEST_F(WhileBufferAssignmentTest,WhilesDontShareEntryParamIfLiveOut)2775 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
2776   auto module = CreateNewVerifiedModule();
2777   auto builder = HloComputation::Builder("entry");
2778 
2779   auto input0 = builder.AddInstruction(
2780       HloInstruction::CreateParameter(0, data_shape_, "input0"));
2781   auto weights0 = builder.AddInstruction(
2782       HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2783 
2784   auto zero = builder.AddInstruction(
2785       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2786   auto output0 = builder.AddInstruction(
2787       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2788   auto output1 = builder.AddInstruction(
2789       HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2790 
2791   auto cond0 =
2792       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2793   auto body0 =
2794       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2795 
2796   auto tuple0 = builder.AddInstruction(
2797       HloInstruction::CreateTuple({input0, weights0, output0}));
2798   auto while0 = builder.AddInstruction(
2799       HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2800 
2801   // Get output of 'while0' and feed as input to 'while1'.
2802   auto while0_out = builder.AddInstruction(
2803       HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2804 
2805   auto cond1 =
2806       module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2807   auto body1 =
2808       module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2809 
2810   auto tuple1 = builder.AddInstruction(
2811       HloInstruction::CreateTuple({while0_out, weights0, output1}));
2812   auto while1 = builder.AddInstruction(
2813       HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2814 
2815   // Get output of 'while1' so that it is live out of computation.
2816   auto while1_out = builder.AddInstruction(
2817       HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
2818 
2819   module->AddEntryComputation(builder.Build());
2820   RunCopyInsertion(module.get());
2821   auto assignment = RunBufferAssignment(module.get());
2822   // Get BufferAllocation for root instruction.
2823   auto* root_alloc =
2824       assignment->GetUniqueTopLevelSlice(while1_out).value().allocation();
2825   // Test that root instruction allocation is live out.
2826   EXPECT_TRUE(root_alloc->maybe_live_out());
2827   // Test that root instruction allocation is not an entry parameter.
2828   EXPECT_FALSE(root_alloc->is_entry_computation_parameter());
2829 }
2830 
TEST_F(WhileBufferAssignmentTest,WhileWithDynamicUpdateSliceShare)2831 TEST_F(WhileBufferAssignmentTest, WhileWithDynamicUpdateSliceShare) {
2832   const char* const hlo_string = R"(
2833 HloModule test
2834 
2835 while_body {
2836   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2837   constant.1 = f32[] constant(0)
2838   broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2839   get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2840   get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2841   constant.2 = s32[] constant(128)
2842   add.5 = s32[] add(get-tuple-element.3, constant.2)
2843   constant.3 = s32[] constant(0)
2844   dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2845   dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2846   ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2847 }
2848 
2849 while_condition {
2850   state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2851   get-tuple-element = s32[] get-tuple-element(state), index=0
2852   get-tuple-element.1 = s32[] constant(3)
2853   ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
2854 }
2855 
2856 ENTRY entry_computation {
2857   constant.7 = s32[] constant(0)
2858   copy.1 = s32[] copy(constant.7)
2859   constant.6 = f32[] constant(0)
2860   broadcast.6 = f32[1280,1,128]{2,1,0} broadcast(constant.6), dimensions={}
2861   tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(copy.1, broadcast.6)
2862   while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body
2863   ROOT get-tuple-element.2 = s32[] get-tuple-element(while.0), index=0
2864 }
2865 
2866 )";
2867   auto module = ParseAndReturnVerifiedModule(hlo_string).value();
2868 
2869   RunCopyInsertion(module.get());
2870   auto assignment = RunBufferAssignment(module.get());
2871   // Get BufferAllocation for root instruction.
2872   auto dus9 = FindInstruction(module.get(), "dynamic-update-slice.9");
2873   auto dus9_alloc_slice = assignment->GetUniqueTopLevelSlice(dus9).value();
2874   auto dus5 = FindInstruction(module.get(), "dynamic-update-slice.5");
2875   auto dus5_alloc_slice = assignment->GetUniqueTopLevelSlice(dus5).value();
2876   // Test that the two dynamic-update-slice ops share the same allocation slice.
2877   EXPECT_EQ(dus9_alloc_slice.allocation(), dus5_alloc_slice.allocation());
2878   EXPECT_EQ(dus9_alloc_slice, dus5_alloc_slice);
2879 }
2880 }  // namespace
2881 }  // namespace xla
2882