1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
17
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/service/async_op_canonicalizer.h"
28 #include "tensorflow/compiler/xla/service/buffer_value.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/copy_insertion.h"
31 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
32 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_dce.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
39 #include "tensorflow/compiler/xla/service/hlo_parser.h"
40 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
41 #include "tensorflow/compiler/xla/shape_util.h"
42 #include "tensorflow/compiler/xla/test.h"
43 #include "tensorflow/compiler/xla/test_helpers.h"
44 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/lib/core/status_test_util.h"
48
49 namespace xla {
50 namespace {
51
52 using memory_space_assignment::PresetAssignments;
53 using ::testing::UnorderedElementsAre;
54
55 // DFS visitor that collects the instructions referenced by a computation
56 // without descending into nested computations, i.e., only from the operands.
57 class InstructionListVisitor : public DfsHloVisitorWithDefault {
58 public:
InstructionListVisitor(const HloInstruction * root)59 explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {}
60
DefaultAction(HloInstruction * hlo)61 Status DefaultAction(HloInstruction* hlo) override {
62 // For each instruction, just push it on the list after walking the
63 // operands.
64 instructions_.push_back(hlo);
65 VLOG(0) << "List instruction " << hlo->ToString();
66 return OkStatus();
67 }
68
GetInstructions()69 std::vector<const HloInstruction*> GetInstructions() { return instructions_; }
70
71 private:
72 // The instruction root of the computation.
73 const HloInstruction* root_;
74
75 // The full set of instructions found (may be duplicates, e.g., kParameter).
76 std::vector<const HloInstruction*> instructions_;
77
78 InstructionListVisitor(const InstructionListVisitor&) = delete;
79 InstructionListVisitor& operator=(const InstructionListVisitor&) = delete;
80 };
81
GetInstructions(HloInstruction * root)82 const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
83 InstructionListVisitor main_list(root);
84 TF_CHECK_OK(root->Accept(&main_list));
85 return main_list.GetInstructions();
86 }
87
88 class BufferAssignmentTest : public HloTestBase {
89 protected:
~BufferAssignmentTest()90 ~BufferAssignmentTest() override {}
91
RunBufferAssignment(HloModule * module,int64_t alignment=1)92 std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
93 int64_t alignment = 1) {
94 return BufferAssigner::Run(
95 module, std::make_unique<DependencyHloOrdering>(module),
96 backend().compiler()->BufferSizeBytesFunction(),
97 [alignment](LogicalBuffer::Color) { return alignment; },
98 /*allocate_buffers_for_constants=*/true)
99 .value();
100 }
101
RunBufferAssignmentWithSequentialOrdering(HloModule * module,int64_t alignment=1)102 std::unique_ptr<BufferAssignment> RunBufferAssignmentWithSequentialOrdering(
103 HloModule* module, int64_t alignment = 1) {
104 return BufferAssigner::Run(
105 module,
106 std::make_unique<SequentialHloOrdering>(module->schedule()),
107 backend().compiler()->BufferSizeBytesFunction(),
108 [alignment](LogicalBuffer::Color) { return alignment; },
109 /*allocate_buffers_for_constants=*/true)
110 .value();
111 }
112
RunBufferAssignmentNoBuffersForConstants(HloModule * module,int64_t alignment=1)113 std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants(
114 HloModule* module, int64_t alignment = 1) {
115 return BufferAssigner::Run(
116 module, std::make_unique<DependencyHloOrdering>(module),
117 backend().compiler()->BufferSizeBytesFunction(),
118 [alignment](LogicalBuffer::Color) { return alignment; },
119 /*allocate_buffers_for_constants=*/false)
120 .value();
121 }
122
RunBufferAssignmentNoBuffersReuseForAdd(HloModule * module,int64_t alignment=1)123 std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersReuseForAdd(
124 HloModule* module, int64_t alignment = 1) {
125 auto must_not_live_out = [](const HloInstruction* instruction,
126 const ShapeIndex&) {
127 return instruction->opcode() == HloOpcode::kAdd;
128 };
129
130 return BufferAssigner::Run(
131 module, std::make_unique<DependencyHloOrdering>(module),
132 backend().compiler()->BufferSizeBytesFunction(),
133 [alignment](LogicalBuffer::Color) { return alignment; },
134 /*allocate_buffers_for_constants=*/false,
135 /*colorer=*/BufferAssigner::DefaultColorer(),
136 /*must_not_live_out=*/must_not_live_out)
137 .value();
138 }
139
RunColoredBufferAssignment(HloModule * module,BufferAssigner::Colorer colorer,int64_t alignment=1)140 std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
141 HloModule* module, BufferAssigner::Colorer colorer,
142 int64_t alignment = 1) {
143 return BufferAssigner::Run(
144 module, std::make_unique<DependencyHloOrdering>(module),
145 backend().compiler()->BufferSizeBytesFunction(),
146 [alignment](LogicalBuffer::Color) { return alignment; },
147 /*allocate_buffers_for_constants=*/true, std::move(colorer))
148 .value();
149 }
150
RunBufferAssignmentWithInstructionSequence(HloModule * module,absl::Span<HloInstruction * const> instruction_sequence,int64_t alignment=1)151 std::unique_ptr<BufferAssignment> RunBufferAssignmentWithInstructionSequence(
152 HloModule* module, absl::Span<HloInstruction* const> instruction_sequence,
153 int64_t alignment = 1) {
154 HloSchedule schedule(module);
155 schedule.set_sequence(module->entry_computation(), instruction_sequence);
156 return BufferAssigner::Run(
157 module, std::make_unique<SequentialHloOrdering>(schedule),
158 backend().compiler()->BufferSizeBytesFunction(),
159 [alignment](LogicalBuffer::Color) { return alignment; },
160 /*allocate_buffers_for_constants=*/true)
161 .value();
162 }
163
RunBufferAssignmentWithPresetAssignments(HloModule * module,std::unique_ptr<PresetAssignments> preset_assignments,int64_t alignment=1)164 std::unique_ptr<BufferAssignment> RunBufferAssignmentWithPresetAssignments(
165 HloModule* module, std::unique_ptr<PresetAssignments> preset_assignments,
166 int64_t alignment = 1) {
167 return BufferAssigner::Run(
168 module, std::make_unique<DependencyHloOrdering>(module),
169 backend().compiler()->BufferSizeBytesFunction(),
170 [alignment](LogicalBuffer::Color) { return alignment; },
171 /*allocate_buffers_for_constants=*/true,
172 BufferAssigner::DefaultColorer(),
173 /*must_not_live_out=*/std::nullopt,
174 /*can_share_buffer=*/nullptr, std::move(preset_assignments))
175 .value();
176 }
177
178 // Builds an x+1.0 computation to use in a Map.
BuildMapComputationPlus1(const std::string & name)179 std::unique_ptr<HloComputation> BuildMapComputationPlus1(
180 const std::string& name) {
181 auto builder = HloComputation::Builder(name);
182 auto param =
183 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
184 auto value = builder.AddInstruction(
185 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
186 builder.AddInstruction(
187 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
188 return builder.Build();
189 }
190
BuildReduceComputation(const std::string & name)191 std::unique_ptr<HloComputation> BuildReduceComputation(
192 const std::string& name) {
193 auto builder = HloComputation::Builder(name);
194 auto param =
195 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
196 auto param2 =
197 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
198 builder.AddInstruction(
199 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2));
200 return builder.Build();
201 }
202
203 // Builds a simple compare-to-limit (x < 4) computation for a While.
204 //
205 // condition:
206 // const4[s32] -----------------------------------\
207 // \
208 // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
209 //
BuildWhileConditionComputation(const std::string & name)210 std::unique_ptr<HloComputation> BuildWhileConditionComputation(
211 const std::string& name) {
212 auto builder = HloComputation::Builder(name);
213 auto const4 = builder.AddInstruction(
214 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
215 auto param = builder.AddInstruction(
216 HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
217 auto index = builder.AddInstruction(
218 HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
219 builder.AddInstruction(
220 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
221 const4, ComparisonDirection::kLt));
222 return builder.Build();
223 }
224
225 // Builds a simple body computation for a While.
226 //
227 // body:
228 // constv[f32[4]] --------------------------------------\
229 // \
230 // /--- get-tuple-elementv[1] --- addv ---\
231 // param[(s32,f32[4])] ---| tuple
232 // \--- get-tuple-elementc[0] --- addc ---/
233 // /
234 // const1[s32] -----------------------------------------/
235 //
BuildWhileBodyComputation(const std::string & name)236 std::unique_ptr<HloComputation> BuildWhileBodyComputation(
237 const std::string& name) {
238 auto builder = HloComputation::Builder(name);
239 auto const1 = builder.AddInstruction(
240 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
241 auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
242 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
243 auto param = builder.AddInstruction(
244 HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
245 auto indexc = builder.AddInstruction(
246 HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
247 auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
248 indexc->shape(), HloOpcode::kAdd, indexc, const1));
249 auto indexv = builder.AddInstruction(
250 HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
251 auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
252 constv->shape(), HloOpcode::kAdd, indexv, constv));
253 builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
254 return builder.Build();
255 }
256
BuildR0F32UnaryOpComputation(HloOpcode opcode,const std::string & name)257 std::unique_ptr<HloComputation> BuildR0F32UnaryOpComputation(
258 HloOpcode opcode, const std::string& name) {
259 auto builder = HloComputation::Builder(name);
260 auto param =
261 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
262 builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param));
263 return builder.Build();
264 }
265
266 // Verifies that the given instruction hlo has a valid input buffer assigned,
267 // i.e., the parameter number matches the op's.
GetAssignedInputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)268 const BufferAllocation& GetAssignedInputAllocation(
269 const BufferAssignment& buffers, HloInstruction* hlo) {
270 LOG(INFO) << "Checking input: " << hlo->ToString();
271 const BufferAllocation& buffer =
272 *buffers.GetUniqueTopLevelSlice(hlo).value().allocation();
273 EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number());
274 return buffer;
275 }
276
277 // Verifies that the given instruction hlo has a valid output buffer
278 // assigned, and returns it.
GetAssignedOutputAllocation(const BufferAssignment & buffers,HloInstruction * hlo)279 const BufferAllocation& GetAssignedOutputAllocation(
280 const BufferAssignment& buffers, HloInstruction* hlo) {
281 LOG(INFO) << "Checking output: " << hlo->ToString();
282 const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo);
283 return buffer;
284 }
285
286 // Returns the allocation for the given instruction.
GetAllocation(const BufferAssignment & buffers,const HloInstruction * hlo,const ShapeIndex & index)287 const BufferAllocation& GetAllocation(const BufferAssignment& buffers,
288 const HloInstruction* hlo,
289 const ShapeIndex& index) {
290 return *buffers.GetUniqueSlice(hlo, index).value().allocation();
291 }
GetTopLevelAllocation(const BufferAssignment & buffers,const HloInstruction * hlo)292 const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers,
293 const HloInstruction* hlo) {
294 return *buffers.GetUniqueTopLevelSlice(hlo).value().allocation();
295 }
296
297 // Verifies that all instructions in the given instruction list except
298 // kConstant have assigned buffers, and returns their total size. If min_index
299 // and max_index are not nullptr, the minimum and maximum buffer indices in
300 // the assignment are written into them.
ValidateBuffers(const std::vector<const HloInstruction * > & instructions,const BufferAssignment & buffers)301 int64_t ValidateBuffers(
302 const std::vector<const HloInstruction*>& instructions,
303 const BufferAssignment& buffers) {
304 // Verifies all instructions have buffers, and gets the index ranges.
305 for (const HloInstruction* hlo : instructions) {
306 if (!buffers.HasTopLevelAllocation(hlo)) {
307 // If `hlo` has no assigned buffer, it is either a constant or a nested
308 // parameter.
309 EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() ||
310 HloOpcode::kParameter == hlo->opcode());
311 continue;
312 }
313 }
314
315 // Gets the total size of all buffers assigned.
316 int64_t total_size = 0;
317 for (auto& allocation : buffers.Allocations()) {
318 total_size += allocation.size();
319 }
320 return total_size;
321 }
322
323 // Shapes for use in the examples.
324 Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
325 Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
326 Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
327 Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10});
328 Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100});
329 Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10});
330 Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_});
331 Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
332 };
333
334 // Returns true if the buffers assigned to instructions in "a" are distinct
335 // from the buffers assigned to those in "b" (ie, intersection is empty).
BuffersDistinct(const std::vector<const HloInstruction * > & a,const std::vector<const HloInstruction * > & b,const BufferAssignment & assignment)336 static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
337 const std::vector<const HloInstruction*>& b,
338 const BufferAssignment& assignment) {
339 absl::flat_hash_set<BufferAllocation::Slice> a_slices;
340 for (const HloInstruction* instruction : a) {
341 if (assignment.HasTopLevelAllocation(instruction)) {
342 a_slices.insert(assignment.GetUniqueTopLevelSlice(instruction).value());
343 }
344 }
345
346 for (const HloInstruction* instruction : b) {
347 if (assignment.HasTopLevelAllocation(instruction)) {
348 if (a_slices.contains(
349 assignment.GetUniqueTopLevelSlice(instruction).value())) {
350 return false;
351 }
352 }
353 }
354 return true;
355 }
356
357 // Tests a computation consisting of a single scalar constant node.
TEST_F(BufferAssignmentTest,ScalarConstant)358 TEST_F(BufferAssignmentTest, ScalarConstant) {
359 auto builder = HloComputation::Builder(TestName());
360 auto const0 = builder.AddInstruction(
361 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
362 auto module = CreateNewVerifiedModule();
363 module->AddEntryComputation(builder.Build());
364
365 {
366 auto buffers = RunBufferAssignment(module.get());
367 EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
368 }
369
370 {
371 auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
372 EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
373 }
374 }
375
TEST_F(BufferAssignmentTest,BufferForConst)376 TEST_F(BufferAssignmentTest, BufferForConst) {
377 // Addition of two vector constants: checks that internal constant nodes have
378 // no buffers assigned, and their consumer has a buffer.
379 auto builder = HloComputation::Builder(TestName());
380 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
381 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
382 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
383 LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
384 auto add = builder.AddInstruction(
385 HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
386 auto module = CreateNewVerifiedModule();
387 module->AddEntryComputation(builder.Build());
388
389 {
390 auto buffers = RunBufferAssignment(module.get());
391 EXPECT_TRUE(buffers->HasTopLevelAllocation(const0));
392 EXPECT_TRUE(buffers->HasTopLevelAllocation(const1));
393 GetAssignedOutputAllocation(*buffers, add);
394 }
395 {
396 auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get());
397 EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
398 EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
399 GetAssignedOutputAllocation(*buffers, add);
400 }
401 }
402
TEST_F(BufferAssignmentTest,HasAllocationAt)403 TEST_F(BufferAssignmentTest, HasAllocationAt) {
404 // Create a tuple with non-const and const elements and check that
405 // HasAllocationAt works correctly.
406 auto builder = HloComputation::Builder(TestName());
407 auto param0 = builder.AddInstruction(
408 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
409 auto constant = builder.AddInstruction(
410 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
411 auto negate = builder.AddInstruction(
412 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
413 auto tuple = builder.AddInstruction(
414 HloInstruction::CreateTuple({negate, param0, constant}));
415 auto module = CreateNewVerifiedModule();
416 module->AddEntryComputation(builder.Build());
417
418 auto buffers = RunBufferAssignment(module.get());
419 // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
420 // reports for the instruction directly.
421 EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
422 buffers->HasAllocationAt(tuple, /*index=*/{}));
423 EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
424 buffers->HasAllocationAt(tuple, /*index=*/{0}));
425 EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
426 buffers->HasAllocationAt(tuple, /*index=*/{1}));
427 EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
428 buffers->HasAllocationAt(tuple, /*index=*/{2}));
429 }
430
TEST_F(BufferAssignmentTest,BufferForOutputConst)431 TEST_F(BufferAssignmentTest, BufferForOutputConst) {
432 // This computation copies a constant to output.
433 auto builder = HloComputation::Builder(TestName());
434 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
435 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
436 auto copy = builder.AddInstruction(
437 HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
438 auto module = CreateNewVerifiedModule();
439 module->AddEntryComputation(builder.Build());
440
441 auto buffers = RunBufferAssignment(module.get());
442 // The copy node now has an output buffer.
443 GetAssignedOutputAllocation(*buffers, copy);
444 }
445
TEST_F(BufferAssignmentTest,Basic)446 TEST_F(BufferAssignmentTest, Basic) {
447 // paramscalar ------- (mul) -- (add) -- (sub)
448 // / / /
449 // param0[100] -------/ / /
450 // / /
451 // param1[100] --------------/--------/
452 auto builder = HloComputation::Builder(TestName());
453 auto paramscalar =
454 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
455 auto broadcast = builder.AddInstruction(
456 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
457 auto param0 = builder.AddInstruction(
458 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
459 auto param1 = builder.AddInstruction(
460 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
461 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
462 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
463 auto add = builder.AddInstruction(
464 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
465 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
466 f32vec100_, HloOpcode::kSubtract, add, param1));
467 auto module = CreateNewVerifiedModule();
468 module->AddEntryComputation(builder.Build());
469
470 auto buffers = RunBufferAssignment(module.get());
471
472 // Distinct input buffers were assigned for parameters.
473 BufferAllocation paramscalar_buffer =
474 GetAssignedInputAllocation(*buffers, paramscalar);
475 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
476 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
477 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
478 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
479 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
480
481 // The mul node has a valid buffer assigned, doesn't share with input.
482 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
483 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
484
485 // The add node can reuse the mul node's buffer.
486 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
487 EXPECT_EQ(add_buffer.index(), mul_buffer.index());
488
489 // The sub node has a valid output buffer assigned.
490 GetAssignedOutputAllocation(*buffers, sub);
491 }
492
TEST_F(BufferAssignmentTest,AliasedParamCanBeReused)493 TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) {
494 // If an input buffer and output buffer aliases, the input buffer can be
495 // reused for other intermediate results.
496 //
497 // param0[100] ----- (neg1) -- (neg2)
498 // | |
499 // + -------- Aliased ---------+
500
501 auto builder = HloComputation::Builder(TestName());
502
503 auto param = builder.AddInstruction(
504 HloInstruction::CreateParameter(0, f32vec100_, "p0"));
505 auto neg_1 = builder.AddInstruction(
506 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
507 auto neg_2 = builder.AddInstruction(
508 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, neg_1));
509
510 auto module = CreateNewVerifiedModule();
511 module->AddEntryComputation(builder.Build());
512
513 TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({}, 0, {}));
514
515 auto buffers = RunBufferAssignment(module.get());
516
517 BufferAllocation param_buffer = GetAssignedInputAllocation(*buffers, param);
518 BufferAllocation neg_1_buffer = GetAllocation(*buffers, neg_1, {});
519 BufferAllocation neg_2_buffer = GetAllocation(*buffers, neg_2, {});
520
521 // Everything use one buffer.
522 EXPECT_EQ(param_buffer.index(), neg_1_buffer.index());
523 EXPECT_EQ(neg_2_buffer.index(), neg_1_buffer.index());
524 }
525
TEST_F(BufferAssignmentTest,AddCannotReuse)526 TEST_F(BufferAssignmentTest, AddCannotReuse) {
527 // Pass in a special rule to indicate that "add" cannot be live out.
528 //
529 // paramscalar ------- (mul) -- (add) -- (sub)
530 // / / /
531 // param0[100] -------/ / /
532 // / /
533 // param1[100] --------------/--------/
534 auto builder = HloComputation::Builder(TestName());
535 auto paramscalar =
536 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
537 auto broadcast = builder.AddInstruction(
538 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
539 auto param0 = builder.AddInstruction(
540 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
541 auto param1 = builder.AddInstruction(
542 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
543 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
544 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
545 auto add = builder.AddInstruction(
546 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
547 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
548 f32vec100_, HloOpcode::kSubtract, add, param1));
549 auto module = CreateNewVerifiedModule();
550 module->AddEntryComputation(builder.Build());
551
552 auto buffers = RunBufferAssignmentNoBuffersReuseForAdd(module.get());
553
554 // Distinct input buffers were assigned for parameters.
555 BufferAllocation paramscalar_buffer =
556 GetAssignedInputAllocation(*buffers, paramscalar);
557 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
558 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
559 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
560 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
561 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
562
563 // The mul node has a valid buffer assigned, doesn't share with input.
564 const BufferAllocation& sub_buffer = GetTopLevelAllocation(*buffers, sub);
565 EXPECT_NE(sub_buffer.index(), param0_buffer.index());
566
567 // The add node cannot reuse the mul node's buffer since we told buffer
568 // assignment so.
569 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
570 EXPECT_NE(add_buffer.index(), sub_buffer.index());
571
572 // The sub node has a valid output buffer assigned.
573 GetAssignedOutputAllocation(*buffers, sub);
574 }
575
TEST_F(BufferAssignmentTest,BasicUniquelyColored)576 TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
577 // paramscalar ------- (mul) -- (add) -- (sub)
578 // / / /
579 // param0[100] -------/ / /
580 // / /
581 // param1[100] --------------/--------/
582 // The output of each op is colored with a different color, so we can not
583 // share anything.
584 auto builder = HloComputation::Builder(TestName());
585 auto paramscalar =
586 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
587 auto broadcast = builder.AddInstruction(
588 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
589 auto param0 = builder.AddInstruction(
590 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
591 auto param1 = builder.AddInstruction(
592 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
593 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
594 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
595 auto add = builder.AddInstruction(
596 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
597 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
598 f32vec100_, HloOpcode::kSubtract, add, param1));
599 auto module = CreateNewVerifiedModule();
600 module->AddEntryComputation(builder.Build());
601
602 absl::flat_hash_map<const HloInstruction*, int> color_map;
603 auto colorer = [&](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
604 int color = 0;
605 for (HloValue::Id id = 0;
606 id < alias_analysis->dataflow_analysis().values().size(); id++) {
607 auto& value = alias_analysis->dataflow_analysis().GetValue(id);
608 color_map[value.defining_instruction()] = color;
609 value.set_color(BufferValue::Color(color++));
610 }
611 return OkStatus();
612 };
613
614 auto buffers = RunColoredBufferAssignment(module.get(), colorer);
615
616 // Distinct input buffers were assigned for parameters.
617 BufferAllocation paramscalar_buffer =
618 GetAssignedInputAllocation(*buffers, paramscalar);
619 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
620 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
621 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
622 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
623 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
624
625 // The mul node has a valid buffer assigned, doesn't share with input.
626 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
627 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
628
629 // The add node can not reuse the mul node's buffer due to coloring.
630 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
631 EXPECT_NE(add_buffer.index(), mul_buffer.index());
632
633 // The sub node has a valid output buffer assigned.
634 GetAssignedOutputAllocation(*buffers, sub);
635
636 // Check if the HLO instructions have the correct colors in the layout.
637 EXPECT_EQ(param0->shape().layout().memory_space(), color_map[param0]);
638 EXPECT_EQ(param1->shape().layout().memory_space(), color_map[param1]);
639 EXPECT_EQ(mul->shape().layout().memory_space(), color_map[mul]);
640 EXPECT_EQ(add->shape().layout().memory_space(), color_map[add]);
641 EXPECT_EQ(sub->shape().layout().memory_space(), color_map[sub]);
642 }
643
TEST_F(BufferAssignmentTest,BasicPartiallyColored)644 TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
645 // paramscalar ------- (mul) -- (add) -- (sub)
646 // / / /
647 // param0[100] -------/ / /
648 // / /
649 // param1[100] --------------/--------/
650 // The output of the mul and the add have the color 1, and the other buffers
651 // have the color 0, which allows the mul and add to share buffers.
652 auto builder = HloComputation::Builder(TestName());
653 auto paramscalar =
654 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
655 auto broadcast = builder.AddInstruction(
656 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
657 auto param0 = builder.AddInstruction(
658 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
659 auto param1 = builder.AddInstruction(
660 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
661 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
662 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
663 auto add = builder.AddInstruction(
664 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
665 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
666 f32vec100_, HloOpcode::kSubtract, add, param1));
667 auto module = CreateNewVerifiedModule();
668 module->AddEntryComputation(builder.Build());
669
670 auto colorer = [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
671 for (HloValue::Id id = 0;
672 id < alias_analysis->dataflow_analysis().values().size(); id++) {
673 auto& value = alias_analysis->dataflow_analysis().GetValue(id);
674 auto& buffer = alias_analysis->GetBufferContainingValue(value);
675 for (const auto& alias : buffer.values()) {
676 if (alias->instruction()->opcode() == HloOpcode::kAdd ||
677 alias->instruction()->opcode() == HloOpcode::kMultiply) {
678 value.set_color(LogicalBuffer::Color(1));
679 }
680 }
681 if (!value.has_color()) {
682 value.set_color(LogicalBuffer::Color(0));
683 }
684 }
685 return OkStatus();
686 };
687
688 auto buffers = RunColoredBufferAssignment(module.get(), colorer);
689
690 // Distinct input buffers were assigned for parameters.
691 BufferAllocation paramscalar_buffer =
692 GetAssignedInputAllocation(*buffers, paramscalar);
693 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
694 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
695 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
696 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
697 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
698
699 // The mul node has a valid buffer assigned, doesn't share with input.
700 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
701 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
702
703 // The add node can reuse the mul node's buffer.
704 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
705 EXPECT_EQ(add_buffer.index(), mul_buffer.index());
706
707 // The sub node has a valid output buffer assigned.
708 GetAssignedOutputAllocation(*buffers, sub);
709
710 // Check if the HLO instructions have the correct colors in the layout.
711 EXPECT_EQ(mul->shape().layout().memory_space(), 1);
712 EXPECT_EQ(add->shape().layout().memory_space(), 1);
713 EXPECT_EQ(sub->shape().layout().memory_space(), 0);
714 EXPECT_EQ(param0->shape().layout().memory_space(), 0);
715 EXPECT_EQ(param1->shape().layout().memory_space(), 0);
716 }
717
TEST_F(BufferAssignmentTest,PresetAssignments)718 TEST_F(BufferAssignmentTest, PresetAssignments) {
719 // paramscalar ------- (mul) -- (add) -- (sub)
720 // / / /
721 // param0[100] -------/ / /
722 // / /
723 // param1[100] --------------/--------/
724 // Similar to BasicPartiallyColored, but the color is set in the layout.
725 // The output of the mul and the add have the color 1 and have preset
726 // assignments, and the other buffers have the color 0, which allows the mul
727 // and add to share buffers.
728 auto builder = HloComputation::Builder(TestName());
729 auto paramscalar =
730 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
731 auto broadcast = builder.AddInstruction(
732 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
733 auto param0 = builder.AddInstruction(
734 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
735 auto param1 = builder.AddInstruction(
736 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
737 Shape f32vec100_color1 =
738 ShapeUtil::MakeShapeWithLayout(F32, {100}, {0}, {}, {}, 0, 1);
739 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
740 f32vec100_color1, HloOpcode::kMultiply, broadcast, param0));
741 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
742 f32vec100_color1, HloOpcode::kAdd, mul, param1));
743 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
744 f32vec100_, HloOpcode::kSubtract, add, param1));
745 auto module = CreateNewVerifiedModule();
746 module->AddEntryComputation(builder.Build());
747
748 auto preset_assignments = std::make_unique<PresetAssignments>();
749 preset_assignments->add_chunk({mul, {}}, {/*offset=*/100, /*size=*/400});
750 preset_assignments->add_chunk({add, {}}, {/*offset=*/550, /*size=*/400});
751 preset_assignments->assignment_information_for_space(/*memory_space=*/1)
752 ->size = 950;
753
754 auto buffers = RunBufferAssignmentWithPresetAssignments(
755 module.get(), std::move(preset_assignments));
756
757 // Distinct input buffers were assigned for parameters.
758 BufferAllocation paramscalar_buffer =
759 GetAssignedInputAllocation(*buffers, paramscalar);
760 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
761 BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
762 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
763 EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
764 EXPECT_EQ(paramscalar_buffer.color(), LogicalBuffer::Color(0));
765 EXPECT_NE(param0_buffer.index(), param1_buffer.index());
766 EXPECT_EQ(param0_buffer.color(), LogicalBuffer::Color(0));
767
768 // The mul and add use the same preset buffer. Ensure it has the correct color
769 // and offsets.
770 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
771 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
772 EXPECT_EQ(mul_buffer, add_buffer);
773 EXPECT_NE(mul_buffer.index(), param0_buffer.index());
774 EXPECT_EQ(mul_buffer.color(), LogicalBuffer::Color(1));
775
776 EXPECT_EQ(mul_buffer.assigned_buffers().size(), 2);
777 for (const auto& value_and_offsetsize : mul_buffer.assigned_buffers()) {
778 if (value_and_offsetsize.first->instruction() == mul) {
779 EXPECT_EQ(value_and_offsetsize.second.offset, 100);
780 EXPECT_EQ(value_and_offsetsize.second.size, 400);
781 } else {
782 EXPECT_EQ(value_and_offsetsize.first->instruction(), add);
783 EXPECT_EQ(value_and_offsetsize.second.offset, 550);
784 EXPECT_EQ(value_and_offsetsize.second.size, 400);
785 }
786 }
787
788 // The sub node has a valid output buffer assigned.
789 GetAssignedOutputAllocation(*buffers, sub);
790 }
791
TEST_F(BufferAssignmentTest,PresetAssignmentsWhile)792 TEST_F(BufferAssignmentTest, PresetAssignmentsWhile) {
793 // Tests preset assignments when there is no 1-to-1 correspondence between
794 // HloValue and HloBuffer (i.e., a while loop).
795 auto module = CreateNewVerifiedModule();
796 Shape f32vec10_color1 =
797 ShapeUtil::MakeShapeWithLayout(F32, {10}, {0}, {}, {}, 0, 1);
798 Shape t_s32_f32v10_color1 =
799 ShapeUtil::MakeTupleShape({s32_, f32vec10_color1});
800
801 auto cond_builder = HloComputation::Builder("WhileCond");
802 HloInstruction* cond_param = cond_builder.AddInstruction(
803 HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "cond_param"));
804 HloInstruction* cond_iter = cond_builder.AddInstruction(
805 HloInstruction::CreateGetTupleElement(s32_, cond_param, 0));
806 HloInstruction* cond_limit = cond_builder.AddInstruction(
807 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(50)));
808 cond_builder.AddInstruction(
809 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
810 cond_limit, ComparisonDirection::kLt));
811 HloComputation* cond_computation =
812 module->AddEmbeddedComputation(cond_builder.Build());
813
814 auto body_builder = HloComputation::Builder("WhileBody");
815 HloInstruction* body_param = body_builder.AddInstruction(
816 HloInstruction::CreateParameter(0, t_s32_f32v10_color1, "body_param"));
817 HloInstruction* body_iter = body_builder.AddInstruction(
818 HloInstruction::CreateGetTupleElement(s32_, body_param, 0));
819 HloInstruction* body_data = body_builder.AddInstruction(
820 HloInstruction::CreateGetTupleElement(f32vec10_color1, body_param, 1));
821 HloInstruction* body_data_increment = body_builder.AddInstruction(
822 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
823 {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f})));
824 HloInstruction* body_data_next =
825 body_builder.AddInstruction(HloInstruction::CreateBinary(
826 f32vec10_color1, HloOpcode::kAdd, body_data, body_data_increment));
827 HloInstruction* body_iter_increment = body_builder.AddInstruction(
828 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
829 HloInstruction* body_iter_next =
830 body_builder.AddInstruction(HloInstruction::CreateBinary(
831 s32_, HloOpcode::kAdd, body_iter, body_iter_increment));
832 body_builder.AddInstruction(
833 HloInstruction::CreateTuple({body_iter_next, body_data_next}));
834 HloComputation* body_computation =
835 module->AddEmbeddedComputation(body_builder.Build());
836
837 auto builder = HloComputation::Builder(TestName());
838 HloInstruction* iter = builder.AddInstruction(
839 HloInstruction::CreateParameter(0, s32_, "param_iter"));
840 HloInstruction* data = builder.AddInstruction(
841 HloInstruction::CreateParameter(1, f32vec10_, "param_data"));
842 HloInstruction* negate = builder.AddInstruction(
843 HloInstruction::CreateUnary(f32vec10_color1, HloOpcode::kNegate, data));
844 HloInstruction* tuple =
845 builder.AddInstruction(HloInstruction::CreateTuple({iter, negate}));
846 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
847 t_s32_f32v10_color1, cond_computation, body_computation, tuple));
848 HloInstruction* while_data = builder.AddInstruction(
849 HloInstruction::CreateGetTupleElement(f32vec10_color1, while_op, 1));
850 builder.AddInstruction(HloInstruction::CreateBinary(
851 f32vec10_, HloOpcode::kAdd, while_data, data));
852 module->AddEntryComputation(builder.Build());
853
854 // Set only one preset assignment for while data and its aliases.
855 auto preset_assignments = std::make_unique<PresetAssignments>();
856 preset_assignments->add_chunk({negate, {}}, {/*offset=*/100, /*size=*/40});
857 preset_assignments->assignment_information_for_space(/*memory_space=*/1)
858 ->size = 140;
859
860 auto buffers = RunBufferAssignmentWithPresetAssignments(
861 module.get(), std::move(preset_assignments));
862
863 // All assigned buffers are aliased so they should have the same offset and
864 // size.
865 const BufferAllocation& data_buffer = GetTopLevelAllocation(*buffers, negate);
866 EXPECT_EQ(data_buffer.assigned_buffers().size(), 5);
867 for (const auto& value_and_offsetsize : data_buffer.assigned_buffers()) {
868 EXPECT_EQ(value_and_offsetsize.second.offset, 100);
869 EXPECT_EQ(value_and_offsetsize.second.size, 40);
870 EXPECT_EQ(value_and_offsetsize.first->color(), LogicalBuffer::Color(1));
871 }
872 }
873
TEST_F(BufferAssignmentTest,MultipleUsersForNode)874 TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
875 // This is similar to the Basic test, with the difference that (sub) is
876 // another user of (mul)'s result, so (mul)'s buffer cannot be reused for
877 // (add)'s output.
878 //
879 // paramscalar -------\ /-----------\
880 // \ / \
881 // param0[100] ------- (mul) -- (add) -- (sub)
882 // /
883 // param1[100] ----------------/
884 //
885 auto builder = HloComputation::Builder(TestName());
886 auto paramscalar =
887 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
888 auto broadcast = builder.AddInstruction(
889 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
890 auto param0 = builder.AddInstruction(
891 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
892 auto param1 = builder.AddInstruction(
893 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
894 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
895 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
896 auto add = builder.AddInstruction(
897 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
898 auto sub = builder.AddInstruction(
899 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
900 auto module = CreateNewVerifiedModule();
901 module->AddEntryComputation(builder.Build());
902
903 auto buffers = RunBufferAssignment(module.get());
904
905 // Input buffers were assigned for parameters.
906 BufferAllocation paramscalar_buffer =
907 GetAssignedInputAllocation(*buffers, paramscalar);
908 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
909 BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1);
910 EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
911 EXPECT_NE(paramscalar_buffer.index(), param1_index.index());
912 EXPECT_NE(param0_buffer.index(), param1_index.index());
913
914 // The mul node had a buffer allocated.
915 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
916
917 // Now the add node can't reuse the mul node's buffer.
918 const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
919 EXPECT_NE(add_buffer.index(), mul_buffer.index());
920
921 // Log size information for inspection.
922 const std::vector<const HloInstruction*> level0 = GetInstructions(sub);
923 int64_t size0 = ValidateBuffers(level0, *buffers);
924 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
925 << " for " << level0.size() << " instructions; "
926 << "total buffer size " << size0;
927 }
928
TEST_F(BufferAssignmentTest,TrivialMap)929 TEST_F(BufferAssignmentTest, TrivialMap) {
930 // This tests a trivial x+1 map as the only operation.
931 //
932 // param0[100x10] ---> (map x+1)
933 //
934 // Builds the map function.
935 auto module = CreateNewVerifiedModule();
936 auto map_computation =
937 module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
938 auto inner_last = map_computation->root_instruction();
939
940 // Creates the main kernel and verifies instruction counts.
941 auto builder = HloComputation::Builder(TestName());
942 auto param0 = builder.AddInstruction(
943 HloInstruction::CreateParameter(0, f32a100x10_, "p"));
944 auto map = builder.AddInstruction(
945 HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
946 module->AddEntryComputation(builder.Build());
947
948 const std::vector<const HloInstruction*> level0 = GetInstructions(map);
949 EXPECT_EQ(2, level0.size()) << "Invalid main kernel size";
950 const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last);
951 EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
952
953 // Assigns buffers and fetches sizes.
954 auto buffers = RunBufferAssignment(module.get());
955 int64_t size0 = ValidateBuffers(level0, *buffers);
956 int64_t size1 = ValidateBuffers(level1, *buffers);
957
958 // Both algorithms assign the map's buffer before processing the embedded
959 // computation, so we can verify that the buffers aren't shared between them
960 // by checking:
961 EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers))
962 << "Reuse between main kernel and embedded mapping.";
963
964 // An input buffer was assigned for the parameter.
965 BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
966
967 // An output buffer was assigned for the map.
968 BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map);
969 EXPECT_NE(param0_buffer.index(), map_buffer.index());
970
971 // The final computation node of the map is an add of an f32 param and a
972 // constant.
973 EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode());
974 const BufferAllocation& inner_add_buffer =
975 GetTopLevelAllocation(*buffers, inner_last);
976 EXPECT_NE(inner_add_buffer.index(), map_buffer.index());
977
978 // Log size information for inspection.
979 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
980 << " for " << level0.size() + level1.size() << " instructions; "
981 << "total buffer size " << size0 + size1;
982 }
983
TEST_F(BufferAssignmentTest,CannotReuseInputBufferOfReduce)984 TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
985 // Make sure that the input buffer of a reduce cannot be reused for its
986 // output. (Reuse is not safe in the general case, as it reshapes and some
987 // out-of-order reductions could overwrite an element before a use.)
988 //
989 // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3)
990 auto module = CreateNewVerifiedModule();
991 auto reduce_computation =
992 module->AddEmbeddedComputation(BuildReduceComputation("f32+f32"));
993
994 auto builder = HloComputation::Builder(TestName());
995 auto param0 = builder.AddInstruction(
996 HloInstruction::CreateParameter(0, f32a100x10_, "p"));
997 auto exp1 = builder.AddInstruction(
998 HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
999 auto exp2 = builder.AddInstruction(
1000 HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
1001 auto const0 = builder.AddInstruction(
1002 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
1003 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
1004 /*shape=*/f32vec10_,
1005 /*operand=*/exp2,
1006 /*init_value=*/const0,
1007 /*dimensions_to_reduce=*/{0}, reduce_computation));
1008 auto exp3 = builder.AddInstruction(
1009 HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce));
1010
1011 module->AddEntryComputation(builder.Build());
1012
1013 auto buffers = RunBufferAssignment(module.get());
1014 const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
1015 ValidateBuffers(instrs, *buffers);
1016
1017 const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1);
1018 const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2);
1019 const BufferAllocation& reduce_buffer =
1020 GetTopLevelAllocation(*buffers, reduce);
1021
1022 // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity
1023 // checking.
1024 EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index());
1025
1026 // The buffer of exp2 cannot be used for reduce, even though it's the only
1027 // operand.
1028 EXPECT_NE(exp2_buffer.index(), reduce_buffer.index());
1029 }
1030
TEST_F(BufferAssignmentTest,ExampleWhile)1031 TEST_F(BufferAssignmentTest, ExampleWhile) {
1032 // This tests a While loop example from the ir_semantics document.
1033 //
1034 // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation.
1035 // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation.
1036 //
1037 // const3[s32] -------\
1038 // const4[f32[4]] --- tuple --- while[condition, body]
1039 //
1040 // Builds the nested condition and body.
1041 auto module = CreateNewVerifiedModule();
1042 auto condition_computation =
1043 module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
1044 auto body_computation =
1045 module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update"));
1046
1047 // Creates the main kernel and verifies instruction counts.
1048 auto builder = HloComputation::Builder(TestName());
1049 auto const3 = builder.AddInstruction(
1050 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
1051 auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
1052 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
1053 auto tuple =
1054 builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
1055 auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
1056 t_s32_f32v4_, condition_computation, body_computation, tuple));
1057 module->AddEntryComputation(builder.Build());
1058
1059 const std::vector<const HloInstruction*> level0 = GetInstructions(while_op);
1060 EXPECT_EQ(4, level0.size()) << "Invalid while kernel size";
1061 const std::vector<const HloInstruction*> levelc =
1062 GetInstructions(condition_computation->root_instruction());
1063 EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size";
1064 const std::vector<const HloInstruction*> levelb =
1065 GetInstructions(body_computation->root_instruction());
1066 EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
1067
1068 // Assigns buffers and fetches sizes.
1069 auto buffers = RunBufferAssignment(module.get());
1070 int64_t size0 = ValidateBuffers(level0, *buffers);
1071 int64_t sizec = ValidateBuffers(levelc, *buffers);
1072 int64_t sizeb = ValidateBuffers(levelb, *buffers);
1073
1074 // BufferAssignment will assign a single allocation for the following
1075 // instructions: while, while.cond.param, while.body.param, while.body.result.
1076 EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers))
1077 << "Should be reuse between main kernel and embedded condition.";
1078 EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers))
1079 << "Should be reuse between embedded condition and body.";
1080 // Expect buffer reuse between main kernel and body computation.
1081 EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers))
1082 << "Should be reuse between main kernel and embedded body.";
1083
1084 // The final computation node of the while body is a tuple of s32 and
1085 // f32[4] adds.
1086 HloInstruction* body_root = body_computation->root_instruction();
1087 EXPECT_EQ(HloOpcode::kTuple, body_root->opcode());
1088
1089 // Check that buffer for each subshape of 'while_op' shares allocation with
1090 // corresponding buffer from while body computation at same index.
1091 ShapeUtil::ForEachSubshape(
1092 while_op->shape(),
1093 [this, &buffers, while_op, body_root](const Shape& /*subshape*/,
1094 const ShapeIndex& index) {
1095 auto while_op_allocation = GetAllocation(*buffers, while_op, index);
1096 auto body_root_allocation = GetAllocation(*buffers, body_root, index);
1097 EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index());
1098 });
1099
1100 // Log size information for inspection.
1101 LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
1102 << " for " << level0.size() + levelc.size() + levelb.size()
1103 << " instructions; total buffer size " << size0 + sizec + sizeb;
1104 }
1105
TEST_F(BufferAssignmentTest,ExampleConditional)1106 TEST_F(BufferAssignmentTest, ExampleConditional) {
1107 auto module = CreateNewVerifiedModule();
1108 auto true_computation = module->AddEmbeddedComputation(
1109 BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil"));
1110 auto false_computation = module->AddEmbeddedComputation(
1111 BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor"));
1112
1113 auto builder = HloComputation::Builder(TestName());
1114 auto pred = builder.AddInstruction(
1115 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1116 auto const1 = builder.AddInstruction(
1117 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
1118 auto const2 = builder.AddInstruction(
1119 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
1120 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1121 r0f32_, pred, const1, true_computation, const2, false_computation));
1122 module->AddEntryComputation(builder.Build());
1123
1124 const std::vector<const HloInstruction*> conditional_instrs =
1125 GetInstructions(conditional);
1126 const std::vector<const HloInstruction*> true_instrs =
1127 GetInstructions(true_computation->root_instruction());
1128 const std::vector<const HloInstruction*> false_instrs =
1129 GetInstructions(false_computation->root_instruction());
1130 EXPECT_EQ(4, conditional_instrs.size());
1131 EXPECT_EQ(2, true_instrs.size());
1132 EXPECT_EQ(2, false_instrs.size());
1133
1134 auto buffers = RunBufferAssignment(module.get());
1135 ValidateBuffers(conditional_instrs, *buffers);
1136 ValidateBuffers(true_instrs, *buffers);
1137 ValidateBuffers(false_instrs, *buffers);
1138
1139 EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers))
1140 << "Should be reuse between conditional and true computation.";
1141 EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers))
1142 << "Should be reuse between conditional and false computation.";
1143 EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers))
1144 << "Should be reuse between true and false computations.";
1145
1146 const BufferAllocation& conditional_buffer =
1147 GetTopLevelAllocation(*buffers, conditional);
1148 const BufferAllocation& true_buffer =
1149 GetTopLevelAllocation(*buffers, true_computation->root_instruction());
1150 const BufferAllocation& false_buffer =
1151 GetTopLevelAllocation(*buffers, false_computation->root_instruction());
1152 EXPECT_EQ(conditional_buffer.size(), true_buffer.size());
1153 EXPECT_EQ(conditional_buffer.size(), false_buffer.size());
1154 }
1155
TEST_F(BufferAssignmentTest,UnaryOpReuseChain)1156 TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
1157 // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
1158 auto builder = HloComputation::Builder(TestName());
1159 auto param0 = builder.AddInstruction(
1160 HloInstruction::CreateParameter(0, f32vec100_, "p"));
1161 auto exp1 = builder.AddInstruction(
1162 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
1163 auto tanh = builder.AddInstruction(
1164 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1));
1165 auto exp2 = builder.AddInstruction(
1166 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh));
1167 auto neg = builder.AddInstruction(
1168 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
1169
1170 auto module = CreateNewVerifiedModule();
1171 module->AddEntryComputation(builder.Build());
1172 auto assignment = RunBufferAssignment(module.get());
1173
1174 // tanh and exp2 can reuse exp1's buffer
1175 EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
1176 auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1);
1177 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh));
1178 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2));
1179 EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg));
1180 }
1181
TEST_F(BufferAssignmentTest,ReuseNonOperandBuffer)1182 TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
1183 // This computation is a chain of operations which decreases in buffer size
1184 // (via slice) then increases in size (via broadcast):
1185 //
1186 // param ---> (negate) ---> (slice) ---> (broadcast)
1187 //
1188 // The negate should share a buffer with broadcast.
1189 auto builder = HloComputation::Builder(TestName());
1190 auto param0 = builder.AddInstruction(
1191 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1192 auto negate = builder.AddInstruction(
1193 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1194 auto slice = builder.AddInstruction(
1195 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1196 auto broadcast = builder.AddInstruction(
1197 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1198
1199 auto module = CreateNewVerifiedModule();
1200 module->AddEntryComputation(builder.Build());
1201 auto assignment = RunBufferAssignment(module.get());
1202
1203 // negate and broadcast should share a buffer.
1204 EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1205 auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1206 EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1207
1208 // Slice should have its own buffer.
1209 EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1210 }
1211
TEST_F(BufferAssignmentTest,NoReuseLiveBuffer)1212 TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
1213 // This computation is identical to that in ReuseNonOperandBuffer, but the
1214 // negate value is live until the end of the computation (due to it being an
1215 // operand of the output tuple) preventing reuse.
1216 //
1217 // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple)
1218 // \-----------------------------------/
1219 //
1220 // The negate should not share a buffer with broadcast.
1221 auto builder = HloComputation::Builder(TestName());
1222 auto param0 = builder.AddInstruction(
1223 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1224 auto negate = builder.AddInstruction(
1225 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1226 auto slice = builder.AddInstruction(
1227 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1228 auto broadcast = builder.AddInstruction(
1229 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1230 builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
1231
1232 auto module = CreateNewVerifiedModule();
1233 module->AddEntryComputation(builder.Build());
1234 auto assignment = RunBufferAssignment(module.get());
1235
1236 // The instructions should not share buffers.
1237 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1238 GetTopLevelAllocation(*assignment, negate));
1239 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1240 GetTopLevelAllocation(*assignment, slice));
1241 EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1242 GetTopLevelAllocation(*assignment, slice));
1243 }
1244
TEST_F(BufferAssignmentTest,NoReuseAliasedBuffer)1245 TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
1246 // This computation is identical to that in ReuseNonOperandBuffer, but the
1247 // negate value is placed into a tuple which lives to the end of the
1248 // computation. This extends the live range of negate's buffer preventing
1249 // reuse due to buffer aliasing.
1250 //
1251 // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple)
1252 // \-----------------------------------/
1253 //
1254 // The negate should not share a buffer with broadcast.
1255 auto builder = HloComputation::Builder(TestName());
1256 auto param0 = builder.AddInstruction(
1257 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1258 auto negate = builder.AddInstruction(
1259 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1260 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate}));
1261 auto tuple_element = builder.AddInstruction(
1262 HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
1263 auto slice = builder.AddInstruction(
1264 HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
1265 auto broadcast = builder.AddInstruction(
1266 HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
1267 builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
1268
1269 auto module = CreateNewVerifiedModule();
1270 module->AddEntryComputation(builder.Build());
1271 auto assignment = RunBufferAssignment(module.get());
1272
1273 // The instructions should not share buffers.
1274 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1275 GetTopLevelAllocation(*assignment, negate));
1276 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1277 GetTopLevelAllocation(*assignment, slice));
1278 EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
1279 GetTopLevelAllocation(*assignment, slice));
1280 }
1281
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBuffer)1282 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
1283 // This computation is very similar to ReuseNonOperandBuffer except the
1284 // broadcast has a smaller output than the negate. This should block reuse of
1285 // negate's buffer by broadcast because the output buffer(s) of a computation
1286 // should be exactly sized for the value.
1287 //
1288 // param ---> (negate) ---> (slice) ---> (broadcast)
1289 //
1290 // Neither negate nor slice may share a buffer with broadcast.
1291 auto builder = HloComputation::Builder(TestName());
1292 auto param0 = builder.AddInstruction(
1293 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1294 // Negate output is 100 elements.
1295 auto negate = builder.AddInstruction(
1296 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1297 // Slice output is 10 elements.
1298 auto slice = builder.AddInstruction(
1299 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1300 // Broadcast output is 40 elements.
1301 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1302 ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1303
1304 auto module = CreateNewVerifiedModule();
1305 module->AddEntryComputation(builder.Build());
1306 auto assignment = RunBufferAssignment(module.get());
1307
1308 // The broadcast output buffer cannot be shared.
1309 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1310 GetTopLevelAllocation(*assignment, negate));
1311 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1312 GetTopLevelAllocation(*assignment, slice));
1313 }
1314
TEST_F(BufferAssignmentTest,ReuseOutputBufferIfExactlySized)1315 TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
1316 // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast
1317 // output is exactly the same size as the negate (rather than being
1318 // smaller). This enables reuse of negate's buffer by the broadcast because
1319 // the output buffer will be sized exactly to its value.
1320 //
1321 // param ---> (negate) ---> (slice) ---> (broadcast)
1322 //
1323 // The negate should *not* share a buffer with broadcast.
1324 auto builder = HloComputation::Builder(TestName());
1325 auto param0 = builder.AddInstruction(
1326 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1327 // Negate output is 100 elements.
1328 auto negate = builder.AddInstruction(
1329 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1330 auto slice = builder.AddInstruction(
1331 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1332 // Broadcast output is 40 elements.
1333 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1334 ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
1335
1336 auto module = CreateNewVerifiedModule();
1337 module->AddEntryComputation(builder.Build());
1338 auto assignment = RunBufferAssignment(module.get());
1339
1340 // negate and broadcast should share a buffer.
1341 EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
1342 auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
1343 EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
1344
1345 // Slice should have its own buffer.
1346 EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
1347 }
1348
TEST_F(BufferAssignmentTest,DoNotReuseOversizedOutputBufferInTuple)1349 TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
1350 // This computation is very similar to ReuseNonOperandBuffer except the
1351 // broadcast has a smaller output than the negate, and the broadcast is
1352 // contained in the computation output as a tuple element. This should block
1353 // reuse of the negate's buffer by the broadcast because the output buffer(s)
1354 // of a computation should be exactly sized for the value. This includes those
1355 // buffers aliased in the output (eg, contained as tuple elements).
1356 //
1357 // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple)
1358 //
1359 // Neither negate nor slice may share a buffer with broadcast.
1360 auto builder = HloComputation::Builder(TestName());
1361 auto param0 = builder.AddInstruction(
1362 HloInstruction::CreateParameter(0, f32vec100_, "param0"));
1363 // Negate output is 100 elements.
1364 auto negate = builder.AddInstruction(
1365 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
1366 // Slice output is 10 elements.
1367 auto slice = builder.AddInstruction(
1368 HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
1369 // Broadcast output is 40 elements.
1370 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
1371 ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
1372 builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
1373
1374 auto module = CreateNewVerifiedModule();
1375 module->AddEntryComputation(builder.Build());
1376 auto assignment = RunBufferAssignment(module.get());
1377
1378 // The broadcast output buffer cannot be shared.
1379 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1380 GetTopLevelAllocation(*assignment, negate));
1381 EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
1382 GetTopLevelAllocation(*assignment, slice));
1383 }
1384
TEST_F(BufferAssignmentTest,EmbeddedComputationBuffers)1385 TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
1386 // Verify that buffers for embedded computations are properly marked as
1387 // thread-local and that embedded parameters are not marked as
1388 // is_entry_computation_parameter.
1389 auto module = CreateNewVerifiedModule();
1390 auto vec_shape = ShapeUtil::MakeShape(F32, {42});
1391 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1392
1393 // Create a scalar computation to use in a map.
1394 auto map_builder = HloComputation::Builder(TestName() + "_map");
1395 auto map_param = map_builder.AddInstruction(
1396 HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1397 auto map_root = map_builder.AddInstruction(
1398 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1399 auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1400
1401 // Create a vector computation to use in a kCall.
1402 auto call_builder = HloComputation::Builder(TestName() + "_call");
1403 auto call_param = call_builder.AddInstruction(
1404 HloInstruction::CreateParameter(0, vec_shape, "vec_param"));
1405 auto call_root = call_builder.AddInstruction(
1406 HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param));
1407 auto call_computation = module->AddEmbeddedComputation(call_builder.Build());
1408
1409 // Create entry computation which kCalls call_computation and then calls map
1410 // with map_computation on the result.
1411 auto builder = HloComputation::Builder(TestName());
1412 auto param = builder.AddInstruction(
1413 HloInstruction::CreateParameter(0, vec_shape, "param"));
1414 auto call = builder.AddInstruction(
1415 HloInstruction::CreateCall(vec_shape, {param}, call_computation));
1416 auto map = builder.AddInstruction(
1417 HloInstruction::CreateMap(vec_shape, {call}, map_computation));
1418 module->AddEntryComputation(builder.Build());
1419
1420 auto assignment = RunBufferAssignment(module.get());
1421
1422 // Allocations for the map computation should be thread-local and not
1423 // live-out.
1424 auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1425 EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1426 EXPECT_FALSE(map_param_alloc.maybe_live_out());
1427 EXPECT_TRUE(map_param_alloc.is_thread_local());
1428
1429 auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1430 EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1431 EXPECT_FALSE(map_root_alloc.maybe_live_out());
1432 EXPECT_TRUE(map_root_alloc.is_thread_local());
1433
1434 // Allocations for the call computation should not be thread-local.
1435 auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
1436 EXPECT_TRUE(call_param_alloc.is_entry_computation_parameter());
1437 EXPECT_FALSE(call_param_alloc.maybe_live_out());
1438 EXPECT_FALSE(call_param_alloc.is_thread_local());
1439
1440 auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
1441 EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
1442 EXPECT_FALSE(call_root_alloc.is_thread_local());
1443
1444 // Entry computation allocations can be marked liveout and
1445 // is_entry_computation_parameter.
1446 auto& param_alloc = GetTopLevelAllocation(*assignment, param);
1447 EXPECT_TRUE(param_alloc.is_entry_computation_parameter());
1448 EXPECT_FALSE(param_alloc.maybe_live_out());
1449 EXPECT_FALSE(param_alloc.is_thread_local());
1450
1451 auto& map_alloc = GetTopLevelAllocation(*assignment, map);
1452 EXPECT_FALSE(map_alloc.is_entry_computation_parameter());
1453 EXPECT_TRUE(map_alloc.maybe_live_out());
1454 EXPECT_FALSE(map_alloc.is_thread_local());
1455 }
1456
TEST_F(BufferAssignmentTest,CustomCallEmbeddedComputationBuffers)1457 TEST_F(BufferAssignmentTest, CustomCallEmbeddedComputationBuffers) {
1458 // Verify that buffers for embedded computations in a custom call are properly
1459 // marked as thread-local.
1460 auto module = CreateNewVerifiedModule();
1461 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1462
1463 // Create a scalar computation to use in a map.
1464 auto map_builder = HloComputation::Builder(TestName() + "_map");
1465 auto map_param = map_builder.AddInstruction(
1466 HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
1467 auto map_root = map_builder.AddInstruction(
1468 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
1469 auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1470
1471 // Create entry computation with a custom call on map_computation.
1472 auto builder = HloComputation::Builder(TestName());
1473 auto param = builder.AddInstruction(
1474 HloInstruction::CreateParameter(0, scalar_shape, "param"));
1475 builder.AddInstruction(HloInstruction::CreateCustomCall(
1476 scalar_shape, {param}, map_computation, "call_name"));
1477 module->AddEntryComputation(builder.Build());
1478
1479 auto assignment = RunBufferAssignment(module.get());
1480
1481 // Allocations for the map computation should be thread-local and not
1482 // live-out.
1483 auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
1484 EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
1485 EXPECT_FALSE(map_param_alloc.maybe_live_out());
1486 EXPECT_TRUE(map_param_alloc.is_thread_local());
1487
1488 auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
1489 EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
1490 EXPECT_FALSE(map_root_alloc.maybe_live_out());
1491 EXPECT_TRUE(map_root_alloc.is_thread_local());
1492 }
1493
TEST_F(BufferAssignmentTest,TupleParameterAsOutput)1494 TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
1495 // Test a computation that returns a tuple parameter.
1496 auto builder = HloComputation::Builder(TestName());
1497 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1498 0,
1499 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1500 ShapeUtil::MakeShape(F32, {}),
1501 ShapeUtil::MakeShape(S32, {42})}),
1502 "param0"));
1503
1504 auto module = CreateNewVerifiedModule();
1505 module->AddEntryComputation(builder.Build());
1506 auto assignment = RunBufferAssignment(module.get());
1507
1508 // There should be four allocations: one for vector of pointers, and one for
1509 // each tuple element.
1510 EXPECT_EQ(4, assignment->Allocations().size());
1511
1512 // Verify each buffer allocation is marked as an entry computation parameter
1513 // and is liveout.
1514 ShapeUtil::ForEachSubshape(
1515 tuple_param->shape(),
1516 [this, &assignment, tuple_param](const Shape& /*subshape*/,
1517 const ShapeIndex& index) {
1518 auto allocation = GetAllocation(*assignment, tuple_param, index);
1519 EXPECT_TRUE(allocation.is_entry_computation_parameter());
1520 EXPECT_EQ(0, allocation.parameter_number());
1521 EXPECT_TRUE(allocation.maybe_live_out());
1522 });
1523 }
1524
TEST_F(BufferAssignmentTest,ElementOfNestedTupleParameterAsOutput)1525 TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
1526 // Test a computation which returns a GetElementTuple of a nested tuple
1527 // parameter.
1528 auto builder = HloComputation::Builder(TestName());
1529 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
1530 0,
1531 ShapeUtil::MakeTupleShape(
1532 {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1533 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}),
1534 ShapeUtil::MakeShape(S32, {101})})}),
1535 "param0"));
1536 auto tuple_element =
1537 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1538 ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
1539
1540 auto module = CreateNewVerifiedModule();
1541 module->AddEntryComputation(builder.Build());
1542 auto assignment = RunBufferAssignment(module.get());
1543
1544 // Only some of the elements of the input param are liveout.
1545 EXPECT_FALSE(
1546 GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out());
1547 // Tuple element at index={1} is live out because GetTupleElement({1})
1548 // forwards a pointer to this allocation (instead of defining its own buffer).
1549 EXPECT_TRUE(
1550 GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out());
1551 EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0})
1552 .maybe_live_out());
1553 EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1})
1554 .maybe_live_out());
1555
1556 // The GetTupleElement output is liveout.
1557 EXPECT_TRUE(
1558 GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out());
1559
1560 // Verify that the GetTupleElement allocations of its elements match the
1561 // corresponding tuple parameter allocations because they alias.
1562 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}),
1563 GetAllocation(*assignment, tuple_element, /*index=*/{0}));
1564 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}),
1565 GetAllocation(*assignment, tuple_element, /*index=*/{1}));
1566
1567 // GetTupleElement forwards a pointer to its underlying buffer, so verify
1568 // that it has the same allocation than the corresponding parameter element.
1569 EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}),
1570 GetTopLevelAllocation(*assignment, tuple_element));
1571 }
1572
1573 // TODO(b/32248867): Enable when buffer assignment gives allocations to
1574 // constants.
TEST_F(BufferAssignmentTest,TupleConstantAsOutput)1575 TEST_F(BufferAssignmentTest, TupleConstantAsOutput) {
1576 // Test that a tuple constant which is forwarded to the computation output
1577 // is properly handled.
1578 auto builder = HloComputation::Builder(TestName());
1579 Literal elements[] = {LiteralUtil::CreateR0<int64_t>(0),
1580 LiteralUtil::CreateR0<int64_t>(1)};
1581 builder.AddInstruction(HloInstruction::CreateConstant(
1582 LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
1583
1584 auto module = CreateNewVerifiedModule();
1585 module->AddEntryComputation(builder.Build());
1586 auto assignment = RunBufferAssignment(module.get());
1587
1588 EXPECT_EQ(3, assignment->Allocations().size());
1589 }
1590
TEST_F(BufferAssignmentTest,TupleCustomCallAsOutput)1591 TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
1592 // Test a computation which returns a tuple custom call value.
1593 auto builder = HloComputation::Builder(TestName());
1594 auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall(
1595 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
1596 ShapeUtil::MakeShape(S32, {101})}),
1597 /*operands=*/{}, /*custom_call_target=*/"foo_function"));
1598 auto module = CreateNewVerifiedModule();
1599 module->AddEntryComputation(builder.Build());
1600 auto assignment = RunBufferAssignment(module.get());
1601
1602 EXPECT_EQ(3, assignment->Allocations().size());
1603 EXPECT_TRUE(
1604 GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out());
1605 EXPECT_TRUE(
1606 GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out());
1607 EXPECT_TRUE(
1608 GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
1609 }
1610
TEST_F(BufferAssignmentTest,CustomCallAliasedBuffer)1611 TEST_F(BufferAssignmentTest, CustomCallAliasedBuffer) {
1612 // Test a computation with custom call aliasing.
1613 const char* const kModuleString = R"(
1614 HloModule xla_computation_f
1615 ENTRY xla_computation_f {
1616 parameter.1 = f32[2,3,4,5] parameter(0)
1617 parameter.2 = f32[2,3,4,5] parameter(1)
1618 add = f32[2,3,4,5] add(parameter.1, parameter.2)
1619 ROOT custom-call = f32[2,3,4,5] custom-call(add, parameter.2), custom_call_target="dm_softmax", operand_layout_constraints={f32[2,3,4,5], f32[2,3,4,5]}, output_to_operand_aliasing={{}: (0, {})}
1620 }
1621 )";
1622
1623 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
1624 ParseAndReturnUnverifiedModule(kModuleString));
1625 std::unique_ptr<BufferAssignment> assignment =
1626 RunBufferAssignment(module.get());
1627 HloInstruction* custom_call = module->entry_computation()->root_instruction();
1628 EXPECT_TRUE(
1629 assignment->SharesTopLevelSlice(custom_call, custom_call->operand(0)));
1630 }
1631
TEST_F(BufferAssignmentTest,TupleCallAsOutput)1632 TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
1633 // Test a computation which returns a tuple call value.
1634 auto module = CreateNewVerifiedModule();
1635 auto elem_shape = f32vec4_;
1636 auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1637
1638 auto sub_builder = HloComputation::Builder(TestName() + "_sub");
1639 auto sub_param = sub_builder.AddInstruction(
1640 HloInstruction::CreateParameter(0, elem_shape, "sub_param"));
1641 auto sub_tuple =
1642 sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param}));
1643 auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
1644
1645 auto builder = HloComputation::Builder(TestName());
1646 auto param = builder.AddInstruction(
1647 HloInstruction::CreateParameter(0, elem_shape, "param"));
1648 auto call = builder.AddInstruction(
1649 HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
1650 module->AddEntryComputation(builder.Build());
1651
1652 auto assignment = RunBufferAssignment(module.get());
1653
1654 EXPECT_EQ(2, assignment->Allocations().size());
1655 // Buffers for call are colocated with the sub-computation.
1656 EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}),
1657 GetAllocation(*assignment, sub_tuple, /*index=*/{}));
1658 EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}),
1659 GetAllocation(*assignment, sub_param, /*index=*/{}));
1660
1661 // The parameter isn't aliased with the result tuple, but it is aliased with
1662 // the call operand.
1663 EXPECT_NE(GetTopLevelAllocation(*assignment, param),
1664 GetTopLevelAllocation(*assignment, sub_tuple));
1665 EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1666 GetTopLevelAllocation(*assignment, sub_param));
1667 }
1668
TEST_F(BufferAssignmentTest,TupleChainedCallAsOutput)1669 TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) {
1670 // Test a chain of calls with tuple output. The chain looks like:
1671 // A: call(B, tuple(param))
1672 // B: call(C, param)
1673 // C: call(D, param)
1674 // D: param
1675 auto module = CreateNewVerifiedModule();
1676 auto elem_shape = f32vec4_;
1677 auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
1678
1679 auto d_builder = HloComputation::Builder(TestName() + "_d");
1680 auto d_param = d_builder.AddInstruction(
1681 HloInstruction::CreateParameter(0, tuple_shape, "d_param"));
1682 auto d_computation = d_builder.Build();
1683
1684 auto c_builder = HloComputation::Builder(TestName() + "_c");
1685 auto c_param = c_builder.AddInstruction(
1686 HloInstruction::CreateParameter(0, tuple_shape, "c_param"));
1687 auto c_call = c_builder.AddInstruction(
1688 HloInstruction::CreateCall(tuple_shape, {c_param}, d_computation.get()));
1689 auto c_computation = c_builder.Build();
1690
1691 auto b_builder = HloComputation::Builder(TestName() + "_b");
1692 auto b_param = b_builder.AddInstruction(
1693 HloInstruction::CreateParameter(0, tuple_shape, "b_param"));
1694 auto b_call = b_builder.AddInstruction(
1695 HloInstruction::CreateCall(tuple_shape, {b_param}, c_computation.get()));
1696 auto b_computation = b_builder.Build();
1697
1698 auto a_builder = HloComputation::Builder(TestName());
1699 auto a_param = a_builder.AddInstruction(
1700 HloInstruction::CreateParameter(0, elem_shape, "param"));
1701 auto a_tuple =
1702 a_builder.AddInstruction(HloInstruction::CreateTuple({a_param}));
1703 auto a_call = a_builder.AddInstruction(
1704 HloInstruction::CreateCall(tuple_shape, {a_tuple}, b_computation.get()));
1705 auto a_computation = a_builder.Build();
1706
1707 // Add the computations in an order that doesn't match the dependency
1708 // post-order, to shake out more possible bugs.
1709 module->AddEmbeddedComputation(std::move(d_computation));
1710 module->AddEmbeddedComputation(std::move(c_computation));
1711 module->AddEntryComputation(std::move(a_computation));
1712 module->AddEmbeddedComputation(std::move(b_computation));
1713
1714 auto assignment = RunBufferAssignment(module.get());
1715
1716 // Buffers for call are colocated with the sub-computations.
1717 EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}),
1718 GetAllocation(*assignment, b_call, /*index=*/{}));
1719 EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}),
1720 GetAllocation(*assignment, c_call, /*index=*/{}));
1721 EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{}),
1722 GetAllocation(*assignment, d_param, /*index=*/{}));
1723 EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{0}),
1724 GetAllocation(*assignment, b_call, /*index=*/{0}));
1725 EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{0}),
1726 GetAllocation(*assignment, c_call, /*index=*/{0}));
1727 EXPECT_EQ(GetAllocation(*assignment, c_call, /*index=*/{0}),
1728 GetAllocation(*assignment, d_param, /*index=*/{0}));
1729
1730 EXPECT_TRUE(BuffersDistinct({a_param}, {b_param}, *assignment));
1731 EXPECT_TRUE(BuffersDistinct({a_param}, {c_param}, *assignment));
1732 EXPECT_TRUE(BuffersDistinct({a_param}, {d_param}, *assignment));
1733
1734 EXPECT_EQ(GetAllocation(*assignment, b_param, /*index=*/{0}),
1735 GetAllocation(*assignment, c_param, /*index=*/{0}));
1736 EXPECT_EQ(GetAllocation(*assignment, c_param, /*index=*/{0}),
1737 GetAllocation(*assignment, d_param, /*index=*/{0}));
1738 }
1739
TEST_F(BufferAssignmentTest,BitcastAsOutput)1740 TEST_F(BufferAssignmentTest, BitcastAsOutput) {
1741 // Test a computation which returns a bitcast value.
1742 auto builder = HloComputation::Builder(TestName());
1743 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
1744 0, ShapeUtil::MakeShape(F32, {42}), "param"));
1745 auto bitcast = builder.AddInstruction(
1746 HloInstruction::CreateBitcast(param->shape(), param));
1747
1748 auto module = CreateNewVerifiedModule();
1749 module->AddEntryComputation(builder.Build());
1750 auto assignment = RunBufferAssignment(module.get());
1751
1752 // Bitcast should get the same allocation as the param.
1753 EXPECT_EQ(1, assignment->Allocations().size());
1754 EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
1755 GetTopLevelAllocation(*assignment, bitcast));
1756 }
1757
1758 // TODO(b/34669761): Remove this test when buffers are allowed to share
1759 // allocations.
TEST_F(BufferAssignmentTest,TupleBufferNotReused)1760 TEST_F(BufferAssignmentTest, TupleBufferNotReused) {
1761 // Test a computation that returns a tuple parameter.
1762 auto builder = HloComputation::Builder(TestName());
1763 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
1764 auto param = builder.AddInstruction(
1765 HloInstruction::CreateParameter(0, scalar_shape, "param0"));
1766 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({param}));
1767 auto tuple_element = builder.AddInstruction(
1768 HloInstruction::CreateGetTupleElement(scalar_shape, tuple, 0));
1769 auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
1770 scalar_shape, HloOpcode::kCopy, tuple_element));
1771
1772 auto module = CreateNewVerifiedModule();
1773 module->AddEntryComputation(builder.Build());
1774 auto assignment = RunBufferAssignment(module.get());
1775
1776 // There should be no buffer reuse. The copy should not reuse the tuple
1777 // buffer.
1778 EXPECT_EQ(3, assignment->Allocations().size());
1779 EXPECT_NE(GetTopLevelAllocation(*assignment, tuple),
1780 GetTopLevelAllocation(*assignment, copy));
1781 }
1782
TEST_F(BufferAssignmentTest,OneTempAllocation)1783 TEST_F(BufferAssignmentTest, OneTempAllocation) {
1784 // Test a computation that requires multiple temp buffers, and ensure they
1785 // are combined into a single allocation.
1786 auto builder = HloComputation::Builder(TestName());
1787 Shape shape_2x3 = ShapeUtil::MakeShape(F32, {2, 3});
1788 Shape shape_2x4 = ShapeUtil::MakeShape(F32, {2, 4});
1789 Shape shape_3x4 = ShapeUtil::MakeShape(F32, {3, 4});
1790 Shape shape_4x4 = ShapeUtil::MakeShape(F32, {4, 4});
1791 Shape shape_5x4 = ShapeUtil::MakeShape(F32, {5, 4});
1792
1793 // There should be separate temp buffers for dot_ab and dot_bc.
1794 auto param_a = builder.AddInstruction(
1795 HloInstruction::CreateParameter(0, shape_2x3, "param_a"));
1796 auto param_b = builder.AddInstruction(
1797 HloInstruction::CreateParameter(1, shape_3x4, "param_b"));
1798 auto param_c = builder.AddInstruction(
1799 HloInstruction::CreateParameter(2, shape_4x4, "param_c"));
1800 DotDimensionNumbers dot_dnums;
1801 dot_dnums.add_lhs_contracting_dimensions(1);
1802 dot_dnums.add_rhs_contracting_dimensions(0);
1803 PrecisionConfig precision_config;
1804 precision_config.mutable_operand_precision()->Resize(
1805 2, PrecisionConfig::DEFAULT);
1806 auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
1807 shape_2x4, param_a, param_b, dot_dnums, precision_config));
1808 auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
1809 shape_3x4, param_b, param_c, dot_dnums, precision_config));
1810 builder.AddInstruction(
1811 HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
1812
1813 // Run buffer assignment with alignment=1.
1814 auto module = CreateNewVerifiedModule();
1815 module->AddEntryComputation(builder.Build());
1816 auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1);
1817
1818 // There are 5 allocations: 3 parameters, 1 output, and 1 temp.
1819 EXPECT_EQ(5, assignment->Allocations().size());
1820
1821 // Ensure the temp buffers for dot_ab and dot_bc share a single allocation,
1822 // and each occupies different slices of that allocation.
1823 BufferAllocation::Slice slice_ab =
1824 assignment->GetUniqueTopLevelSlice(dot_ab).value();
1825 BufferAllocation::Slice slice_bc =
1826 assignment->GetUniqueTopLevelSlice(dot_bc).value();
1827 EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1828 EXPECT_NE(slice_ab, slice_bc);
1829 EXPECT_EQ(32, slice_ab.size());
1830 EXPECT_EQ(48, slice_bc.size());
1831 EXPECT_EQ(80, slice_ab.allocation()->size());
1832 EXPECT_EQ(80, slice_bc.allocation()->size());
1833
1834 // Re-run buffer assignment with alignment=64.
1835 assignment = RunBufferAssignment(module.get(), /*alignment=*/64);
1836 EXPECT_EQ(5, assignment->Allocations().size());
1837 slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).value();
1838 slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).value();
1839 EXPECT_EQ(slice_ab.allocation(), slice_bc.allocation());
1840 EXPECT_NE(slice_ab, slice_bc);
1841 EXPECT_EQ(32, slice_ab.size());
1842 EXPECT_EQ(48, slice_bc.size());
1843 // Ensure the offsets and allocation size account for the alignment, without
1844 // assuming which buffer gets assigned first.
1845 if (slice_ab.offset() == 0) {
1846 EXPECT_EQ(64, slice_bc.offset());
1847 EXPECT_EQ(64 + 48, slice_ab.allocation()->size());
1848 EXPECT_EQ(64 + 48, slice_bc.allocation()->size());
1849 } else {
1850 EXPECT_EQ(64, slice_ab.offset());
1851 EXPECT_EQ(0, slice_bc.offset());
1852 EXPECT_EQ(64 + 32, slice_ab.allocation()->size());
1853 EXPECT_EQ(64 + 32, slice_bc.allocation()->size());
1854 }
1855 }
1856
TEST_F(BufferAssignmentTest,TrivialPeakBuffers)1857 TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
1858 // paramscalar -(bc)- (mul) -- (add) -- (sub)
1859 // / / /
1860 // param0[100] -------/ / /
1861 // / /
1862 // param1[100] --------------/--------/
1863 auto builder = HloComputation::Builder(TestName());
1864 auto paramscalar =
1865 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
1866 auto broadcast = builder.AddInstruction(
1867 HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {}));
1868 auto param0 = builder.AddInstruction(
1869 HloInstruction::CreateParameter(1, f32vec100_, "p1"));
1870 auto param1 = builder.AddInstruction(
1871 HloInstruction::CreateParameter(2, f32vec100_, "p2"));
1872 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
1873 f32vec100_, HloOpcode::kMultiply, broadcast, param0));
1874 auto add = builder.AddInstruction(
1875 HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
1876 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
1877 f32vec100_, HloOpcode::kSubtract, add, param1));
1878 auto module = CreateNewVerifiedModule();
1879 module->AddEntryComputation(builder.Build());
1880
1881 auto buffers = RunBufferAssignment(module.get());
1882
1883 const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
1884 const std::vector<const HloValue*>& peak_buffers =
1885 mul_buffer.PeakMemoryLogicalBuffers();
1886 ASSERT_EQ(peak_buffers.size(), 1);
1887 EXPECT_EQ(peak_buffers[0]->instruction(), sub);
1888 }
1889
TEST_F(BufferAssignmentTest,PeakBuffers)1890 TEST_F(BufferAssignmentTest, PeakBuffers) {
1891 // Compute the peak liveness buffers of the following sequence:
1892 //
1893 // %param = ...
1894 // %log = log(%param)
1895 // %rev = reverse(%log)
1896 // %neg = neg(%param)
1897 // %concat = concat(%rev, %neg)
1898 // ROOT %root = slice(concat)
1899 //
1900 // In the temporary block, the set of live buffers at peak memory use should
1901 // be {%rev, %neg, %concat}. This occurs right at the concat itself.
1902 auto builder = HloComputation::Builder(TestName());
1903 auto param = builder.AddInstruction(
1904 HloInstruction::CreateParameter(0, f32vec100_, "p"));
1905 auto log = builder.AddInstruction(
1906 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
1907 auto rev = builder.AddInstruction(
1908 HloInstruction::CreateReverse(f32vec100_, log, {0}));
1909 auto neg = builder.AddInstruction(
1910 HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param));
1911 const Shape concat_shape = ShapeUtil::MakeShape(F32, {200});
1912 auto concat = builder.AddInstruction(
1913 HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0));
1914 // Make the root tiny so no interior nodes can share its buffer.
1915 auto root = builder.AddInstruction(HloInstruction::CreateSlice(
1916
1917 ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1}));
1918
1919 auto module = CreateNewVerifiedModule();
1920 module->AddEntryComputation(builder.Build());
1921
1922 auto buffers = RunBufferAssignmentWithInstructionSequence(
1923 module.get(), {param, log, rev, neg, concat, root});
1924
1925 // The temporary buffer should hold the 4 interior instructions.
1926 const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat);
1927 EXPECT_FALSE(buffer.IsInputOrOutput());
1928 EXPECT_TRUE(buffer.IsPreallocatedTempBuffer());
1929 ASSERT_EQ(buffer.assigned_buffers().size(), 4);
1930
1931 const std::vector<const HloValue*>& peak_buffers =
1932 buffer.PeakMemoryLogicalBuffers();
1933
1934 // The peak live set should be concat and its inputs.
1935 ASSERT_EQ(peak_buffers.size(), 3);
1936 std::vector<const HloInstruction*> peak_instructions;
1937 for (const HloValue* logical_buffer : peak_buffers) {
1938 peak_instructions.push_back(logical_buffer->instruction());
1939 }
1940 EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
1941 }
1942
TEST_F(BufferAssignmentTest,AliasedBuffersShouldntCoexistInPeakBuffers)1943 TEST_F(BufferAssignmentTest, AliasedBuffersShouldntCoexistInPeakBuffers) {
1944 std::string hlo_text = R"(
1945 HloModule test_module, is_scheduled=true
1946
1947 cond {
1948 param = (s32[], s32[]) parameter(0)
1949 ROOT constant = pred[] constant(true)
1950 }
1951
1952 body {
1953 param.0 = (s32[], s32[]) parameter(0)
1954 gte = s32[] get-tuple-element(param.0), index=0
1955 add = s32[] add(gte, gte)
1956 ROOT tuple = (s32[], s32[]) tuple(add, add)
1957 }
1958
1959 ENTRY test_module {
1960 param.3 = s32[] parameter(0)
1961 copy = s32[] copy(param.3)
1962 tuple = (s32[], s32[]) tuple(copy, copy)
1963 while = (s32[], s32[]) while(tuple), condition=cond, body=body
1964 gte = s32[] get-tuple-element(while), index=0
1965 ROOT negate = s32[] negate(gte)
1966 })";
1967
1968 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
1969 auto assignment = RunBufferAssignmentWithSequentialOrdering(module.get());
1970 const BufferAllocation& buffer =
1971 GetTopLevelAllocation(*assignment, FindInstruction(module.get(), "copy"));
1972 const std::vector<const HloValue*>& peak_buffers =
1973 buffer.PeakMemoryLogicalBuffers();
1974
1975 // Since the same aliased buffer (copy) is passed into while, we expect the
1976 // number of peak array buffers to be one.
1977 int num_peak_buffers = 0;
1978 for (const HloValue* peak_buffer : peak_buffers) {
1979 if (peak_buffer->shape().IsArray()) {
1980 ++num_peak_buffers;
1981 }
1982 }
1983 EXPECT_EQ(num_peak_buffers, 1);
1984 }
1985
TEST_F(BufferAssignmentTest,InPlaceBuffer)1986 TEST_F(BufferAssignmentTest, InPlaceBuffer) {
1987 const char* hlo_text = R"(
1988 HloModule Module
1989
1990 ENTRY main {
1991 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
1992 constant.1 = f32[] constant(0)
1993 broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
1994 get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
1995 get-tuple-element.3 = s32[] get-tuple-element(state), index=0
1996 constant.2 = s32[] constant(128)
1997 add.5 = s32[] add(get-tuple-element.3, constant.2)
1998 constant.3 = s32[] constant(0)
1999 dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2000 dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2001 ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2002 }
2003 )";
2004
2005 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
2006 HloInstruction* parameter =
2007 m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
2008 HloInstruction* dus1 =
2009 m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
2010 HloInstruction* dus2 =
2011 m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
2012
2013 auto buffers = RunBufferAssignment(m.get());
2014
2015 {
2016 const BufferAllocation& parameter_alloc =
2017 GetTopLevelAllocation(*buffers, parameter);
2018
2019 const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
2020 EXPECT_EQ(parameter_alloc, dus1_alloc);
2021 const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
2022 EXPECT_EQ(parameter_alloc, dus2_alloc);
2023 }
2024 }
2025
TEST_F(BufferAssignmentTest,ConstantBuffersAreNotReused)2026 TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) {
2027 const char* hlo_text = R"(
2028 HloModule Module
2029
2030 True {
2031 ROOT x.0.1 = f32[] parameter(0)
2032 }
2033
2034 False {
2035 x.0.0 = f32[] parameter(0)
2036 ROOT copy.1 = f32[] copy(x.0.0)
2037 }
2038
2039 ENTRY main {
2040 pred.1.0 = pred[] parameter(0)
2041 constant.1.1 = f32[] constant(56)
2042 copy.2 = f32[] copy(constant.1.1)
2043 constant.1.2 = f32[] constant(12)
2044 ROOT conditional.1.3 = f32[] conditional(pred.1.0, copy.2, constant.1.2),
2045 true_computation=True, false_computation=False
2046 }
2047 )";
2048
2049 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
2050 HloInstruction* constant_1 =
2051 m->entry_computation()->GetInstructionWithName("constant.1.1");
2052 HloInstruction* constant_2 =
2053 m->entry_computation()->GetInstructionWithName("constant.1.2");
2054
2055 auto buffers = RunBufferAssignment(m.get());
2056
2057 {
2058 const BufferAllocation& allocation_for_const_1 =
2059 GetTopLevelAllocation(*buffers, constant_1);
2060 EXPECT_TRUE(allocation_for_const_1.is_constant());
2061 for (const auto& buffer_offset_pair :
2062 allocation_for_const_1.assigned_buffers()) {
2063 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2064 HloOpcode::kCopy);
2065 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2066 HloOpcode::kConditional);
2067 }
2068 }
2069
2070 {
2071 const BufferAllocation& allocation_for_const_2 =
2072 GetTopLevelAllocation(*buffers, constant_2);
2073 EXPECT_TRUE(allocation_for_const_2.is_constant());
2074 for (const auto& buffer_offset_pair :
2075 allocation_for_const_2.assigned_buffers()) {
2076 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2077 HloOpcode::kCopy);
2078 EXPECT_NE(buffer_offset_pair.first->instruction()->opcode(),
2079 HloOpcode::kConditional);
2080 }
2081 }
2082 }
2083
2084 class WhileBufferAssignmentTest : public HloTestBase {
2085 protected:
BuildWhileConditionComputation(const std::string & name)2086 std::unique_ptr<HloComputation> BuildWhileConditionComputation(
2087 const std::string& name) {
2088 auto builder = HloComputation::Builder(name);
2089 builder.AddInstruction(
2090 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2091 auto zero = builder.AddInstruction(
2092 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
2093 auto ten = builder.AddInstruction(
2094 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
2095 builder.AddInstruction(HloInstruction::CreateCompare(
2096 ShapeUtil::MakeShape(PRED, {}), zero, ten, ComparisonDirection::kLt));
2097 return builder.Build();
2098 }
2099
BuildWhileBodyComputation(const std::string & name)2100 std::unique_ptr<HloComputation> BuildWhileBodyComputation(
2101 const std::string& name) {
2102 auto builder = HloComputation::Builder(name);
2103 auto loop_state = builder.AddInstruction(
2104 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
2105 auto input = builder.AddInstruction(
2106 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0));
2107 auto weights = builder.AddInstruction(
2108 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
2109 auto output = builder.AddInstruction(HloInstruction::CreateBinary(
2110 data_shape_, HloOpcode::kMultiply, input, weights));
2111 builder.AddInstruction(
2112 HloInstruction::CreateTuple({input, weights, output}));
2113 return builder.Build();
2114 }
2115
RunBufferAssignment(HloModule * module,int64_t alignment=1)2116 std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
2117 int64_t alignment = 1) {
2118 HloSchedule schedule = ScheduleModule(module, ByteSizeOf).value();
2119 return BufferAssigner::Run(
2120 module, std::make_unique<SequentialHloOrdering>(schedule),
2121 ByteSizeOf,
2122 [alignment](LogicalBuffer::Color) { return alignment; },
2123 /*allocate_buffers_for_constants=*/true)
2124 .value();
2125 }
2126
ByteSizeOf(const BufferValue & buffer)2127 static int64_t ByteSizeOf(const BufferValue& buffer) {
2128 return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*));
2129 }
2130
2131 Shape data_shape_ = ShapeUtil::MakeShape(F32, {4});
2132 Shape loop_state_shape_ =
2133 ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_});
2134 };
2135
RunCopyInsertion(HloModule * module)2136 static void RunCopyInsertion(HloModule* module) {
2137 CopyInsertion copy_insertion;
2138 EXPECT_IS_OK(copy_insertion.Run(module).status());
2139 }
2140
TEST_F(WhileBufferAssignmentTest,TwoForwardWhileLoops)2141 TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
2142 auto module = CreateNewVerifiedModule();
2143 auto builder = HloComputation::Builder("entry");
2144
2145 auto input0 = builder.AddInstruction(
2146 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2147 auto weights0 = builder.AddInstruction(
2148 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2149 auto weights1 = builder.AddInstruction(
2150 HloInstruction::CreateParameter(2, data_shape_, "weights1"));
2151
2152 auto zero = builder.AddInstruction(
2153 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2154 auto output0 = builder.AddInstruction(
2155 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2156 auto output1 = builder.AddInstruction(
2157 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2158
2159 auto cond0 =
2160 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2161 auto body0 =
2162 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2163
2164 auto tuple0 = builder.AddInstruction(
2165 HloInstruction::CreateTuple({input0, weights0, output0}));
2166 auto while0 = builder.AddInstruction(
2167 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2168
2169 auto cond1 =
2170 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2171 auto body1 =
2172 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2173 auto input1 = builder.AddInstruction(
2174 HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2175 auto tuple1 = builder.AddInstruction(
2176 HloInstruction::CreateTuple({input1, weights1, output1}));
2177 auto while1 = builder.AddInstruction(
2178 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2179
2180 module->AddEntryComputation(builder.Build());
2181 RunCopyInsertion(module.get());
2182 auto assignment = RunBufferAssignment(module.get());
2183
2184 // Verify 'input0' and read-only use while0{0} alias.
2185 EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).value(),
2186 assignment->GetUniqueSlice(while0, {0}).value());
2187 // Verify 'weights0' and read-only use while0{1} alias.
2188 EXPECT_EQ(assignment->GetUniqueSlice(weights0, {}).value(),
2189 assignment->GetUniqueSlice(while0, {1}).value());
2190 // Verify 'while0{2}' and read-only use while1{0} alias.
2191 EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).value(),
2192 assignment->GetUniqueSlice(while1, {0}).value());
2193 // Verify 'weights1' and read-only use while1{1} alias.
2194 EXPECT_EQ(assignment->GetUniqueSlice(weights1, {}).value(),
2195 assignment->GetUniqueSlice(while1, {1}).value());
2196 }
2197
2198 // Tests that two colocated buffer sets are not merged if an entry parameter
2199 // buffer belongs to either of the colocation sets (b/73267882).
2200 //
2201 // %param --> %while.0 --> %mul --> %while.1 --> %broadcast
2202 //
2203 // %while.0 body just forwards the init value, so the loop carried variable
2204 // remains the constant, whereas %while.1 changes the loop carried variable.
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithEntryParameter)2205 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) {
2206 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2207
2208 const char* module_str = R"(
2209 HloModule test_module
2210
2211 %cond.v0 {
2212 %param = s32[] parameter(0)
2213 ROOT %constant = pred[] constant(true)
2214 }
2215
2216 %cond.v1 {
2217 %param.0 = s32[] parameter(0)
2218 ROOT %constant.0 = pred[] constant(true)
2219 }
2220
2221 %body.v0 {
2222 ROOT %param.1 = s32[] parameter(0)
2223 }
2224
2225 %body.v1 {
2226 %param.2 = s32[] parameter(0)
2227 ROOT add = s32[] add(%param.2, %param.2)
2228 }
2229
2230 ENTRY %test_module {
2231 %param.3 = s32[] parameter(0)
2232 %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0
2233 %mul = s32[] multiply(%while.0, %while.0)
2234 %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2235 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2236 })";
2237
2238 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2239
2240 // Run CopyInsertion and check if the graph constructed above doesn't need
2241 // any copies inserted for BufferAssignment to run.
2242 int64_t instruction_count = m->instruction_count();
2243 CopyInsertion copy_insertion;
2244 ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2245 ASSERT_EQ(instruction_count, m->instruction_count());
2246
2247 // Get the instructions in the module.
2248 const HloInstruction* bcast = m->entry_computation()->root_instruction();
2249 const HloInstruction* param =
2250 m->entry_computation()->parameter_instruction(0);
2251 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2252 const HloInstruction* while1 = bcast->operand(0);
2253 ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2254 const HloInstruction* while0 = while1->operand(0)->operand(0);
2255 ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2256
2257 // Run buffer assignment.
2258 auto assignment = RunBufferAssignment(m.get());
2259 TF_ASSERT_OK_AND_ASSIGN(auto slice_param,
2260 assignment->GetUniqueSlice(param, {}));
2261 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2262 assignment->GetUniqueSlice(while0, {}));
2263 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2264 assignment->GetUniqueSlice(while1, {}));
2265
2266 // The parameter slice is part of the while0's colocation set (init value),
2267 // but not merged into the while1's colocation set.
2268 EXPECT_EQ(slice_param, slice_while0);
2269 EXPECT_NE(slice_param, slice_while1);
2270 }
2271
TEST_F(WhileBufferAssignmentTest,ColocatedBufferWithConstant)2272 TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithConstant) {
2273 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2274
2275 const char* module_str = R"(
2276 HloModule test_module
2277
2278 %cond.v0 {
2279 %param = s32[] parameter(0)
2280 ROOT %constant = pred[] constant(true)
2281 }
2282
2283 %cond.v1 {
2284 %param.0 = s32[] parameter(0)
2285 ROOT %constant.0 = pred[] constant(true)
2286 }
2287
2288 %body.v0 {
2289 ROOT %param.1 = s32[] parameter(0)
2290 }
2291
2292 %body.v1 {
2293 %param.2 = s32[] parameter(0)
2294 ROOT add = s32[] add(%param.2, %param.2)
2295 }
2296
2297 ENTRY %test_module {
2298 %constant.42 = s32[] constant(42)
2299 %while.0 = s32[] while(%constant.42), condition=%cond.v0, body=%body.v0
2300 %mul = s32[] multiply(%while.0, %while.0)
2301 %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1
2302 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={}
2303 })";
2304
2305 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2306
2307 // Run CopyInsertion and check if the graph constructed above doesn't need
2308 // any copies inserted for BufferAssignment to run.
2309 int64_t instruction_count = m->instruction_count();
2310 CopyInsertion copy_insertion;
2311 ASSERT_IS_OK(copy_insertion.Run(m.get()).status());
2312 ASSERT_EQ(instruction_count, m->instruction_count());
2313
2314 // Get the instructions in the module.
2315 const HloInstruction* bcast = m->entry_computation()->root_instruction();
2316 const HloInstruction* constant =
2317 m->entry_computation()->GetInstructionWithName("constant.42");
2318 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
2319 const HloInstruction* while1 = bcast->operand(0);
2320 ASSERT_EQ(while1->opcode(), HloOpcode::kWhile);
2321 const HloInstruction* while0 = while1->operand(0)->operand(0);
2322 ASSERT_EQ(while0->opcode(), HloOpcode::kWhile);
2323
2324 // Run buffer assignment.
2325 auto assignment = RunBufferAssignment(m.get());
2326 TF_ASSERT_OK_AND_ASSIGN(auto slice_constant,
2327 assignment->GetUniqueSlice(constant, {}));
2328 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2329 assignment->GetUniqueSlice(while0, {}));
2330 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2331 assignment->GetUniqueSlice(while1, {}));
2332
2333 // The constant slice is part of the while0's colocation set (init value), but
2334 // not merged into the while1's colocation set.
2335 EXPECT_EQ(slice_constant, slice_while0);
2336 EXPECT_NE(slice_constant, slice_while1);
2337 }
2338
2339 // Tests that the colocated buffers for while instructions are properly assigned
2340 // during buffer assignment such that the result tuple elements are not assigned
2341 // to the same buffer.
2342 //
2343 // %infeed --> %while.0 --> %while.1 --+
2344 // +-- %tuple
2345 // %zero --> %add --> %while.2 --+
2346 //
2347 // Execution Order:
2348 // %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple
2349 //
2350 // The HLO computation used in this test requires specific ordering to expose
2351 // the bug (b/72496031). During buffer assignment, the visitation order of
2352 // colocated buffers is %while.2 -> while.0 -> while.1, and the buffer
2353 // assignment was coalescing the colocated buffers for all 3 while instructions,
2354 // therefore assigning the same buffer to the two result tuple elements.
TEST_F(WhileBufferAssignmentTest,ColocatedBuffers)2355 TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
2356 const Shape r0s32 = ShapeUtil::MakeShape(S32, {});
2357
2358 // Builds a condition computation: x -> x < 4
2359 auto build_cond = [&]() {
2360 auto builder = HloComputation::Builder("cond");
2361 auto const4 = builder.AddInstruction(
2362 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
2363 auto param =
2364 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2365 builder.AddInstruction(
2366 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param,
2367 const4, ComparisonDirection::kLt));
2368 return builder.Build();
2369 };
2370
2371 // Builds a body computation: x -> x + 9
2372 auto build_body = [&]() {
2373 auto builder = HloComputation::Builder("body");
2374 auto const9 = builder.AddInstruction(
2375 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
2376 auto param =
2377 builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
2378 builder.AddInstruction(
2379 HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9));
2380 return builder.Build();
2381 };
2382
2383 // Build the entry computation as described in the comment above.
2384 auto module = CreateNewVerifiedModule();
2385 auto builder = HloComputation::Builder("entry");
2386
2387 auto token = builder.AddInstruction(HloInstruction::CreateToken());
2388 auto infeed =
2389 builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
2390 auto infeed_data = builder.AddInstruction(
2391 HloInstruction::CreateGetTupleElement(r0s32, infeed, 0));
2392 auto cond0 = module->AddEmbeddedComputation(build_cond());
2393 auto body0 = module->AddEmbeddedComputation(build_body());
2394 auto while0 = builder.AddInstruction(
2395 HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data));
2396
2397 auto cond1 = module->AddEmbeddedComputation(build_cond());
2398 auto body1 = module->AddEmbeddedComputation(build_body());
2399 auto while1 = builder.AddInstruction(
2400 HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
2401
2402 auto zero = builder.AddInstruction(
2403 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
2404 auto add = builder.AddInstruction(
2405 HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
2406 auto cond2 = module->AddEmbeddedComputation(build_cond());
2407 auto body2 = module->AddEmbeddedComputation(build_body());
2408 auto while2 = builder.AddInstruction(
2409 HloInstruction::CreateWhile(r0s32, cond2, body2, add));
2410
2411 auto tuple =
2412 builder.AddInstruction(HloInstruction::CreateTuple({while2, while1}));
2413 module->AddEntryComputation(builder.Build());
2414
2415 // Run CopyInsertion and check if the graph constructed above doesn't need
2416 // any copies inserted for BufferAssignment to run.
2417 int64_t instruction_count = module->instruction_count();
2418 CopyInsertion copy_insertion;
2419 ASSERT_IS_OK(copy_insertion.Run(module.get()).status());
2420 ASSERT_EQ(instruction_count, module->instruction_count());
2421
2422 // Create a sequential order among all the instructions in the entry
2423 // computation, since the issue this test stresses depends on the order the
2424 // nodes are traversed during BufferAssignment.
2425 TF_ASSERT_OK_AND_ASSIGN(
2426 HloSchedule schedule,
2427 ScheduleModule(module.get(), [](const BufferValue& buffer) {
2428 return ShapeUtil::ByteSizeOf(buffer.shape(),
2429 /*pointer_size=*/sizeof(void*));
2430 }));
2431 schedule.set_sequence(
2432 module->entry_computation(),
2433 {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
2434 TF_ASSERT_OK(schedule.Verify());
2435
2436 TF_ASSERT_OK_AND_ASSIGN(
2437 auto assignment,
2438 BufferAssigner::Run(
2439 module.get(), std::make_unique<SequentialHloOrdering>(schedule),
2440 backend().compiler()->BufferSizeBytesFunction(),
2441 [](LogicalBuffer::Color) { return 1; },
2442 /*allocate_buffers_for_constants=*/true));
2443
2444 // The result tuple elements must be assigned with different buffers.
2445 TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
2446 TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1}));
2447 EXPECT_NE(slice0, slice1);
2448
2449 // while0 and while1 result buffers must be equal to slice1.
2450 TF_ASSERT_OK_AND_ASSIGN(auto slice_while0,
2451 assignment->GetUniqueSlice(while0, {}));
2452 TF_ASSERT_OK_AND_ASSIGN(auto slice_while1,
2453 assignment->GetUniqueSlice(while1, {}));
2454 EXPECT_EQ(slice1, slice_while0);
2455 EXPECT_EQ(slice1, slice_while1);
2456
2457 // while2 result buffer must be equal to slice0.
2458 TF_ASSERT_OK_AND_ASSIGN(auto slice_while2,
2459 assignment->GetUniqueSlice(while2, {}));
2460 EXPECT_EQ(slice0, slice_while2);
2461 }
2462
TEST_F(WhileBufferAssignmentTest,OneForwardBackwardWhileLoopSet)2463 TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
2464 auto module = CreateNewVerifiedModule();
2465 auto builder = HloComputation::Builder("entry");
2466
2467 auto input0 = builder.AddInstruction(
2468 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2469 auto weights0 = builder.AddInstruction(
2470 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2471
2472 auto zero = builder.AddInstruction(
2473 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2474 auto output0 = builder.AddInstruction(
2475 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2476
2477 auto cond0 =
2478 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2479 auto body0 =
2480 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2481
2482 auto tuple0 = builder.AddInstruction(
2483 HloInstruction::CreateTuple({input0, weights0, output0}));
2484 auto while0 = builder.AddInstruction(
2485 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2486
2487 auto cond1 =
2488 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2489 auto body1 =
2490 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2491
2492 auto while1 = builder.AddInstruction(
2493 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
2494
2495 module->AddEntryComputation(builder.Build());
2496 RunCopyInsertion(module.get());
2497 auto assignment = RunBufferAssignment(module.get());
2498
2499 // while0 and while1 buffers should be completely aligned.
2500 EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).value(),
2501 assignment->GetUniqueSlice(while1, {0}).value());
2502 EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).value(),
2503 assignment->GetUniqueSlice(while1, {1}).value());
2504 EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).value(),
2505 assignment->GetUniqueSlice(while1, {2}).value());
2506 }
2507
TEST_F(BufferAssignmentTest,TwoCalls)2508 TEST_F(BufferAssignmentTest, TwoCalls) {
2509 auto module = CreateNewVerifiedModule();
2510 Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
2511 HloComputation* sub_computation;
2512 {
2513 auto builder = HloComputation::Builder(TestName() + "_sub_comp");
2514 auto param = builder.AddInstruction(
2515 HloInstruction::CreateParameter(0, r0f32, "param"));
2516 auto constant1 = builder.AddInstruction(
2517 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2518 auto add = builder.AddInstruction(
2519 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
2520 sub_computation = module->AddEmbeddedComputation(builder.Build(add));
2521 }
2522 auto builder = HloComputation::Builder(TestName());
2523 auto constant2 = builder.AddInstruction(
2524 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
2525 auto constant3 = builder.AddInstruction(
2526 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
2527 auto call1 = builder.AddInstruction(
2528 HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
2529 auto call2 = builder.AddInstruction(
2530 HloInstruction::CreateCall(r0f32, {constant3}, sub_computation));
2531 auto add1 = builder.AddInstruction(
2532 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2));
2533 auto add2 = builder.AddInstruction(
2534 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1));
2535 module->AddEntryComputation(builder.Build(add2));
2536
2537 {
2538 FlattenCallGraph flatten;
2539 TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2540 EXPECT_TRUE(result);
2541 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
2542 }
2543
2544 RunCopyInsertion(module.get());
2545 auto assignment = RunBufferAssignment(module.get());
2546
2547 EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment));
2548 }
2549
TEST_F(BufferAssignmentTest,CallParamCoAllocation)2550 TEST_F(BufferAssignmentTest, CallParamCoAllocation) {
2551 const char* hlo_text = R"(
2552 HloModule CallParamCoAllocation
2553
2554 Callee {
2555 param0 = (f32[100],(f32[200],f32[300])) parameter(0)
2556 param1 = s32[20] parameter(1)
2557 ROOT constant = f32[] constant(1)
2558 }
2559
2560 ENTRY Main {
2561 entry_param0 = f32[100] parameter(0)
2562 entry_param1 = s32[20] parameter(1)
2563 custom_call = (f32[200],f32[300]) custom-call(), custom_call_target="call-target"
2564 call_op0 = (f32[100],(f32[200],f32[300])) tuple(entry_param0, custom_call)
2565 ROOT call_result = f32[] call(call_op0, entry_param1), to_apply=Callee
2566 }
2567 )";
2568
2569 HloModuleConfig config;
2570 config.set_debug_options(GetDebugOptionsFromFlags());
2571 TF_ASSERT_OK_AND_ASSIGN(auto m,
2572 ParseAndReturnVerifiedModule(hlo_text, config));
2573
2574 auto buffers = RunBufferAssignment(m.get());
2575
2576 HloComputation* main = m->entry_computation();
2577 HloComputation* callee = m->GetComputationWithName("Callee");
2578 EXPECT_NE(callee, nullptr);
2579
2580 HloInstruction* param0 = callee->parameter_instruction(0);
2581 HloInstruction* param1 = callee->parameter_instruction(1);
2582
2583 HloInstruction* entry_param0 = main->parameter_instruction(0);
2584 HloInstruction* entry_param1 = main->parameter_instruction(1);
2585 HloInstruction* custom_call = main->GetInstructionWithName("custom_call");
2586
2587 EXPECT_EQ(GetAllocation(*buffers, entry_param0, {}),
2588 GetAllocation(*buffers, param0, {0}));
2589 EXPECT_EQ(GetAllocation(*buffers, entry_param1, {}),
2590 GetAllocation(*buffers, param1, {}));
2591
2592 EXPECT_EQ(GetAllocation(*buffers, custom_call, {}),
2593 GetAllocation(*buffers, param0, {1}));
2594 EXPECT_EQ(GetAllocation(*buffers, custom_call, {0}),
2595 GetAllocation(*buffers, param0, {1, 0}));
2596 EXPECT_EQ(GetAllocation(*buffers, custom_call, {1}),
2597 GetAllocation(*buffers, param0, {1, 1}));
2598 }
2599
TEST_F(BufferAssignmentTest,AsyncCall)2600 TEST_F(BufferAssignmentTest, AsyncCall) {
2601 const char* hlo_text = R"(
2602 HloModule AsyncCall, is_scheduled=true
2603
2604 %called_computation (param_0: f32[4096], param_1: f32[4096]) -> f32[4096] {
2605 %param_0 = f32[4096]{0} parameter(0)
2606 %param_1 = f32[4096]{0} parameter(1)
2607 %negate_0 = f32[4096]{0} negate(f32[4096]{0} %param_0)
2608 %negate_1 = f32[4096]{0} negate(f32[4096]{0} %param_1)
2609 %negate_2 = f32[4096]{0} negate(f32[4096]{0} %negate_1)
2610 %negate_3 = f32[4096]{0} negate(f32[4096]{0} %negate_2)
2611 ROOT %result.1 = f32[4096]{0} add(f32[4096]{0} %negate_0, f32[4096]{0} %negate_3)
2612 }
2613
2614 ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] {
2615 %a = f32[4096]{0} parameter(0)
2616 %b = f32[4096]{0} parameter(1)
2617 %async-start = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-start(f32[4096]{0} %a, f32[4096]{0} %b), to_apply=%called_computation
2618 %negate_4 = f32[4096]{0} negate(f32[4096]{0} %a)
2619 %negate_5 = f32[4096]{0} negate(f32[4096]{0} %b)
2620 %negate_6 = f32[4096]{0} negate(f32[4096]{0} %negate_5)
2621 %negate_7 = f32[4096]{0} negate(f32[4096]{0} %negate_6)
2622 %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_4, f32[4096]{0} %negate_7)
2623 %async-done = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start), to_apply=%called_computation
2624 ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done)
2625 }
2626 )";
2627
2628 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
2629 AsyncOpCanonicalizer async_op_canonicalizer;
2630 EXPECT_TRUE(async_op_canonicalizer.Run(m.get()).ok());
2631 HloDCE dce;
2632 EXPECT_TRUE(dce.Run(m.get()).ok());
2633
2634 auto buffers = RunBufferAssignmentWithSequentialOrdering(m.get());
2635
2636 LOG(INFO) << buffers->ToString();
2637
2638 auto get_slice = [&](std::string_view hlo_name, const ShapeIndex& index) {
2639 return buffers->GetUniqueSlice(FindInstruction(m.get(), hlo_name), index)
2640 .ValueOrDie();
2641 };
2642
2643 // Make sure the parameters and root of the async called computation has the
2644 // same slice as the async call operands/output.
2645 EXPECT_EQ(get_slice("param_0", {}), get_slice("a", {}));
2646 EXPECT_EQ(get_slice("param_1", {}), get_slice("b", {}));
2647 EXPECT_EQ(get_slice("result.1", {}), get_slice("async-done", {}));
2648
2649 // Make sure the intermediate values in the async called computation have
2650 // different allocated slices than the values that overlap it.
2651 for (const auto& hlo_name :
2652 {"negate_0", "negate_1", "negate_2", "negate_3"}) {
2653 EXPECT_NE(get_slice(hlo_name, {}), get_slice("negate_4", {}));
2654 EXPECT_NE(get_slice(hlo_name, {}), get_slice("negate_5", {}));
2655 EXPECT_NE(get_slice(hlo_name, {}), get_slice("negate_6", {}));
2656 EXPECT_NE(get_slice(hlo_name, {}), get_slice("negate_7", {}));
2657 EXPECT_NE(get_slice(hlo_name, {}), get_slice("add_0", {}));
2658 }
2659 }
2660
TEST_F(BufferAssignmentTest,BufferInfoStringTest)2661 TEST_F(BufferAssignmentTest, BufferInfoStringTest) {
2662 absl::string_view module_str = R"(
2663 HloModule test_module
2664
2665 ENTRY %test_module {
2666 %param.0 = s32[1024]{0} parameter(0)
2667 %param.1 = s32[1024]{0} parameter(1)
2668 %mul = s32[1024]{0} multiply(%param.0, %param.1)
2669 %add = s32[1024]{0} add(%mul, %param.0)
2670 ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[1024] %add), dimensions={0}
2671 })";
2672
2673 absl::string_view reference_str =
2674 R"(buffer_id,buffer_name,offset,size,definition_time,end_time,num_uses,use_times,use_names
2675 0,"<0 param.0 @0>",0,4096,0,5,2,"2;3","mul, operand 0;add, operand 1"
2676 1,"<1 param.1 @0>",0,4096,1,5,1,"2","mul, operand 1"
2677 2,"<2 mul @0>",0,4096,2,3,1,"3","add, operand 0"
2678 3,"<3 add @0>",0,4096,3,4,1,"4","bcast, operand 0"
2679 4,"<4 bcast @0>",0,4194304,4,5,0,"",""
2680 )";
2681
2682 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
2683 HloInstruction* const param0 = FindInstruction(m.get(), "param.0");
2684 HloInstruction* const param1 = FindInstruction(m.get(), "param.1");
2685 HloInstruction* const mul = FindInstruction(m.get(), "mul");
2686 HloInstruction* const add = FindInstruction(m.get(), "add");
2687 HloInstruction* const bcast = FindInstruction(m.get(), "bcast");
2688 // Run buffer assignment.
2689 auto assignment = RunBufferAssignmentWithInstructionSequence(
2690 m.get(), {param0, param1, mul, add, bcast});
2691 const std::string buffer_info_str = assignment->BufferInfoString();
2692
2693 EXPECT_EQ(buffer_info_str, reference_str);
2694 }
2695
TEST_F(WhileBufferAssignmentTest,WhileLoopsInterferingResultRange)2696 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
2697 auto module = CreateNewVerifiedModule();
2698 auto builder = HloComputation::Builder(TestName());
2699
2700 auto zero = builder.AddInstruction(
2701 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2702 auto one = builder.AddInstruction(
2703 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2704
2705 auto input0 = builder.AddInstruction(
2706 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2707 auto weights0 = builder.AddInstruction(
2708 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2709 auto output0 = builder.AddInstruction(
2710 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2711
2712 auto input1 = builder.AddInstruction(
2713 HloInstruction::CreateParameter(2, data_shape_, "input1"));
2714 auto weights1 = builder.AddInstruction(
2715 HloInstruction::CreateParameter(3, data_shape_, "weights1"));
2716 auto output1 = builder.AddInstruction(
2717 HloInstruction::CreateBroadcast(data_shape_, one, {}));
2718
2719 auto cond =
2720 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2721 auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2722
2723 auto tuple0 = builder.AddInstruction(
2724 HloInstruction::CreateTuple({input0, weights0, output0}));
2725 auto tuple1 = builder.AddInstruction(
2726 HloInstruction::CreateTuple({input1, weights1, output1}));
2727
2728 auto while0 = builder.AddInstruction(
2729 HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0));
2730 auto while1 = builder.AddInstruction(
2731 HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
2732
2733 auto gte0 = builder.AddInstruction(
2734 HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
2735 auto gte1 = builder.AddInstruction(
2736 HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
2737 auto root_add = builder.AddInstruction(
2738 HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1));
2739
2740 module->AddEntryComputation(builder.Build());
2741
2742 {
2743 FlattenCallGraph flatten;
2744 TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
2745 EXPECT_TRUE(result);
2746 }
2747
2748 RunCopyInsertion(module.get());
2749
2750 HloSchedule schedule = ScheduleModule(module.get(), ByteSizeOf).value();
2751
2752 // To trigger b/38494731, we want a specific Hlo schedule for the
2753 // root computation, so we overwrite that entry with a manually
2754 // crafted sequence.
2755 schedule.set_sequence(
2756 module->entry_computation(),
2757 {input1, weights1, one, output1, while1->mutable_operand(0), while1,
2758 input0, weights0, zero, output0, while0->mutable_operand(0), while0,
2759 gte0, gte1, root_add});
2760
2761 // If this ASSERT fails, we constructed a bogus sequence above and this test
2762 // itself is buggy.
2763 TF_ASSERT_OK(schedule.Verify());
2764
2765 auto assignment =
2766 BufferAssigner::Run(
2767 module.get(), std::make_unique<SequentialHloOrdering>(schedule),
2768 ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
2769 /*allocate_buffers_for_constants=*/true)
2770 .value();
2771
2772 EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
2773 }
2774
TEST_F(WhileBufferAssignmentTest,WhilesDontShareEntryParamIfLiveOut)2775 TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
2776 auto module = CreateNewVerifiedModule();
2777 auto builder = HloComputation::Builder("entry");
2778
2779 auto input0 = builder.AddInstruction(
2780 HloInstruction::CreateParameter(0, data_shape_, "input0"));
2781 auto weights0 = builder.AddInstruction(
2782 HloInstruction::CreateParameter(1, data_shape_, "weights0"));
2783
2784 auto zero = builder.AddInstruction(
2785 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
2786 auto output0 = builder.AddInstruction(
2787 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2788 auto output1 = builder.AddInstruction(
2789 HloInstruction::CreateBroadcast(data_shape_, zero, {}));
2790
2791 auto cond0 =
2792 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2793 auto body0 =
2794 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2795
2796 auto tuple0 = builder.AddInstruction(
2797 HloInstruction::CreateTuple({input0, weights0, output0}));
2798 auto while0 = builder.AddInstruction(
2799 HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
2800
2801 // Get output of 'while0' and feed as input to 'while1'.
2802 auto while0_out = builder.AddInstruction(
2803 HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
2804
2805 auto cond1 =
2806 module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
2807 auto body1 =
2808 module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
2809
2810 auto tuple1 = builder.AddInstruction(
2811 HloInstruction::CreateTuple({while0_out, weights0, output1}));
2812 auto while1 = builder.AddInstruction(
2813 HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
2814
2815 // Get output of 'while1' so that it is live out of computation.
2816 auto while1_out = builder.AddInstruction(
2817 HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
2818
2819 module->AddEntryComputation(builder.Build());
2820 RunCopyInsertion(module.get());
2821 auto assignment = RunBufferAssignment(module.get());
2822 // Get BufferAllocation for root instruction.
2823 auto* root_alloc =
2824 assignment->GetUniqueTopLevelSlice(while1_out).value().allocation();
2825 // Test that root instruction allocation is live out.
2826 EXPECT_TRUE(root_alloc->maybe_live_out());
2827 // Test that root instruction allocation is not an entry parameter.
2828 EXPECT_FALSE(root_alloc->is_entry_computation_parameter());
2829 }
2830
TEST_F(WhileBufferAssignmentTest,WhileWithDynamicUpdateSliceShare)2831 TEST_F(WhileBufferAssignmentTest, WhileWithDynamicUpdateSliceShare) {
2832 const char* const hlo_string = R"(
2833 HloModule test
2834
2835 while_body {
2836 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2837 constant.1 = f32[] constant(0)
2838 broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
2839 get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
2840 get-tuple-element.3 = s32[] get-tuple-element(state), index=0
2841 constant.2 = s32[] constant(128)
2842 add.5 = s32[] add(get-tuple-element.3, constant.2)
2843 constant.3 = s32[] constant(0)
2844 dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
2845 dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
2846 ROOT tuple.85 = (s32[], f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
2847 }
2848
2849 while_condition {
2850 state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
2851 get-tuple-element = s32[] get-tuple-element(state), index=0
2852 get-tuple-element.1 = s32[] constant(3)
2853 ROOT less-than.339.338 = pred[] compare(get-tuple-element, get-tuple-element.1), direction=LT
2854 }
2855
2856 ENTRY entry_computation {
2857 constant.7 = s32[] constant(0)
2858 copy.1 = s32[] copy(constant.7)
2859 constant.6 = f32[] constant(0)
2860 broadcast.6 = f32[1280,1,128]{2,1,0} broadcast(constant.6), dimensions={}
2861 tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(copy.1, broadcast.6)
2862 while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body
2863 ROOT get-tuple-element.2 = s32[] get-tuple-element(while.0), index=0
2864 }
2865
2866 )";
2867 auto module = ParseAndReturnVerifiedModule(hlo_string).value();
2868
2869 RunCopyInsertion(module.get());
2870 auto assignment = RunBufferAssignment(module.get());
2871 // Get BufferAllocation for root instruction.
2872 auto dus9 = FindInstruction(module.get(), "dynamic-update-slice.9");
2873 auto dus9_alloc_slice = assignment->GetUniqueTopLevelSlice(dus9).value();
2874 auto dus5 = FindInstruction(module.get(), "dynamic-update-slice.5");
2875 auto dus5_alloc_slice = assignment->GetUniqueTopLevelSlice(dus5).value();
2876 // Test that the two dynamic-update-slice ops share the same allocation slice.
2877 EXPECT_EQ(dus9_alloc_slice.allocation(), dus5_alloc_slice.allocation());
2878 EXPECT_EQ(dus9_alloc_slice, dus5_alloc_slice);
2879 }
2880 } // namespace
2881 } // namespace xla
2882