1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/instruction_hoister.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22
23 namespace xla {
24 namespace {
25
26 namespace op = xla::testing::opcode_matchers;
27 using memory_space_assignment::AsynchronousCopy;
28 using memory_space_assignment::AsynchronousCopyResource;
29 using memory_space_assignment::CostAnalysisPrefetchIntervalPicker;
30 using memory_space_assignment::InstructionCountPrefetchIntervalPicker;
31 using memory_space_assignment::MemorySpaceAssignment;
32 using memory_space_assignment::MemorySpaceAssignmentCostAnalysis;
33 using memory_space_assignment::Options;
34 using memory_space_assignment::PrefetchIntervalPicker;
35 using memory_space_assignment::PresetAssignments;
36
37 constexpr int64_t kPointerSize = 8;
38 constexpr float kAsyncCopyBandwidth = 100;
39 constexpr float kAlternateMemBandwidth = 1000;
40 constexpr float kBytesPerSecond = 100;
41 constexpr float kFlopsPerSecond = 1000;
42 constexpr float kTranscendentalsPerSecond = 10;
43
ShapeSize(const Shape & shape)44 int64_t ShapeSize(const Shape& shape) {
45 return ShapeUtil::ByteSizeOf(shape, kPointerSize);
46 }
47
48 class MemorySpaceAssignmentTest : public HloTestBase,
49 public ::testing::WithParamInterface<bool> {
50 protected:
51 // We use the following two memory space values to describe the default (slow
52 // and large) and alternate (fast and small) memory spaces.
53 const int64_t kDefaultMemorySpace = 0;
54 const int64_t kAlternateMemorySpace = 1;
55
AssignMemorySpaceUsingCostAnalysis(HloModule * module,std::optional<Options> memory_space_assignment_options=std::nullopt)56 std::unique_ptr<PresetAssignments> AssignMemorySpaceUsingCostAnalysis(
57 HloModule* module,
58 std::optional<Options> memory_space_assignment_options = std::nullopt) {
59 HloCostAnalysis::Options cost_options{ShapeSize};
60 cost_options.set_flops_per_second(kFlopsPerSecond);
61 cost_options.set_bytes_per_second(kBytesPerSecond);
62 cost_options.set_transcendentals_per_second(kTranscendentalsPerSecond);
63 HloCostAnalysis hlo_cost_analysis(cost_options);
64 for (HloComputation* computation : module->MakeNonfusionComputations()) {
65 TF_CHECK_OK(computation->Accept(&hlo_cost_analysis));
66 }
67 auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
68
69 Options options;
70 if (memory_space_assignment_options.has_value()) {
71 options = *memory_space_assignment_options;
72 } else {
73 options.async_copy_bandwidth_bytes_per_second = kAsyncCopyBandwidth;
74 options.alternate_mem_bandwidth_bytes_per_second = kAlternateMemBandwidth;
75 }
76 auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create(
77 hlo_cost_analysis, options, *module)
78 .ValueOrDie();
79 CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
80 CostAnalysisPrefetchIntervalPicker(
81 *cost_analysis, /*min_overlap_to_async_copy_ratio=*/0.8,
82 /*preferred_overlap_to_async_copy_ratio=*/1.5,
83 /*max_overlap_to_mem_size_async_copy_ratio=*/10.0,
84 /*mem_size_bytes=*/128));
85 return AssignMemorySpace(
86 module, /*max_outstanding_async_copies=*/-1,
87 MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
88 *cost_analysis, &cache_),
89 &prefetch_interval_picker, memory_space_assignment_options);
90 }
91
AssignMemorySpace(HloModule * module,int64_t max_outstanding_async_copies=-1,int64_t max_prefetch_interval=10,int64_t min_prefetch_interval=2,std::optional<Options> options=std::nullopt)92 std::unique_ptr<PresetAssignments> AssignMemorySpace(
93 HloModule* module, int64_t max_outstanding_async_copies = -1,
94 int64_t max_prefetch_interval = 10, int64_t min_prefetch_interval = 2,
95 std::optional<Options> options = std::nullopt) {
96 InstructionHoister instruction_hoister;
97 TF_CHECK_OK(instruction_hoister.Run(module).status());
98 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(
99 min_prefetch_interval, max_prefetch_interval);
100 return AssignMemorySpace(module, max_outstanding_async_copies,
101 /*buffer_interval_compare=*/{},
102 &prefetch_interval_picker, options);
103 }
104
AssignMemorySpace(HloModule * module,int64_t max_outstanding_async_copies,std::optional<MemorySpaceAssignment::BufferIntervalCompare> buffer_interval_compare,PrefetchIntervalPicker * prefetch_interval_picker,std::optional<Options> memory_space_assignment_options=std::nullopt)105 std::unique_ptr<PresetAssignments> AssignMemorySpace(
106 HloModule* module, int64_t max_outstanding_async_copies,
107 std::optional<MemorySpaceAssignment::BufferIntervalCompare>
108 buffer_interval_compare,
109 PrefetchIntervalPicker* prefetch_interval_picker,
110 std::optional<Options> memory_space_assignment_options = std::nullopt) {
111 auto size_fn = [](const BufferValue& buffer) {
112 return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
113 };
114
115 auto is_allowed_in_alternate_mem = [](const HloValue& value) {
116 // Check if the value belongs to the entry computation.
117 HloInstruction* instruction = value.instruction();
118 HloComputation* computation = instruction->parent();
119 bool in_entry_computation =
120 (computation == computation->parent()->entry_computation());
121 if (in_entry_computation &&
122 instruction->opcode() == HloOpcode::kParameter) {
123 return false;
124 }
125 return true;
126 };
127
128 // Only check parameters in default memory if the original module didn't
129 // have the parameters in alternate memory.
130 bool check_parameters_in_default_memory = true;
131 for (const HloInstruction* parameter :
132 module->entry_computation()->parameter_instructions()) {
133 ShapeUtil::ForEachSubshape(
134 parameter->shape(),
135 [&](const Shape& subshape, const ShapeIndex& /*index*/) {
136 if (subshape.has_layout() &&
137 subshape.layout().memory_space() == kAlternateMemorySpace) {
138 check_parameters_in_default_memory = false;
139 }
140 });
141 }
142
143 Options options;
144 if (memory_space_assignment_options) {
145 options = *memory_space_assignment_options;
146 } else {
147 options.max_size_in_bytes = 128;
148 options.alignment_in_bytes = 8;
149 options.verify = true;
150 }
151
152 options.alternate_memory_space = kAlternateMemorySpace;
153 options.buffer_interval_compare = buffer_interval_compare;
154 options.prefetch_interval_picker = prefetch_interval_picker;
155 options.size_fn = size_fn;
156 if (options.is_allowed_in_alternate_mem_fn == nullptr) {
157 options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
158 }
159 options.max_outstanding_prefetches = max_outstanding_async_copies;
160 options.max_outstanding_evictions = max_outstanding_async_copies;
161 options.allocate_across_sequential_calls = GetParam();
162
163 auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
164 std::unique_ptr<HloLiveRange> hlo_live_range =
165 HloLiveRange::Run(module->schedule(), *alias_analysis,
166 module->entry_computation())
167 .ValueOrDie();
168
169 std::unique_ptr<PresetAssignments> preset_assignments =
170 MemorySpaceAssignment::Run(module, *hlo_live_range, *alias_analysis,
171 options)
172 .ValueOrDie();
173 if (check_parameters_in_default_memory) {
174 CheckParametersInDefaultMemory(module);
175 }
176 CheckRootInDefaultMemory(module);
177 CheckPresetAssignments(preset_assignments.get());
178 return preset_assignments;
179 }
180
CheckPresetAssignments(const PresetAssignments * preset_assignments)181 void CheckPresetAssignments(const PresetAssignments* preset_assignments) {
182 // Ensure that the exported preset assignments point to layouts in the
183 // alternate memory. Also ensure that the positions are unique. Note that
184 // we're using a std::set instead of absl::flat_hash_set because we can make
185 // use of HloPosition's comparator logic instead of providing a hasher.
186 std::set<HloPosition> positions_in_preset_assignments;
187 for (auto& position_and_chunk : preset_assignments->chunks()) {
188 HloPosition position = position_and_chunk.first;
189 EXPECT_EQ(positions_in_preset_assignments.find(position),
190 positions_in_preset_assignments.end());
191 positions_in_preset_assignments.insert(position);
192 const Shape& subshape =
193 ShapeUtil::GetSubshape(position.instruction->shape(), position.index);
194 EXPECT_EQ(subshape.layout().memory_space(), kAlternateMemorySpace)
195 << "Exported position is not in alternate mem: "
196 << position.ToString();
197 }
198 }
199
CheckParametersInDefaultMemory(const HloModule * module)200 void CheckParametersInDefaultMemory(const HloModule* module) {
201 // Check that all the entry parameter subshapes are placed in default
202 // memory.
203 const HloComputation* entry_computation = module->entry_computation();
204 for (const HloInstruction* parameter :
205 entry_computation->parameter_instructions()) {
206 ShapeUtil::ForEachSubshape(
207 parameter->shape(),
208 [&](const Shape& subshape, const ShapeIndex& /*index*/) {
209 if (subshape.has_layout()) {
210 EXPECT_NE(subshape.layout().memory_space(), kAlternateMemorySpace)
211 << "Parameter not in default memory: "
212 << parameter->ToString();
213 }
214 });
215 }
216 }
217
CheckRootInDefaultMemory(const HloModule * module)218 void CheckRootInDefaultMemory(const HloModule* module) {
219 const HloInstruction* root =
220 module->entry_computation()->root_instruction();
221 if (root->shape().IsArray()) {
222 EXPECT_EQ(root->shape().layout().memory_space(), kDefaultMemorySpace);
223 }
224 }
225
226 struct OutstandingAsyncCopies {
227 int64_t max_copies;
228 int64_t max_prefetches;
229 int64_t max_evictions;
230 };
231
CountMaximumOutstandingAsyncCopies(const HloModule & module)232 /*static*/ OutstandingAsyncCopies CountMaximumOutstandingAsyncCopies(
233 const HloModule& module) {
234 OutstandingAsyncCopies copies{0, 0, 0};
235 int64_t current_copies = 0;
236 int64_t current_prefetches = 0;
237 int64_t current_evictions = 0;
238 for (HloInstruction* instruction : module.schedule()
239 .sequence(module.entry_computation())
240 .instructions()) {
241 if (instruction->opcode() == HloOpcode::kCopyStart) {
242 current_copies++;
243 if (ShapeUtil::GetSubshape(instruction->shape(), {0})
244 .layout()
245 .memory_space() == kAlternateMemorySpace) {
246 current_prefetches++;
247 } else {
248 current_evictions++;
249 }
250 } else if (instruction->opcode() == HloOpcode::kCopyDone) {
251 current_copies--;
252 if (instruction->shape().layout().memory_space() ==
253 kAlternateMemorySpace) {
254 current_prefetches--;
255 } else {
256 current_evictions--;
257 }
258 }
259 copies.max_copies = std::max(copies.max_copies, current_copies);
260 copies.max_prefetches =
261 std::max(copies.max_prefetches, current_prefetches);
262 copies.max_prefetches = std::max(copies.max_evictions, current_evictions);
263 }
264 return copies;
265 }
266
GetAlternateMemoryOffset(const PresetAssignments & preset_assignments,const HloInstruction * instruction,const ShapeIndex & index={}) const267 int64_t GetAlternateMemoryOffset(const PresetAssignments& preset_assignments,
268 const HloInstruction* instruction,
269 const ShapeIndex& index = {}) const {
270 // Returns the offset of the assignment, -1 if it's not in the alternate
271 // memory.
272 const HloModule* module = instruction->parent()->parent();
273 auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
274 HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index);
275 for (auto& pos_and_chunk : preset_assignments.chunks()) {
276 for (auto& value : buffer.values()) {
277 if (pos_and_chunk.first == value->defining_position()) {
278 return pos_and_chunk.second.offset;
279 }
280 }
281 }
282 return -1;
283 }
284
CreateEvictAndPrefetchModule()285 std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
286 HloComputation::Builder builder(TestName());
287 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
288 HloInstruction* p0 =
289 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
290 HloInstruction* p1 =
291 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
292 HloInstruction* tanh = builder.AddInstruction(
293 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
294 // tanh should be placed in the alternate memory since there isn't much
295 // contention in the beginning. However, tanh has another consumer at the
296 // end. So it should be kicked out to default memory and prefetched back in.
297 // The graph below is meant to increase the contention to force
298 // eviction/prefetch behavior.
299 HloInstruction* a = builder.AddInstruction(
300 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
301 HloInstruction* b = builder.AddInstruction(
302 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
303 HloInstruction* c = builder.AddInstruction(
304 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
305 HloInstruction* d = builder.AddInstruction(
306 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
307 HloInstruction* e = builder.AddInstruction(
308 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
309 HloInstruction* f = builder.AddInstruction(
310 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
311 HloInstruction* g = builder.AddInstruction(
312 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
313 HloInstruction* h = builder.AddInstruction(
314 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
315 HloInstruction* i = builder.AddInstruction(
316 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
317 HloInstruction* j = builder.AddInstruction(
318 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
319 HloInstruction* k = builder.AddInstruction(
320 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
321 HloInstruction* l = builder.AddInstruction(
322 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
323 HloInstruction* m = builder.AddInstruction(
324 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
325 HloInstruction* n = builder.AddInstruction(
326 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
327 HloInstruction* o = builder.AddInstruction(
328 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
329 // tanh is being used at the root instruction, and this should be
330 // prefetched.
331 HloInstruction* add = builder.AddInstruction(
332 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
333
334 auto module = CreateNewVerifiedModule();
335 HloComputation* computation = module->AddEntryComputation(builder.Build());
336
337 HloSchedule schedule(module.get());
338 schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i,
339 j, k, l, m, n, o, add});
340 TF_CHECK_OK(module->set_schedule(schedule));
341 return module;
342 }
343
344 MemorySpaceAssignmentCostAnalysis::Cache cache_;
345 };
346
347 // For testing purposes, we define a cost analysis where we can control the
348 // elapsed times of each HLO and asynchronous copy.
349 class FakeMemorySpaceAssignmentCostAnalysis
350 : public MemorySpaceAssignmentCostAnalysis {
351 public:
352 static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis & cost_analysis,const HloModule & module,const Options & options)353 Create(const HloCostAnalysis& cost_analysis, const HloModule& module,
354 const Options& options) {
355 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
356 TF_ASSIGN_OR_RETURN(auto hlo_live_range,
357 HloLiveRange::Run(module.schedule(), *alias_analysis,
358 module.entry_computation()));
359 auto call_graph = CallGraph::Build(&module);
360 return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
361 cost_analysis, options, std::move(alias_analysis),
362 std::move(hlo_live_range), std::move(call_graph)));
363 }
364
GetInstructionElapsed(const HloInstruction & instruction) const365 float GetInstructionElapsed(
366 const HloInstruction& instruction) const override {
367 if (get_instruction_elapsed_override_) {
368 return get_instruction_elapsed_override_(instruction);
369 }
370 return 1.0;
371 }
372
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,absl::Span<const std::pair<int64_t,ShapeIndex>> operands_in_alternate_mem,absl::Span<const ShapeIndex> outputs_in_alternate_mem) const373 float GetInstructionElapsedInAlternateMemory(
374 const HloInstruction& instruction,
375 absl::Span<const std::pair<int64_t, ShapeIndex>>
376 operands_in_alternate_mem,
377 absl::Span<const ShapeIndex> outputs_in_alternate_mem) const override {
378 if (get_instruction_elapsed_in_alternate_memory_override_) {
379 return get_instruction_elapsed_in_alternate_memory_override_(
380 instruction, operands_in_alternate_mem, outputs_in_alternate_mem);
381 }
382 if (!operands_in_alternate_mem.empty()) {
383 return 0.5;
384 } else {
385 return 1.0;
386 }
387 }
388
GetAsyncCopyElapsed(const Shape & shape) const389 float GetAsyncCopyElapsed(const Shape& shape) const override {
390 if (get_async_copy_elapsed_override_) {
391 return get_async_copy_elapsed_override_(shape);
392 }
393 return 3.0;
394 }
395
396 // The following methods can be used to override what the above API calls
397 // return.
SetOverrideForGetInstructionElapsed(std::function<float (const HloInstruction &)> function)398 void SetOverrideForGetInstructionElapsed(
399 std::function<float(const HloInstruction&)> function) {
400 get_instruction_elapsed_override_ = function;
401 }
SetOverrideForGetInstructionElapsedInAlternateMemory(std::function<float (const HloInstruction &,absl::Span<const std::pair<int64_t,ShapeIndex>>,absl::Span<const ShapeIndex>)> function)402 void SetOverrideForGetInstructionElapsedInAlternateMemory(
403 std::function<float(const HloInstruction&,
404 absl::Span<const std::pair<int64_t, ShapeIndex>>,
405 absl::Span<const ShapeIndex>)>
406 function) {
407 get_instruction_elapsed_in_alternate_memory_override_ = function;
408 }
SetOverrideForGetAsyncCopyElapsed(std::function<float (const Shape &)> function)409 void SetOverrideForGetAsyncCopyElapsed(
410 std::function<float(const Shape&)> function) {
411 get_async_copy_elapsed_override_ = function;
412 }
413
414 protected:
FakeMemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,const Options & options,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range,std::unique_ptr<CallGraph> call_graph)415 FakeMemorySpaceAssignmentCostAnalysis(
416 const HloCostAnalysis& cost_analysis, const Options& options,
417 std::unique_ptr<HloAliasAnalysis> alias_analysis,
418 std::unique_ptr<HloLiveRange> hlo_live_range,
419 std::unique_ptr<CallGraph> call_graph)
420 : MemorySpaceAssignmentCostAnalysis(
421 cost_analysis, options, std::move(alias_analysis),
422 std::move(hlo_live_range), std::move(call_graph)) {}
423
424 private:
425 std::function<float(const HloInstruction&)>
426 get_instruction_elapsed_override_ = nullptr;
427 std::function<float(const HloInstruction&,
428 absl::Span<const std::pair<int64_t, ShapeIndex>>,
429 absl::Span<const ShapeIndex>)>
430 get_instruction_elapsed_in_alternate_memory_override_ = nullptr;
431 std::function<float(const Shape&)> get_async_copy_elapsed_override_ = nullptr;
432 };
433
TEST_P(MemorySpaceAssignmentTest,ParameterOnly)434 TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
435 // A module consisting of a single parameter. Inputs/outputs are currently
436 // excluded from memory space assignment.
437 HloComputation::Builder builder(TestName());
438 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
439 HloInstruction* p0 =
440 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
441
442 auto module = CreateNewVerifiedModule();
443 HloComputation* computation = module->AddEntryComputation(builder.Build());
444
445 HloSchedule schedule(module.get());
446 schedule.set_sequence(computation, {p0});
447 TF_CHECK_OK(module->set_schedule(schedule));
448
449 AssignMemorySpace(module.get());
450
451 EXPECT_THAT(p0, op::ShapeWithLayout(shape));
452 }
453
TEST_P(MemorySpaceAssignmentTest,Simple)454 TEST_P(MemorySpaceAssignmentTest, Simple) {
455 // A simple module with a few simple instructions. Expect this to be
456 // transformed with CopyStart and CopyDone instructions inserted after inputs
457 // and before outputs.
458 HloComputation::Builder builder(TestName());
459 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
460 HloInstruction* p0 =
461 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
462 HloInstruction* p1 =
463 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
464 HloInstruction* add = builder.AddInstruction(
465 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1));
466 HloInstruction* sub = builder.AddInstruction(
467 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
468 HloInstruction* mul = builder.AddInstruction(
469 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, sub));
470
471 auto module = CreateNewVerifiedModule();
472 HloComputation* computation = module->AddEntryComputation(builder.Build());
473
474 HloSchedule schedule(module.get());
475 schedule.set_sequence(computation, {p0, p1, add, sub, mul});
476 TF_CHECK_OK(module->set_schedule(schedule));
477
478 auto preset_assignments = AssignMemorySpace(module.get());
479
480 // Inputs and outputs are currently placed in the default memory. Everything
481 // else should be in the alternate memory.
482 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
483 F32, {2, 3},
484 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{},
485 /*tiles=*/{},
486 /*element_size_in_bits=*/0, kAlternateMemorySpace);
487 EXPECT_THAT(p0, op::ShapeWithLayout(shape));
488 EXPECT_THAT(p1, op::ShapeWithLayout(shape));
489 EXPECT_THAT(mul, op::ShapeWithLayout(shape));
490 EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
491 EXPECT_THAT(sub, op::ShapeWithLayout(shape_in_alternate_mem));
492
493 // Make sure the preset assignments is sane.
494 EXPECT_EQ(preset_assignments->chunks().size(), 3);
495 EXPECT_EQ(preset_assignments->assignment_informations().size(), 1);
496 // Ensure the offset assigned to add and sub are different.
497 EXPECT_NE(preset_assignments->chunks()[0].second.offset,
498 preset_assignments->chunks()[1].second.offset);
499 }
500
TEST_P(MemorySpaceAssignmentTest,NegateChain)501 TEST_P(MemorySpaceAssignmentTest, NegateChain) {
502 // The negate chain is long enough for asynchronous copy to be inserted
503 // between p1 and add.
504 HloComputation::Builder builder(TestName());
505 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
506 HloInstruction* p0 =
507 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
508 HloInstruction* p1 =
509 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
510 HloInstruction* negate0 = builder.AddInstruction(
511 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
512 HloInstruction* negate1 = builder.AddInstruction(
513 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
514 HloInstruction* negate2 = builder.AddInstruction(
515 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
516 HloInstruction* negate3 = builder.AddInstruction(
517 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
518 HloInstruction* negate4 = builder.AddInstruction(
519 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
520 HloInstruction* negate5 = builder.AddInstruction(
521 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
522 HloInstruction* negate6 = builder.AddInstruction(
523 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
524 HloInstruction* add = builder.AddInstruction(
525 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
526
527 auto module = CreateNewVerifiedModule();
528 HloComputation* computation = module->AddEntryComputation(builder.Build());
529
530 HloSchedule schedule(module.get());
531 schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
532 negate3, negate4, negate5, negate6, add});
533 TF_CHECK_OK(module->set_schedule(schedule));
534
535 AssignMemorySpace(module.get());
536
537 EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
538 kDefaultMemorySpace,
539 op::Parameter(1))));
540 // Parameters are in the default memory space.
541 EXPECT_THAT(p0, op::ShapeWithLayout(shape));
542 EXPECT_THAT(p1, op::ShapeWithLayout(shape));
543 // Negate instructions are in the alternate memory space (1).
544 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
545 F32, {2, 3},
546 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{},
547 /*tiles=*/{},
548 /*element_size_in_bits=*/0, kAlternateMemorySpace);
549 EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
550 EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
551 EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
552 EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
553 EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
554 EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem));
555 EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
556 // Ensure the CopyStart/CopyDone schedules.
557 const HloInstructionSequence& sequence =
558 module->schedule().sequence(computation);
559 EXPECT_THAT(sequence.instructions()[0], op::Parameter(0));
560 EXPECT_THAT(sequence.instructions()[1], op::Parameter(1));
561 EXPECT_THAT(sequence.instructions()[2], op::CopyStart());
562 EXPECT_THAT(sequence.instructions()[10], op::CopyDone());
563 }
564
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetch)565 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetch) {
566 std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
567
568 AssignMemorySpace(module.get());
569
570 EXPECT_THAT(
571 module->entry_computation()->root_instruction(),
572 op::Add(op::Add(),
573 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
574 op::AsyncCopy(kDefaultMemorySpace,
575 kAlternateMemorySpace, op::Tanh()))));
576 }
577
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies0)578 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) {
579 std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
580
581 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0);
582
583 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 0);
584 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 0);
585 }
586
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies1)587 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) {
588 std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
589
590 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
591
592 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 1);
593 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 1);
594 }
595
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies2)596 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) {
597 std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
598
599 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2);
600
601 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 2);
602 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 2);
603 }
604
605 // TODO(berkin): This test is broken with some prefetch timing improvements.
TEST_P(MemorySpaceAssignmentTest,DISABLED_DontEvictWhenThereIsDefaultMemAllocation)606 TEST_P(MemorySpaceAssignmentTest,
607 DISABLED_DontEvictWhenThereIsDefaultMemAllocation) {
608 // This test is the same as EvictAndPrefetchLimitAsyncCopies1, except we check
609 // that there is no eviction if not necessary (due to an existing allocation
610 // in default memory).
611 HloComputation::Builder builder(TestName());
612 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
613 HloInstruction* p0 =
614 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
615 HloInstruction* p1 =
616 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
617 HloInstruction* tanh = builder.AddInstruction(
618 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
619 // tanh should be placed in the alternate memory since there isn't much
620 // contention in the beginning. However, tanh has another consumer at the end.
621 // So it should be kicked out to default memory and prefetched back in. The
622 // graph below is meant to increase the contention to force eviction/prefetch
623 // behavior.
624 HloInstruction* a = builder.AddInstruction(
625 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
626 HloInstruction* b = builder.AddInstruction(
627 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
628 HloInstruction* c = builder.AddInstruction(
629 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
630 HloInstruction* d = builder.AddInstruction(
631 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
632 HloInstruction* e = builder.AddInstruction(
633 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
634 HloInstruction* f = builder.AddInstruction(
635 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
636 HloInstruction* g = builder.AddInstruction(
637 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
638 HloInstruction* h = builder.AddInstruction(
639 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
640 HloInstruction* i = builder.AddInstruction(
641 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
642 HloInstruction* j = builder.AddInstruction(
643 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
644 HloInstruction* k = builder.AddInstruction(
645 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
646 HloInstruction* l = builder.AddInstruction(
647 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
648 HloInstruction* m = builder.AddInstruction(
649 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
650 HloInstruction* n = builder.AddInstruction(
651 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
652 HloInstruction* o = builder.AddInstruction(
653 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
654 // tanh is being used at the root instruction, and this should be
655 // prefetched.
656 HloInstruction* add = builder.AddInstruction(
657 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
658
659 auto module = CreateNewVerifiedModule();
660 HloComputation* computation = module->AddEntryComputation(builder.Build());
661
662 HloSchedule schedule(module.get());
663 schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i,
664 j, k, l, m, n, o, add});
665 TF_CHECK_OK(module->set_schedule(schedule));
666
667 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
668
669 // We expect the second argument to multiply is prefetched c.
670 EXPECT_THAT(f, op::Multiply(op::Add(), op::CopyDone()));
671 // We make sure that the second argument to this multiply is not evicted
672 // CopyDone but is the original c.
673 EXPECT_THAT(h, op::Multiply(op::Subtract(), op::Multiply()));
674 }
675
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchAndPrefetch)676 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchAndPrefetch) {
677 // Test for a memory corruption bug involving evict/prefetch/prefetch pattern,
678 // where the last prefetch copied from the original buffer in alternate buffer
679 // instead of evicted buffer.
680 HloComputation::Builder builder(TestName());
681 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
682 HloInstruction* p0 =
683 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
684 HloInstruction* p1 =
685 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
686 HloInstruction* tanh = builder.AddInstruction(
687 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
688 HloInstruction* a = builder.AddInstruction(
689 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
690 HloInstruction* b = builder.AddInstruction(
691 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
692 HloInstruction* c = builder.AddInstruction(
693 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
694 HloInstruction* d = builder.AddInstruction(
695 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
696 HloInstruction* e = builder.AddInstruction(
697 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
698 HloInstruction* f = builder.AddInstruction(
699 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
700 HloInstruction* g = builder.AddInstruction(
701 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
702 HloInstruction* h = builder.AddInstruction(
703 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
704 HloInstruction* i = builder.AddInstruction(
705 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
706 HloInstruction* j = builder.AddInstruction(
707 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
708 HloInstruction* k = builder.AddInstruction(
709 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
710 HloInstruction* l = builder.AddInstruction(
711 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
712 HloInstruction* m = builder.AddInstruction(
713 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
714 HloInstruction* n = builder.AddInstruction(
715 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
716 HloInstruction* o = builder.AddInstruction(
717 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
718 HloInstruction* add0 = builder.AddInstruction(
719 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
720 HloInstruction* negate0 = builder.AddInstruction(
721 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add0));
722 HloInstruction* negate1 = builder.AddInstruction(
723 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
724 HloInstruction* negate2 = builder.AddInstruction(
725 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
726 HloInstruction* negate3 = builder.AddInstruction(
727 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
728 HloInstruction* negate4 = builder.AddInstruction(
729 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
730 HloInstruction* negate5 = builder.AddInstruction(
731 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
732 HloInstruction* negate6 = builder.AddInstruction(
733 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
734 HloInstruction* negate7 = builder.AddInstruction(
735 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
736 HloInstruction* negate8 = builder.AddInstruction(
737 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7));
738 HloInstruction* negate9 = builder.AddInstruction(
739 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8));
740 HloInstruction* add1 = builder.AddInstruction(
741 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate9, tanh));
742
743 auto module = CreateNewVerifiedModule();
744 HloComputation* computation = module->AddEntryComputation(builder.Build());
745
746 HloSchedule schedule(module.get());
747 schedule.set_sequence(
748 computation,
749 {p0, p1, tanh, a, b, c, d, e,
750 f, g, h, i, j, k, l, m,
751 n, o, add0, negate0, negate1, negate2, negate3, negate4,
752 negate5, negate6, negate7, negate8, negate9, add1});
753 TF_CHECK_OK(module->set_schedule(schedule));
754
755 AssignMemorySpace(module.get());
756
757 // Check that both prefetches (add0 and add1) prefetch from the eviction
758 // instead of tanh, which will be placed in the alternate memory directly.
759 EXPECT_THAT(
760 add0,
761 op::Add(op::Add(),
762 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
763 op::AsyncCopy(kDefaultMemorySpace,
764 kAlternateMemorySpace, op::Tanh()))));
765 EXPECT_THAT(
766 add1,
767 op::Add(op::Negate(),
768 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
769 op::AsyncCopy(kDefaultMemorySpace,
770 kAlternateMemorySpace, op::Tanh()))));
771 }
772
TEST_P(MemorySpaceAssignmentTest,While)773 TEST_P(MemorySpaceAssignmentTest, While) {
774 auto module = CreateNewVerifiedModule();
775 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
776 Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
777 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
778
779 auto cond_builder = HloComputation::Builder("WhileCond");
780 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
781 HloInstruction* cond_param = cond_builder.AddInstruction(
782 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
783 HloInstruction* cond_iter = cond_builder.AddInstruction(
784 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
785 HloInstruction* cond_limit = cond_builder.AddInstruction(
786 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
787 // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
788 HloInstruction* cond_lt = cond_builder.AddInstruction(
789 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
790 cond_limit, ComparisonDirection::kLt));
791 HloComputation* cond_computation =
792 module->AddEmbeddedComputation(cond_builder.Build());
793
794 auto body_builder = HloComputation::Builder("WhileBody");
795 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
796 HloInstruction* body_param = body_builder.AddInstruction(
797 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
798 HloInstruction* body_iter = body_builder.AddInstruction(
799 HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
800 HloInstruction* body_data = body_builder.AddInstruction(
801 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
802 HloInstruction* body_iter_increment = body_builder.AddInstruction(
803 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
804 HloInstruction* body_iter_next =
805 body_builder.AddInstruction(HloInstruction::CreateBinary(
806 scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
807 HloInstruction* body_data_increment =
808 body_builder.AddInstruction(HloInstruction::CreateConstant(
809 LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
810 HloInstruction* body_data_mul =
811 body_builder.AddInstruction(HloInstruction::CreateBinary(
812 shape, HloOpcode::kMultiply, body_data, body_data));
813 HloInstruction* body_data_add =
814 body_builder.AddInstruction(HloInstruction::CreateBinary(
815 shape, HloOpcode::kAdd, body_data, body_data_increment));
816 HloInstruction* body_data_next =
817 body_builder.AddInstruction(HloInstruction::CreateBinary(
818 shape, HloOpcode::kAdd, body_data_add, body_data_mul));
819 HloInstruction* body_out = body_builder.AddInstruction(
820 HloInstruction::CreateTuple({body_data_next, body_iter_next}));
821 HloComputation* body_computation =
822 module->AddEmbeddedComputation(body_builder.Build());
823
824 auto builder = HloComputation::Builder(TestName());
825 HloInstruction* data = builder.AddInstruction(
826 HloInstruction::CreateParameter(0, shape, "param_iter"));
827 HloInstruction* iter = builder.AddInstruction(
828 HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
829 HloInstruction* tuple =
830 builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
831 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
832 tuple_shape, cond_computation, body_computation, tuple));
833 HloComputation* entry_computation =
834 module->AddEntryComputation(builder.Build());
835
836 HloSchedule schedule(module.get());
837 schedule.set_sequence(cond_computation,
838 {cond_param, cond_iter, cond_limit, cond_lt});
839 schedule.set_sequence(body_computation,
840 {body_param, body_iter, body_data, body_iter_increment,
841 body_iter_next, body_data_increment, body_data_mul,
842 body_data_add, body_data_next, body_out});
843 schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
844 TF_CHECK_OK(module->set_schedule(schedule));
845
846 AssignMemorySpace(module.get());
847
848 // Ensure the tuple value and buffers used in the while instruction are
849 // exempted from using the alternate memory when allocating across sequential
850 // calls is disabled. However, body_data_mul is independent and can be safely
851 // be placed in the alternate memory.
852 const bool allocate_across_sequential_calls = GetParam();
853 if (!allocate_across_sequential_calls) {
854 EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape));
855 EXPECT_THAT(data, op::ShapeWithLayout(shape));
856 EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape));
857 EXPECT_THAT(body_data, op::ShapeWithLayout(shape));
858 EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape));
859 EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape));
860 }
861 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
862 F32, {2, 3},
863 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
864 /*element_size_in_bits=*/0, kAlternateMemorySpace);
865 EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem));
866 }
867
TEST_P(MemorySpaceAssignmentTest,Tuple)868 TEST_P(MemorySpaceAssignmentTest, Tuple) {
869 HloComputation::Builder builder(TestName());
870 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
871 Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape});
872 Shape tuple_shape =
873 ShapeUtil::MakeTupleShape({shape, shape, inner_tuple_shape});
874 HloInstruction* p = builder.AddInstruction(
875 HloInstruction::CreateParameter(0, tuple_shape, "p"));
876 HloInstruction* p0 = builder.AddInstruction(
877 HloInstruction::CreateGetTupleElement(shape, p, 0));
878 HloInstruction* negate0 = builder.AddInstruction(
879 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
880 HloInstruction* negate1 = builder.AddInstruction(
881 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
882 HloInstruction* negate2 = builder.AddInstruction(
883 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
884 HloInstruction* negate3 = builder.AddInstruction(
885 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
886 HloInstruction* negate4 = builder.AddInstruction(
887 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
888 HloInstruction* negate5 = builder.AddInstruction(
889 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
890 HloInstruction* negate6 = builder.AddInstruction(
891 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
892 HloInstruction* p1 = builder.AddInstruction(
893 HloInstruction::CreateGetTupleElement(shape, p, 1));
894 HloInstruction* add = builder.AddInstruction(
895 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
896 HloInstruction* p2 = builder.AddInstruction(
897 HloInstruction::CreateGetTupleElement(inner_tuple_shape, p, 2));
898 HloInstruction* p2_0 = builder.AddInstruction(
899 HloInstruction::CreateGetTupleElement(shape, p2, 0));
900 HloInstruction* mul = builder.AddInstruction(
901 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, p2_0));
902
903 auto module = CreateNewVerifiedModule();
904 HloComputation* computation = module->AddEntryComputation(builder.Build());
905
906 HloSchedule schedule(module.get());
907 schedule.set_sequence(
908 computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
909 negate6, p1, add, p2, p2_0, mul});
910 TF_CHECK_OK(module->set_schedule(schedule));
911
912 AssignMemorySpace(module.get());
913
914 EXPECT_THAT(
915 mul,
916 op::Multiply(op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
917 kDefaultMemorySpace,
918 op::GetTupleElement())),
919 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
920 op::GetTupleElement(op::GetTupleElement()))));
921 }
922
TEST_P(MemorySpaceAssignmentTest,Bitcast)923 TEST_P(MemorySpaceAssignmentTest, Bitcast) {
924 // Bitcasts can cause the position in the alternate memory to appear multiple
925 // times in the preset assignments. This test ensure the preset assignments
926 // refer to unique positions.
927 HloComputation::Builder builder(TestName());
928 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
929 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
930 HloInstruction* p0 =
931 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
932 HloInstruction* p1 = builder.AddInstruction(
933 HloInstruction::CreateParameter(1, param_shape, "p1"));
934 HloInstruction* negate = builder.AddInstruction(
935 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
936 HloInstruction* bitcast = builder.AddInstruction(
937 HloInstruction::CreateBitcast(param_shape, negate));
938 HloInstruction* add = builder.AddInstruction(
939 HloInstruction::CreateBinary(param_shape, HloOpcode::kAdd, bitcast, p1));
940
941 auto module = CreateNewVerifiedModule();
942 HloComputation* computation = module->AddEntryComputation(builder.Build());
943
944 HloSchedule schedule(module.get());
945 schedule.set_sequence(computation, {p0, p1, negate, bitcast, add});
946 TF_CHECK_OK(module->set_schedule(schedule));
947
948 AssignMemorySpace(module.get());
949
950 bitcast = add->mutable_operand(0);
951 EXPECT_EQ(bitcast->opcode(), HloOpcode::kBitcast);
952 EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
953 }
954
TEST_P(MemorySpaceAssignmentTest,Bitcast2)955 TEST_P(MemorySpaceAssignmentTest, Bitcast2) {
956 HloComputation::Builder builder(TestName());
957 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
958 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
959 HloInstruction* p0 =
960 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
961 HloInstruction* p1 = builder.AddInstruction(
962 HloInstruction::CreateParameter(1, param_shape, "p1"));
963 HloInstruction* negate0 = builder.AddInstruction(
964 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
965 HloInstruction* negate1 = builder.AddInstruction(
966 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
967 HloInstruction* negate2 = builder.AddInstruction(
968 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
969 HloInstruction* negate3 = builder.AddInstruction(
970 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
971 HloInstruction* negate4 = builder.AddInstruction(
972 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
973 HloInstruction* bitcast =
974 builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
975 HloInstruction* add = builder.AddInstruction(
976 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4));
977
978 auto module = CreateNewVerifiedModule();
979 HloComputation* computation = module->AddEntryComputation(builder.Build());
980
981 HloSchedule schedule(module.get());
982 schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
983 negate3, negate4, bitcast, add});
984 TF_CHECK_OK(module->set_schedule(schedule));
985
986 AssignMemorySpace(module.get());
987
988 EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
989 kAlternateMemorySpace);
990 }
991
TEST_P(MemorySpaceAssignmentTest,Bitcast3)992 TEST_P(MemorySpaceAssignmentTest, Bitcast3) {
993 HloComputation::Builder builder(TestName());
994 Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
995 Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
996 Shape shape3 = ShapeUtil::MakeShape(F32, {1, 6});
997 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
998 HloInstruction* p0 =
999 builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0"));
1000 HloInstruction* p1 = builder.AddInstruction(
1001 HloInstruction::CreateParameter(1, param_shape, "p1"));
1002 HloInstruction* negate0 = builder.AddInstruction(
1003 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, p0));
1004 HloInstruction* negate1 = builder.AddInstruction(
1005 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate0));
1006 HloInstruction* negate2 = builder.AddInstruction(
1007 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate1));
1008 HloInstruction* negate3 = builder.AddInstruction(
1009 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate2));
1010 HloInstruction* negate4 = builder.AddInstruction(
1011 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate3));
1012 HloInstruction* bitcast1 =
1013 builder.AddInstruction(HloInstruction::CreateBitcast(shape1, p1));
1014 HloInstruction* add = builder.AddInstruction(
1015 HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, bitcast1, negate4));
1016 HloInstruction* bitcast2 =
1017 builder.AddInstruction(HloInstruction::CreateBitcast(shape3, p1));
1018 HloInstruction* bitcast3 =
1019 builder.AddInstruction(HloInstruction::CreateBitcast(shape2, bitcast2));
1020 HloInstruction* bitcast4 =
1021 builder.AddInstruction(HloInstruction::CreateBitcast(shape2, add));
1022 HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
1023 shape2, HloOpcode::kMultiply, bitcast3, bitcast4));
1024
1025 auto module = CreateNewVerifiedModule();
1026 HloComputation* computation = module->AddEntryComputation(builder.Build());
1027
1028 HloSchedule schedule(module.get());
1029 schedule.set_sequence(computation,
1030 {p0, p1, negate0, negate1, negate2, negate3, negate4,
1031 bitcast1, add, bitcast2, bitcast3, bitcast4, mul});
1032 TF_CHECK_OK(module->set_schedule(schedule));
1033
1034 AssignMemorySpace(module.get());
1035
1036 // We expect one bitcast on the LHS of multiply since bitcast(bitcast(foo)) is
1037 // converted to bitcast(foo).
1038 EXPECT_THAT(
1039 mul,
1040 op::Multiply(
1041 op::Bitcast(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
1042 op::Parameter(1))),
1043 op::Bitcast(op::Add(
1044 op::Bitcast(op::AsyncCopy(kAlternateMemorySpace,
1045 kDefaultMemorySpace, op::Parameter(1))),
1046 op::Negate()))));
1047 EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
1048 kAlternateMemorySpace);
1049 EXPECT_EQ(add->shape().layout().memory_space(), kAlternateMemorySpace);
1050 // bitcast2 will no longer have a consumer and should get DCE'd, so we don't
1051 // care about its memory space.
1052 EXPECT_EQ(mul->operand(0)->shape().layout().memory_space(),
1053 kAlternateMemorySpace);
1054 EXPECT_EQ(mul->operand(1)->shape().layout().memory_space(),
1055 kAlternateMemorySpace);
1056 }
1057
TEST_P(MemorySpaceAssignmentTest,BitcastTuple)1058 TEST_P(MemorySpaceAssignmentTest, BitcastTuple) {
1059 HloComputation::Builder builder(TestName());
1060 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1061 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1062 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
1063
1064 auto module = CreateNewVerifiedModule();
1065 HloComputation::Builder fusion_builder("fusion");
1066 HloInstruction* fusion_param = fusion_builder.AddInstruction(
1067 HloInstruction::CreateParameter(0, tuple_shape, "p"));
1068 HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
1069 HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
1070 HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
1071 HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
1072 fusion_builder.AddInstruction(HloInstruction::CreateBinary(
1073 shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
1074 HloComputation* fusion_computation =
1075 module->AddEmbeddedComputation(fusion_builder.Build());
1076
1077 HloInstruction* p0 =
1078 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
1079 HloInstruction* p1 = builder.AddInstruction(
1080 HloInstruction::CreateParameter(1, param_shape, "p1"));
1081 HloInstruction* negate0 = builder.AddInstruction(
1082 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
1083 HloInstruction* negate1 = builder.AddInstruction(
1084 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1085 HloInstruction* negate2 = builder.AddInstruction(
1086 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1087 HloInstruction* negate3 = builder.AddInstruction(
1088 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1089 HloInstruction* negate4 = builder.AddInstruction(
1090 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1091 HloInstruction* bitcast =
1092 builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
1093 HloInstruction* tuple =
1094 builder.AddInstruction(HloInstruction::CreateTuple({bitcast, p0}));
1095 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
1096 shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
1097
1098 HloComputation* computation = module->AddEntryComputation(builder.Build());
1099
1100 HloSchedule schedule(module.get());
1101 schedule.set_sequence(computation,
1102 {p0, p1, negate0, negate1, negate2, negate3, negate4,
1103 bitcast, tuple, fusion});
1104 TF_CHECK_OK(module->set_schedule(schedule));
1105
1106 AssignMemorySpace(module.get());
1107 }
1108
TEST_P(MemorySpaceAssignmentTest,BitcastGetTupleElementTuple)1109 TEST_P(MemorySpaceAssignmentTest, BitcastGetTupleElementTuple) {
1110 // This test pattern was encountered in
1111 // //third_party/tensorflow/compiler/xla/tests:slice_test and was causing a
1112 // breakage when there is a GetTupleElement(Tuple(Bitcast())) pattern. Also
1113 // added a GetTupleElement(GetTupleElement(Tuple(Tuple(Bitcast())))) pattern.
1114 absl::string_view hlo_string = R"(
1115 HloModule DoIt_S64_10_0_5_1.3, is_scheduled=true
1116
1117 ENTRY %DoIt_S64_10_0_5_1.3 (p0.1: (u32[10], u32[10])) -> (u32[5], u32[5]) {
1118 %p0.1 = (u32[10]{0:T(128)}, u32[10]{0:T(128)}) parameter(0)
1119 %get-tuple-element.1 = u32[10]{0:T(128)} get-tuple-element((u32[10]{0:T(128)}, u32[10]{0:T(128)}) %p0.1), index=1
1120 %bitcast.1 = u32[5]{0:T(128)} bitcast(u32[10]{0:T(128)} %get-tuple-element.1)
1121 %get-tuple-element = u32[10]{0:T(128)} get-tuple-element((u32[10]{0:T(128)}, u32[10]{0:T(128)}) %p0.1), index=0
1122 %bitcast = u32[5]{0:T(128)} bitcast(u32[10]{0:T(128)} %get-tuple-element)
1123 %tuple.1 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) tuple(u32[5]{0:T(128)} %bitcast, u32[5]{0:T(128)} %bitcast.1)
1124 %tuple.3 = ((u32[5]{0:T(128)}, u32[5]{0:T(128)}), (u32[5]{0:T(128)}, u32[5]{0:T(128)})) tuple(%tuple.1, %tuple.1)
1125 %get-tuple-element.4 = u32[5]{0:T(128)} get-tuple-element((u32[5]{0:T(128)}, u32[5]{0:T(128)}) %tuple.1), index=0
1126 %get-tuple-element.5 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) get-tuple-element(%tuple.3), index=0
1127 %get-tuple-element.6 = u32[5]{0:T(128)} get-tuple-element((u32[5]{0:T(128)}, u32[5]{0:T(128)}) %get-tuple-element.5), index=1
1128 %copy.2 = u32[5]{0:T(128)} copy(u32[5]{0:T(128)} %get-tuple-element.4)
1129 %copy.3 = u32[5]{0:T(128)} copy(u32[5]{0:T(128)} %get-tuple-element.6)
1130 ROOT %tuple.2 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) tuple(u32[5]{0:T(128)} %copy.2, u32[5]{0:T(128)} %copy.3)
1131 }
1132 )";
1133
1134 TF_ASSERT_OK_AND_ASSIGN(auto module,
1135 ParseAndReturnVerifiedModule(hlo_string));
1136 AssignMemorySpace(module.get());
1137 }
1138
TEST_P(MemorySpaceAssignmentTest,GetSimplifiedOperandBug)1139 TEST_P(MemorySpaceAssignmentTest, GetSimplifiedOperandBug) {
1140 // Test case for a bug finding Bitcasts in GTE(Tuple(...)) pattern.
1141 absl::string_view hlo_string = R"(
1142 HloModule sort.16, is_scheduled=true
1143
1144 ENTRY %sort.16 (param.0.1: s32[1], param.1.2: f32[1], param.2.3: u32[1], param.3.4: s32[1]) -> (s32[1], f32[1], u32[1], s32[1]) {
1145 %param.3.4 = s32[1]{0:T(128)} parameter(3)
1146 %param.2.3 = u32[1]{0:T(128)} parameter(2)
1147 %param.1.2 = f32[1]{0:T(128)} parameter(1)
1148 %param.0.1 = s32[1]{0:T(128)} parameter(0)
1149 %tuple.1 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %param.0.1, f32[1]{0:T(128)} %param.1.2, u32[1]{0:T(128)} %param.2.3, s32[1]{0:T(128)} %param.3.4)
1150 %get-tuple-element.4 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=0
1151 %get-tuple-element.5 = f32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=1
1152 %get-tuple-element.6 = u32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=2
1153 %get-tuple-element.7 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=3
1154 %copy.4 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.4)
1155 %copy.5 = f32[1]{0:T(128)} copy(f32[1]{0:T(128)} %get-tuple-element.5)
1156 %copy.6 = u32[1]{0:T(128)} copy(u32[1]{0:T(128)} %get-tuple-element.6)
1157 %copy.7 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.7)
1158 ROOT %tuple.2 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %copy.4, f32[1]{0:T(128)} %copy.5, u32[1]{0:T(128)} %copy.6, s32[1]{0:T(128)} %copy.7)
1159 }
1160 )";
1161
1162 TF_ASSERT_OK_AND_ASSIGN(auto module,
1163 ParseAndReturnVerifiedModule(hlo_string));
1164 AssignMemorySpace(module.get());
1165 }
1166
TEST_P(MemorySpaceAssignmentTest,BitcastMultiUse)1167 TEST_P(MemorySpaceAssignmentTest, BitcastMultiUse) {
1168 // When there is a pattern where a bitcast has multiple uses (negate0 and add)
1169 // and one is in the default memory and the other is in alternate memory, they
1170 // both need their own bitcast.
1171 HloComputation::Builder builder(TestName());
1172 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1173 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1174 HloInstruction* p0 = builder.AddInstruction(
1175 HloInstruction::CreateParameter(0, param_shape, "p1"));
1176 HloInstruction* bitcast =
1177 builder.AddInstruction(HloInstruction::CreateBitcast(shape, p0));
1178 HloInstruction* negate0 = builder.AddInstruction(
1179 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, bitcast));
1180 HloInstruction* negate1 = builder.AddInstruction(
1181 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1182 HloInstruction* negate2 = builder.AddInstruction(
1183 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1184 HloInstruction* negate3 = builder.AddInstruction(
1185 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1186 HloInstruction* negate4 = builder.AddInstruction(
1187 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1188 HloInstruction* add = builder.AddInstruction(
1189 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4));
1190
1191 auto module = CreateNewVerifiedModule();
1192 HloComputation* computation = module->AddEntryComputation(builder.Build());
1193
1194 HloSchedule schedule(module.get());
1195 schedule.set_sequence(computation, {p0, bitcast, negate0, negate1, negate2,
1196 negate3, negate4, add});
1197 TF_CHECK_OK(module->set_schedule(schedule));
1198
1199 AssignMemorySpace(module.get());
1200 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
1201 F32, {2, 3},
1202 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
1203 /*element_size_in_bits=*/0, kAlternateMemorySpace);
1204 EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape));
1205 EXPECT_THAT(add->operand(0), op::ShapeWithLayout(shape_in_alternate_mem));
1206 }
1207
TEST_P(MemorySpaceAssignmentTest,BitcastMultiUseTuple)1208 TEST_P(MemorySpaceAssignmentTest, BitcastMultiUseTuple) {
1209 // Same as BitcastMultUse but the second use is a tuple.
1210 HloComputation::Builder builder(TestName());
1211 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1212 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1213 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
1214
1215 auto module = CreateNewVerifiedModule();
1216 HloComputation::Builder fusion_builder("fusion");
1217 HloInstruction* fusion_param = fusion_builder.AddInstruction(
1218 HloInstruction::CreateParameter(0, tuple_shape, "p"));
1219 HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
1220 HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
1221 HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
1222 HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
1223 fusion_builder.AddInstruction(HloInstruction::CreateBinary(
1224 shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
1225 HloComputation* fusion_computation =
1226 module->AddEmbeddedComputation(fusion_builder.Build());
1227
1228 HloInstruction* p0 = builder.AddInstruction(
1229 HloInstruction::CreateParameter(0, param_shape, "p1"));
1230 HloInstruction* bitcast =
1231 builder.AddInstruction(HloInstruction::CreateBitcast(shape, p0));
1232 HloInstruction* negate0 = builder.AddInstruction(
1233 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, bitcast));
1234 HloInstruction* negate1 = builder.AddInstruction(
1235 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1236 HloInstruction* negate2 = builder.AddInstruction(
1237 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1238 HloInstruction* negate3 = builder.AddInstruction(
1239 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1240 HloInstruction* negate4 = builder.AddInstruction(
1241 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1242 HloInstruction* tuple =
1243 builder.AddInstruction(HloInstruction::CreateTuple({bitcast, negate4}));
1244 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
1245 shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
1246
1247 HloComputation* computation = module->AddEntryComputation(builder.Build());
1248
1249 HloSchedule schedule(module.get());
1250 schedule.set_sequence(computation, {p0, bitcast, negate0, negate1, negate2,
1251 negate3, negate4, tuple, fusion});
1252 TF_CHECK_OK(module->set_schedule(schedule));
1253
1254 AssignMemorySpace(module.get());
1255 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
1256 F32, {2, 3},
1257 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
1258 /*element_size_in_bits=*/0, kAlternateMemorySpace);
1259 EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape));
1260 EXPECT_THAT(fusion->operand(0)->operand(0),
1261 op::ShapeWithLayout(shape_in_alternate_mem));
1262 }
1263
TEST_P(MemorySpaceAssignmentTest,BitcastScheduleBug)1264 TEST_P(MemorySpaceAssignmentTest, BitcastScheduleBug) {
1265 // Bitcasts can force asynchronous copies to be scheduled too early, possibly
1266 // leading to memory corruption.
1267 // Bug:
1268 // p0------------------>neg-->neg-->neg ... -->neg-->neg-->neg->add
1269 // /
1270 // p1->cs->cd->bitcast-----------------------------------------+
1271 //
1272 // Expected:
1273 // p0-->neg-->neg-->neg ... -->neg-->neg-->neg------------->add
1274 // /
1275 // p1--------------------->cs----------------->cd->bitcast-+
1276 HloComputation::Builder builder(TestName());
1277 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1278 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1279 HloInstruction* p0 =
1280 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
1281 HloInstruction* p1 = builder.AddInstruction(
1282 HloInstruction::CreateParameter(1, param_shape, "p1"));
1283 HloInstruction* bitcast =
1284 builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
1285 HloInstruction* negate0 = builder.AddInstruction(
1286 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
1287 HloInstruction* negate1 = builder.AddInstruction(
1288 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1289 HloInstruction* negate2 = builder.AddInstruction(
1290 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1291 HloInstruction* negate3 = builder.AddInstruction(
1292 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1293 HloInstruction* negate4 = builder.AddInstruction(
1294 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1295 HloInstruction* negate5 = builder.AddInstruction(
1296 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
1297 HloInstruction* negate6 = builder.AddInstruction(
1298 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
1299 HloInstruction* negate7 = builder.AddInstruction(
1300 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
1301 HloInstruction* negate8 = builder.AddInstruction(
1302 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7));
1303 HloInstruction* negate9 = builder.AddInstruction(
1304 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8));
1305 HloInstruction* add = builder.AddInstruction(
1306 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate9));
1307
1308 auto module = CreateNewVerifiedModule();
1309 HloComputation* computation = module->AddEntryComputation(builder.Build());
1310
1311 HloSchedule schedule(module.get());
1312 schedule.set_sequence(
1313 computation, {p0, p1, bitcast, negate0, negate1, negate2, negate3,
1314 negate4, negate5, negate6, negate7, negate8, negate9, add});
1315 TF_CHECK_OK(module->set_schedule(schedule));
1316
1317 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
1318 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/4);
1319
1320 EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
1321 kAlternateMemorySpace);
1322 const auto& instructions =
1323 module->schedule().sequence(module->entry_computation()).instructions();
1324 for (int i = 0; i < instructions.size(); ++i) {
1325 // Expect that there is a negate before and after the CopyStart and there is
1326 // a negate before CopyDone.
1327 if (instructions.at(i)->opcode() == HloOpcode::kCopyStart) {
1328 EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate);
1329 EXPECT_EQ(instructions.at(i + 1)->opcode(), HloOpcode::kNegate);
1330 } else if (instructions.at(i)->opcode() == HloOpcode::kCopyDone) {
1331 EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate);
1332 }
1333 }
1334 }
1335
TEST_P(MemorySpaceAssignmentTest,AddDependency)1336 TEST_P(MemorySpaceAssignmentTest, AddDependency) {
1337 // Make sure add-dependency is not optimized away.
1338 absl::string_view hlo_string = R"(
1339 HloModule AddDependency, is_scheduled=true
1340
1341 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
1342 %p = f32[3]{0} parameter(0)
1343 %neg0 = f32[3]{0} negate(f32[3]{0} %p)
1344 %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
1345 %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
1346 %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
1347 %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
1348 %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
1349 %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
1350 %token0 = token[] after-all()
1351 %add_dep = f32[3]{0} add-dependency(f32[3]{0} %p, token[] %token0)
1352 ROOT %add = f32[3]{0} add(f32[3]{0} %add_dep, f32[3]{0} %neg6)
1353 }
1354 )";
1355
1356 TF_ASSERT_OK_AND_ASSIGN(auto module,
1357 ParseAndReturnVerifiedModule(hlo_string));
1358 AssignMemorySpace(module.get());
1359
1360 EXPECT_THAT(module->entry_computation()->root_instruction(),
1361 op::Add(op::AddDependency(), op::Negate()));
1362 }
1363
TEST_P(MemorySpaceAssignmentTest,WhileAllocationBug)1364 TEST_P(MemorySpaceAssignmentTest, WhileAllocationBug) {
1365 // This test is carefully crafted to include two multiply ops sized [4,3] in a
1366 // while body. For testing purposes, we have provided a BufferIntervalCompare
1367 // such that first multiply, then tanh, then other HloValues will be
1368 // allocated. The memory is sized just enough to fit two [4,3] buffers.
1369 // Because the multiplies in the while body are going to be allocated in the
1370 // alternate memory first, the tanh that is fed inside the while loop should
1371 // not be placed in the alternate memory. Otherwise, we will corrupt memory.
1372 absl::string_view hlo_string = R"(
1373 HloModule WhileAllocationBug, is_scheduled=true
1374
1375 %WhileBody (body_param: (f32[4,3], f32[])) -> (f32[4,3], f32[]) {
1376 %body_param = (f32[4,3]{1,0}, f32[]) parameter(0)
1377 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[]) %body_param), index=1
1378 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[]) %body_param), index=0
1379 %constant.1 = f32[] constant(1)
1380 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1381 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1382 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.2)
1383 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1384 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1385 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1386 ROOT %tuple = (f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[] %add)
1387 }
1388
1389 %WhileCond (cond_param: (f32[4,3], f32[])) -> pred[] {
1390 %cond_param = (f32[4,3]{1,0}, f32[]) parameter(0)
1391 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[]) %cond_param), index=1
1392 %constant = f32[] constant(50)
1393 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1394 }
1395
1396 ENTRY %Entry (param_iter: f32[4,3], param_data: f32[], p2: f32[4,3]) -> f32[4,3] {
1397 %param_data = f32[] parameter(1)
1398 %param_iter = f32[4,3]{1,0} parameter(0)
1399 %p2 = f32[4,3]{1,0} parameter(2)
1400 %tanh = f32[4,3]{1,0} tanh(f32[4,3]{1,0} %param_iter)
1401 %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1402 %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1403 %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1404 %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1405 %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1406 %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1407 %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1408 %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %tanh)
1409 %tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %tanh, f32[] %param_data)
1410 %while = (f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1411 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[]) %while), index=0
1412 ROOT %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.3, f32[4,3]{1,0} %add.4)
1413 }
1414 )";
1415
1416 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
1417 [](const MemorySpaceAssignment::BufferInterval& a,
1418 const MemorySpaceAssignment::BufferInterval& b) {
1419 bool a_is_mul =
1420 a.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply;
1421 bool b_is_mul =
1422 b.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply;
1423 if (a_is_mul && !b_is_mul) {
1424 return true;
1425 }
1426 if (!a_is_mul && b_is_mul) {
1427 return false;
1428 }
1429 bool a_is_tanh =
1430 a.buffer->defining_instruction()->opcode() == HloOpcode::kTanh;
1431 bool b_is_tanh =
1432 b.buffer->defining_instruction()->opcode() == HloOpcode::kTanh;
1433 if (a_is_tanh && !b_is_tanh) {
1434 return true;
1435 }
1436 if (!a_is_tanh && b_is_tanh) {
1437 return false;
1438 }
1439 return a.buffer->id() < b.buffer->id();
1440 };
1441 TF_ASSERT_OK_AND_ASSIGN(auto module,
1442 ParseAndReturnVerifiedModule(hlo_string));
1443
1444 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
1445 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
1446 buffer_interval_compare, &prefetch_interval_picker);
1447
1448 for (const HloInstruction* instruction :
1449 module->entry_computation()->instructions()) {
1450 if (instruction->opcode() == HloOpcode::kWhile) {
1451 const Shape& while_subshape =
1452 ShapeUtil::GetSubshape(instruction->shape(), {0});
1453 // We expect shape {0} to either be in default memory for the entire while
1454 // loop or there has to be an eviction within the while loop.
1455 if (while_subshape.layout().memory_space() == kAlternateMemorySpace) {
1456 const HloInstruction* body_param =
1457 instruction->while_body()->parameter_instruction(0);
1458 const HloInstruction* gte = nullptr;
1459 for (const HloInstruction* user : body_param->users()) {
1460 if (user->opcode() == HloOpcode::kGetTupleElement &&
1461 user->tuple_index() == 0) {
1462 gte = user;
1463 break;
1464 }
1465 }
1466 EXPECT_NE(gte, nullptr);
1467 const HloInstruction* copy_start = nullptr;
1468 for (const HloInstruction* user : gte->users()) {
1469 if (user->opcode() == HloOpcode::kCopyStart) {
1470 copy_start = user;
1471 break;
1472 }
1473 }
1474 EXPECT_NE(copy_start, nullptr);
1475 const Shape& copy_start_subshape =
1476 ShapeUtil::GetSubshape(copy_start->shape(), {0});
1477
1478 EXPECT_NE(copy_start_subshape.layout().memory_space(),
1479 kAlternateMemorySpace);
1480 }
1481 }
1482 }
1483 }
1484
TEST_P(MemorySpaceAssignmentTest,ConsecutiveWhileLoops)1485 TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoops) {
1486 absl::string_view hlo_string = R"(
1487 HloModule WhileAllocationBug, is_scheduled=true
1488
1489 %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1490 %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1491 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1492 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1493 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1494 %constant.1 = f32[] constant(1)
1495 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1496 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1497 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
1498 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1499 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1500 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1501 ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1502 }
1503
1504 %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1505 %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1506 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1507 %constant = f32[] constant(50)
1508 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1509 }
1510
1511 %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1512 %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1513 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1514 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1515 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1516 %constant.1 = f32[] constant(1)
1517 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1518 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1519 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
1520 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1521 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1522 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1523 ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1524 }
1525
1526 %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1527 %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1528 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1529 %constant = f32[] constant(50)
1530 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1531 }
1532
1533 ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1534 %param_iter = f32[] parameter(1)
1535 %param_data = f32[4,3]{1,0} parameter(0)
1536 %p2 = f32[4,3]{1,0} parameter(2)
1537 %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1538 %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1539 %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1540 %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1541 %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1542 %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1543 %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1544 %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1545 %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1546 %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1547 %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1548 %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1549 %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
1550 %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} get-tuple-element.5, f32[] %param_iter)
1551 %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
1552 %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
1553 ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.6, f32[4,3]{1,0} %add.3)
1554 }
1555 )";
1556
1557 TF_ASSERT_OK_AND_ASSIGN(auto module,
1558 ParseAndReturnVerifiedModule(hlo_string));
1559 AssignMemorySpace(module.get());
1560 }
1561
TEST_P(MemorySpaceAssignmentTest,WhileLiveRangeBug)1562 TEST_P(MemorySpaceAssignmentTest, WhileLiveRangeBug) {
1563 // Tests against while live ranges being incorrect and the verifier
1564 // complaining about a conflict.
1565 absl::string_view hlo_string = R"(
1566 HloModule WhileAllocationBug, is_scheduled=true
1567
1568 %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1569 %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1570 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1571 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1572 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1573 %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1574 %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1575 %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1576 %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1577 %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1578 %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1579 %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1580 %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1581 %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1582 %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1583 %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1584 %constant.1 = f32[] constant(1)
1585 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1586 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1587 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1588 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1589 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1590 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1591 ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1592 }
1593
1594 %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1595 %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1596 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1597 %constant = f32[] constant(50)
1598 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1599 }
1600
1601 ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1602 %param_iter = f32[] parameter(1)
1603 %param_data = f32[4,3]{1,0} parameter(0)
1604 %p2 = f32[4,3]{1,0} parameter(2)
1605 %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1606 %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1607 %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1608 %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1609 %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1610 %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1611 %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1612 %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1613 %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1614 %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1615 %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1616 %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
1617 %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1618 ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %add.3)
1619 }
1620 )";
1621
1622 TF_ASSERT_OK_AND_ASSIGN(auto module,
1623 ParseAndReturnVerifiedModule(hlo_string));
1624 AssignMemorySpace(module.get());
1625 }
1626
TEST_P(MemorySpaceAssignmentTest,ConsecutiveWhileLoopsOneBuffer)1627 TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoopsOneBuffer) {
1628 // Tests against a bug when there are consecutive while loops with one buffer
1629 // (the value doesn't change in the buffer), the parameter can be colored in
1630 // the alternate memory space.
1631 absl::string_view hlo_string = R"(
1632 HloModule WhileAllocationBug, is_scheduled=true
1633
1634 %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1635 %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1636 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1637 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1638 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1639 %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1640 %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1641 %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1642 %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1643 %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1644 %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1645 %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1646 %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1647 %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1648 %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1649 %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1650 %constant.1 = f32[] constant(1)
1651 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1652 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1653 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1654 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1655 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1656 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1657 ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1658 }
1659
1660 %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1661 %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1662 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1663 %constant = f32[] constant(50)
1664 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1665 }
1666
1667 %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1668 %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1669 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1670 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1671 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1672 %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1673 %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1674 %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1675 %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1676 %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1677 %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1678 %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1679 %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1680 %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1681 %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1682 %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1683 %constant.1 = f32[] constant(1)
1684 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1685 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1686 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1687 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1688 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1689 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1690 ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1691 }
1692
1693 %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1694 %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1695 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1696 %constant = f32[] constant(50)
1697 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1698 }
1699
1700 ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1701 %param_iter = f32[] parameter(1)
1702 %param_data = f32[4,3]{1,0} parameter(0)
1703 %p2 = f32[4,3]{1,0} parameter(2)
1704 %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1705 %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1706 %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1707 %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1708 %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1709 %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1710 %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1711 %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1712 %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1713 %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1714 %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1715 %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1716 %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} param_data, f32[] %param_iter)
1717 %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
1718 %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
1719 %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=1
1720 ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %get-tuple-element.6)
1721 }
1722 )";
1723
1724 TF_ASSERT_OK_AND_ASSIGN(auto module,
1725 ParseAndReturnVerifiedModule(hlo_string));
1726 AssignMemorySpace(module.get());
1727 }
1728
TEST_P(MemorySpaceAssignmentTest,WhileCondAliasBug)1729 TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
1730 // While loop is the root of the entry computation. We should ensure the
1731 // output of the entry computation remains to be in default memory space.
1732 // Test from //third_party/tensorflow/compiler/xla/tests:while_test
1733 // WhileTest.WhileWithPrngScalarResult.
1734 absl::string_view hlo_string = R"(
1735 HloModule WhileWithPrngScalarResult.18, is_scheduled=true
1736
1737 %fused_computation (param_0.1: s32[6], param_1.3: s32[1], param_2.3: s32[5]) -> s32[6] {
1738 %param_1.3 = s32[1]{0:T(128)} parameter(1)
1739 %constant.2 = s32[]{:T(128)} constant(-2147483648)
1740 %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
1741 %param_2.3 = s32[5]{0:T(128)} parameter(2)
1742 %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
1743 %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
1744 %param_0.1 = s32[6]{0:T(128)} parameter(0)
1745 ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
1746 }
1747
1748 %body.3 (prev.4: s32[6]) -> s32[6] {
1749 %constant.7 = s32[]{:T(128)} constant(100)
1750 %constant.6 = s32[]{:T(128)} constant(0)
1751 %constant.5 = s32[1]{0:T(128)} constant({1})
1752 %prev.4 = s32[6]{0:T(128)} parameter(0)
1753 %rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
1754 %neg = s32[1]{0:T(128)} negate(s32[1]{0:T(128)} %constant.5)
1755 ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %neg, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
1756 }
1757
1758 %WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
1759 %constant.15 = s32[]{:T(128)} constant(1)
1760 %prev.12 = s32[6]{0:T(128)} parameter(0)
1761 %bitcast.1 = s32[1]{0:T(128)} bitcast(s32[6]{0:T(128)} %prev.12)
1762 %bitcast = s32[]{:T(128)} bitcast(s32[1]{0:T(128)} %bitcast.1)
1763 ROOT %compare.16 = pred[]{:T(128)E(32)} compare(s32[]{:T(128)} %constant.15, s32[]{:T(128)} %bitcast), direction=GT
1764 }
1765
1766 ENTRY %WhileWithPrngScalarResult.18 () -> s32[6] {
1767 %constant.1 = s32[]{:T(128)} constant(0)
1768 %broadcast.2 = s32[6]{0:T(128)} broadcast(s32[]{:T(128)} %constant.1), dimensions={}
1769 ROOT %while.17 = s32[6]{0:T(128)} while(s32[6]{0:T(128)} %broadcast.2), condition=%WhileWithPrngScalarResult.11, body=%body.3
1770 }
1771 )";
1772
1773 TF_ASSERT_OK_AND_ASSIGN(auto module,
1774 ParseAndReturnVerifiedModule(hlo_string));
1775 AssignMemorySpace(module.get());
1776 }
1777
TEST_P(MemorySpaceAssignmentTest,WhileInPlaceBuffer)1778 TEST_P(MemorySpaceAssignmentTest, WhileInPlaceBuffer) {
1779 // Ensure that a dynamic update slice within a while loop is able to get an
1780 // alternate memory allocation.
1781 absl::string_view hlo_string = R"(
1782 HloModule Module, is_scheduled=true
1783
1784 fused_computation {
1785 param0 = f32[2,3] parameter(0)
1786 constant.1 = f32[] constant(0)
1787 broadcast = f32[2,1] broadcast(constant.1), dimensions={}
1788 constant.3 = s32[] constant(0)
1789 ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
1790 }
1791
1792 %WhileBody (body_param: (f32[2,3], f32[2,3], f32[])) -> (f32[2,3], f32[2,3], f32[]) {
1793 %body_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
1794 %get-tuple-element.1 = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=2
1795 %get-tuple-element.2 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=0
1796 %get-tuple-element.3 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=1
1797 %fusion = f32[2,3]{1,0} fusion(get-tuple-element.3), kind=kLoop, calls=fused_computation
1798 %multiply = f32[2,3]{1,0} multiply(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %fusion)
1799 ROOT %tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} %multiply, f32[2,3]{1,0} %fusion, f32[] %get-tuple-element.1)
1800 }
1801
1802 %WhileCond (cond_param: (f32[2,3], f32[2,3], f32[])) -> pred[] {
1803 %cond_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
1804 %get-tuple-element = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %cond_param), index=2
1805 %constant = f32[] constant(50)
1806 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1807 }
1808
1809 ENTRY %Entry (param_data: f32[2,3], param_iter: f32[], p2: f32[2,3]) -> f32[2,3] {
1810 %param_iter = f32[] parameter(1)
1811 %param_data = f32[2,3]{1,0} parameter(0)
1812 %p2 = f32[2,3]{1,0} parameter(2)
1813 %copy1 = f32[2,3]{1,0} copy(param_data)
1814 %copy2 = f32[2,3]{1,0} copy(p2)
1815 %tuple.1 = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} copy1, f32[2,3]{1,0} copy2, f32[] %param_iter)
1816 %while = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) while((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1817 %get-tuple-element.4 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %while), index=0
1818 ROOT %copy3 = f32[2,3]{1,0} copy(get-tuple-element.4)
1819 }
1820 )";
1821
1822 TF_ASSERT_OK_AND_ASSIGN(auto module,
1823 ParseAndReturnVerifiedModule(hlo_string));
1824 AssignMemorySpace(module.get());
1825 const HloInstruction* while_op =
1826 module->entry_computation()->GetInstructionWithName("while");
1827 if (GetParam()) {
1828 EXPECT_EQ(
1829 ShapeUtil::GetSubshape(while_op->shape(), {1}).layout().memory_space(),
1830 kAlternateMemorySpace);
1831 }
1832 }
1833
TEST_P(MemorySpaceAssignmentTest,WhileSharedBufferVerificationBug)1834 TEST_P(MemorySpaceAssignmentTest, WhileSharedBufferVerificationBug) {
1835 // Tests a spurious verification failure when a while has the same value
1836 // passed in twice (copy0) and that value is evicted within the while loop.
1837 absl::string_view hlo_string = R"(
1838 HloModule module, is_scheduled=true
1839
1840 while_cond {
1841 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1842 ROOT gte = pred[] get-tuple-element(p0), index=3
1843 }
1844
1845 while_body {
1846 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1847 gte0 = f32[3]{0} get-tuple-element(p0), index=0
1848 gte1 = f32[3]{0} get-tuple-element(p0), index=1
1849 gte2 = f32[3]{0} get-tuple-element(p0), index=2
1850 gte3 = pred[] get-tuple-element(p0), index=3
1851 add = f32[3]{0} add(gte0, gte0)
1852 negate0 = f32[3]{0} negate(add)
1853 negate1 = f32[3]{0} negate(negate0)
1854 negate2 = f32[3]{0} negate(negate1)
1855 negate3 = f32[3]{0} negate(negate2)
1856 negate4 = f32[3]{0} negate(negate3)
1857 negate5 = f32[3]{0} negate(negate4)
1858 negate6 = f32[3]{0} negate(negate5)
1859 negate7 = f32[3]{0} negate(negate6)
1860 negate8 = f32[3]{0} negate(negate7)
1861 negate9 = f32[3]{0} negate(negate8)
1862 negate10 = f32[3]{0} negate(negate9)
1863 negate11 = f32[3]{0} negate(negate10)
1864 negate12 = f32[3]{0} negate(negate11)
1865 negate13 = f32[3]{0} negate(negate12)
1866 negate14 = f32[3]{0} negate(negate13)
1867 negate15 = f32[3]{0} negate(negate14)
1868 negate16 = f32[3]{0} negate(negate15)
1869 ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, gte0, negate16, gte3)
1870 }
1871
1872 ENTRY entry {
1873 p0 = f32[3]{0} parameter(0)
1874 p1 = pred[] parameter(1)
1875 copy0 = f32[3]{0} copy(p0)
1876 copy1 = f32[3]{0} copy(p0)
1877 tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
1878 while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
1879 ROOT gte = f32[3]{0} get-tuple-element(while), index=2
1880 }
1881 )";
1882 TF_ASSERT_OK_AND_ASSIGN(auto module,
1883 ParseAndReturnVerifiedModule(hlo_string));
1884 AssignMemorySpace(module.get());
1885 }
1886
TEST_P(MemorySpaceAssignmentTest,b228599972)1887 TEST_P(MemorySpaceAssignmentTest, b228599972) {
1888 absl::string_view hlo_string = R"(
1889 HloModule entry, is_scheduled=true
1890
1891 fused_computation {
1892 %p0 = f32[2,3]{1,0} parameter(0)
1893 %result0 = f32[2,3]{1,0} copy(%p0)
1894 %result1 = f32[2,3]{1,0} copy(%p0)
1895 ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}) tuple(%result0, %result1)
1896 }
1897
1898 ENTRY entry {
1899 %p0 = f32[2,3]{1,0} parameter(0)
1900 %p1 = f32[2,3]{1,0} parameter(1)
1901 %unused = (f32[2,3]{1,0}, f32[2,3]{1,0}) fusion(%p0), kind=kLoop, calls=%fused_computation
1902 %unused.0 = f32[2,3]{1,0} get-tuple-element(%unused), index=0
1903 %unused.1 = f32[2,3]{1,0} get-tuple-element(%unused), index=1
1904 %negate.0 = f32[2,3]{1,0} negate(f32[2,3]{1,0} %unused.0)
1905 %negate.1 = f32[2,3]{1,0} negate(f32[2,3]{1,0} %unused.1)
1906
1907 ROOT %result = f32[2,3]{1,0} negate(%p1)
1908 }
1909 )";
1910 TF_ASSERT_OK_AND_ASSIGN(auto module,
1911 ParseAndReturnVerifiedModule(hlo_string));
1912 AssignMemorySpace(module.get());
1913 }
1914
TEST_P(MemorySpaceAssignmentTest,b172243149)1915 TEST_P(MemorySpaceAssignmentTest, b172243149) {
1916 // Tests for the failure in b/172243149, where if we skip processing
1917 // non-copy allocations that are in default memory can actually cause
1918 // failures. In this case, the problem tensor is copy0, where it is fed to
1919 // both negate, while, and add0. The copy0->negate dependency can be allocated
1920 // in the alternate memory. Then the algorithm attempts to place the
1921 // copy0->while edge in the alternate memory, but since this value isn't used
1922 // in the while loop, it won't get an alternate memory allocation. Finally for
1923 // the copy0->add0 edge, the algorithm will actually replace it with
1924 // while{0}->add0, since this is equivalent and while is defined later than
1925 // copy0. However, if we actually skip processing this while{0}->add0
1926 // allocation, we won't replace this edge, and will end up with the
1927 // copy0->add0 edge, which illegally extends the lifetime of the alternate
1928 // memory buffer in copy0.
1929 absl::string_view hlo_string = R"(
1930 HloModule module, is_scheduled=true
1931
1932 while_cond {
1933 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1934 ROOT gte = pred[] get-tuple-element(p0), index=3
1935 }
1936
1937 while_body {
1938 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1939 gte0 = f32[3]{0} get-tuple-element(p0), index=0
1940 gte1 = f32[3]{0} get-tuple-element(p0), index=1
1941 gte2 = f32[3]{0} get-tuple-element(p0), index=2
1942 gte3 = pred[] get-tuple-element(p0), index=3
1943 add = f32[3]{0} add(gte1, gte2)
1944 negate0 = f32[3]{0} negate(add)
1945 negate1 = f32[3]{0} negate(negate0)
1946 negate2 = f32[3]{0} negate(negate1)
1947 negate3 = f32[3]{0} negate(negate2)
1948 negate4 = f32[3]{0} negate(negate3)
1949 negate5 = f32[3]{0} negate(negate4)
1950 negate6 = f32[3]{0} negate(negate5)
1951 negate7 = f32[3]{0} negate(negate6)
1952 negate8 = f32[3]{0} negate(negate7)
1953 negate9 = f32[3]{0} negate(negate8)
1954 negate10 = f32[3]{0} negate(negate9)
1955 negate11 = f32[3]{0} negate(negate10)
1956 negate12 = f32[3]{0} negate(negate11)
1957 negate13 = f32[3]{0} negate(negate12)
1958 negate14 = f32[3]{0} negate(negate13)
1959 negate15 = f32[3]{0} negate(negate14)
1960 negate16 = f32[3]{0} negate(negate15)
1961 ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, negate16, gte3)
1962 }
1963
1964 ENTRY entry {
1965 p0 = f32[3]{0} parameter(0)
1966 p1 = pred[] parameter(1)
1967 copy0 = f32[3]{0} copy(p0)
1968 copy1 = f32[3]{0} copy(p0)
1969 copy2 = f32[3]{0} copy(p0)
1970 negate = f32[3]{0} negate(copy0)
1971 tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, copy2, p1)
1972 while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
1973 gte = f32[3]{0} get-tuple-element(while), index=2
1974 add0 = f32[3]{0} add(negate, copy0)
1975 ROOT add1 = f32[3]{0} add(add0, gte)
1976 }
1977 )";
1978 TF_ASSERT_OK_AND_ASSIGN(auto module,
1979 ParseAndReturnVerifiedModule(hlo_string));
1980 AssignMemorySpace(module.get());
1981 }
1982
TEST_P(MemorySpaceAssignmentTest,ControlPredecessorsBug)1983 TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
1984 // Having control_predecessors on an HLO was preventing us from DCEing an op
1985 // that doesn't have any users (tuple.1). The scheduler assumes the graph is
1986 // fully DCEed, which causes some instructions not to be scheduled.
1987 absl::string_view hlo_string = R"(
1988 HloModule sort.16, is_scheduled=true
1989
1990 ENTRY %sort.16 (param.0.1: s32[1], param.1.2: f32[1], param.2.3: u32[1], param.3.4: s32[1]) -> (s32[1], f32[1], u32[1], s32[1]) {
1991 %param.3.4 = s32[1]{0:T(128)} parameter(3)
1992 %param.2.3 = u32[1]{0:T(128)} parameter(2)
1993 %param.1.2 = f32[1]{0:T(128)} parameter(1)
1994 %param.0.1 = s32[1]{0:T(128)} parameter(0)
1995 %tuple.1 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %param.0.1, f32[1]{0:T(128)} %param.1.2, u32[1]{0:T(128)} %param.2.3, s32[1]{0:T(128)} %param.3.4), control-predecessors={%param.0.1}
1996 %get-tuple-element.4 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=0
1997 %get-tuple-element.5 = f32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=1
1998 %get-tuple-element.6 = u32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=2
1999 %get-tuple-element.7 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=3
2000 %copy.4 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.4)
2001 %copy.5 = f32[1]{0:T(128)} copy(f32[1]{0:T(128)} %get-tuple-element.5)
2002 %copy.6 = u32[1]{0:T(128)} copy(u32[1]{0:T(128)} %get-tuple-element.6)
2003 %copy.7 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.7)
2004 ROOT %tuple.2 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %copy.4, f32[1]{0:T(128)} %copy.5, u32[1]{0:T(128)} %copy.6, s32[1]{0:T(128)} %copy.7)
2005 }
2006 )";
2007
2008 TF_ASSERT_OK_AND_ASSIGN(auto module,
2009 ParseAndReturnVerifiedModule(hlo_string));
2010 AssignMemorySpace(module.get());
2011 }
2012
TEST_P(MemorySpaceAssignmentTest,ConditionalShouldBeAllocatedInAlternateMem)2013 TEST_P(MemorySpaceAssignmentTest, ConditionalShouldBeAllocatedInAlternateMem) {
2014 // Checks if simple conditionals get alternate memory allocations.
2015 absl::string_view hlo_string = R"(
2016 HloModule CondAllocation, is_scheduled=true
2017
2018 true_computation {
2019 p0 = (f32[3]{0}) parameter(0)
2020 gte = f32[3]{0} get-tuple-element(p0), index=0
2021 ROOT neg1 = f32[3]{0} negate(gte)
2022 }
2023
2024 false_computation {
2025 p0 = (f32[3]{0}) parameter(0)
2026 gte = f32[3]{0} get-tuple-element(p0), index=0
2027 ROOT neg2 = f32[3]{0} negate(gte)
2028 }
2029
2030 ENTRY entry {
2031 p0 = f32[3]{0} parameter(0)
2032 p1 = pred[] parameter(1)
2033 copy = f32[3]{0} copy(p0)
2034 tuple = (f32[3]{0}) tuple(copy)
2035 ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
2036 }
2037 )";
2038 TF_ASSERT_OK_AND_ASSIGN(auto module,
2039 ParseAndReturnVerifiedModule(hlo_string));
2040 AssignMemorySpace(module.get());
2041
2042 if (GetParam()) {
2043 // Check that copy and gtes got alternate memory allocations.
2044 auto copy =
2045 module->GetComputationWithName("entry")->GetInstructionWithName("copy");
2046 EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace);
2047 auto neg1 = module->GetComputationWithName("true_computation")
2048 ->GetInstructionWithName("neg1");
2049 auto neg1_operand = neg1->operand(0);
2050 EXPECT_EQ(neg1_operand->shape().layout().memory_space(),
2051 kAlternateMemorySpace);
2052 auto neg2 = module->GetComputationWithName("false_computation")
2053 ->GetInstructionWithName("neg2");
2054 auto neg2_operand = neg2->operand(0);
2055 EXPECT_EQ(neg2_operand->shape().layout().memory_space(),
2056 kAlternateMemorySpace);
2057 }
2058 }
2059
TEST_P(MemorySpaceAssignmentTest,ConditionalAvoidsUnnecessaryPrefetch)2060 TEST_P(MemorySpaceAssignmentTest, ConditionalAvoidsUnnecessaryPrefetch) {
2061 // Checks if we avoid unnecessary allocation in alternate memory if the input
2062 // won't be used in the computation for a long time.
2063 absl::string_view hlo_string = R"(
2064 HloModule CondAllocation, is_scheduled=true
2065
2066 true_computation {
2067 p0 = (f32[3]{0}, f32[3]{0}) parameter(0)
2068 gte0 = f32[3]{0} get-tuple-element(p0), index=0
2069 neg0 = f32[3]{0} negate(gte0)
2070 neg1 = f32[3]{0} negate(neg0)
2071 neg2 = f32[3]{0} negate(neg1)
2072 neg3 = f32[3]{0} negate(neg2)
2073 neg4 = f32[3]{0} negate(neg3)
2074 neg5 = f32[3]{0} negate(neg4)
2075 neg6 = f32[3]{0} negate(neg5)
2076 neg7 = f32[3]{0} negate(neg6)
2077 neg8 = f32[3]{0} negate(neg7)
2078 neg9 = f32[3]{0} negate(neg8)
2079 gte1 = f32[3]{0} get-tuple-element(p0), index=1
2080 ROOT add = f32[3]{0} add(neg9, gte1)
2081 }
2082
2083 false_computation {
2084 p0 = (f32[3]{0}) parameter(0)
2085 gte = f32[3]{0} get-tuple-element(p0), index=0
2086 ROOT neg = f32[3]{0} negate(gte)
2087 }
2088
2089 ENTRY entry {
2090 p0 = f32[3]{0} parameter(0)
2091 p1 = pred[] parameter(1)
2092 copy0 = f32[3]{0} copy(p0)
2093 copy1 = f32[3]{0} copy(p0)
2094 tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1)
2095 tuple1 = (f32[3]{0}) tuple(copy0)
2096 ROOT conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation
2097 }
2098 )";
2099 TF_ASSERT_OK_AND_ASSIGN(auto module,
2100 ParseAndReturnVerifiedModule(hlo_string));
2101 AssignMemorySpace(module.get());
2102
2103 if (GetParam()) {
2104 // Check that copy1 doesn't get unnecessarily allocated in alternate mem
2105 // (due to long negate chain in true_computation) but is prefetched before
2106 // add.
2107 auto copy0 =
2108 module->GetComputationWithName("entry")->GetInstructionWithName(
2109 "copy0");
2110 EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace);
2111 auto copy1 =
2112 module->GetComputationWithName("entry")->GetInstructionWithName(
2113 "copy1");
2114 EXPECT_EQ(copy1->shape().layout().memory_space(), kDefaultMemorySpace);
2115 auto add = module->GetComputationWithName("true_computation")
2116 ->GetInstructionWithName("add");
2117 auto add_operand = add->operand(1);
2118 EXPECT_EQ(add_operand->shape().layout().memory_space(),
2119 kAlternateMemorySpace);
2120 }
2121 }
2122
TEST_P(MemorySpaceAssignmentTest,ConditionalMultiUse)2123 TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUse) {
2124 // Make sure there is an evict when there is a conditional use followed by
2125 // another use.
2126 absl::string_view hlo_string = R"(
2127 HloModule CondAllocation, is_scheduled=true
2128
2129 true_computation {
2130 p0 = (f32[3]{0}, f32[3]{0}) parameter(0)
2131 gte0 = f32[3]{0} get-tuple-element(p0), index=0
2132 gte1 = f32[3]{0} get-tuple-element(p0), index=1
2133 add0 = f32[3]{0} add(gte0, gte1)
2134 neg0 = f32[3]{0} negate(add0)
2135 neg1 = f32[3]{0} negate(neg0)
2136 neg2 = f32[3]{0} negate(neg1)
2137 neg3 = f32[3]{0} negate(neg2)
2138 neg4 = f32[3]{0} negate(neg3)
2139 neg5 = f32[3]{0} negate(neg4)
2140 neg6 = f32[3]{0} negate(neg5)
2141 neg7 = f32[3]{0} negate(neg6)
2142 neg8 = f32[3]{0} negate(neg7)
2143 ROOT neg9 = f32[3]{0} negate(neg8)
2144 }
2145
2146 false_computation {
2147 p0 = (f32[3]{0}) parameter(0)
2148 gte = f32[3]{0} get-tuple-element(p0), index=0
2149 ROOT neg = f32[3]{0} negate(gte)
2150 }
2151
2152 ENTRY entry {
2153 p0 = f32[3]{0} parameter(0)
2154 p1 = pred[] parameter(1)
2155 copy0 = f32[3]{0} copy(p0)
2156 copy1 = f32[3]{0} copy(p0)
2157 tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1)
2158 tuple1 = (f32[3]{0}) tuple(copy0)
2159 conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation
2160 ROOT add1 = f32[3]{0} add(copy1, conditional)
2161 }
2162 )";
2163 TF_ASSERT_OK_AND_ASSIGN(auto module,
2164 ParseAndReturnVerifiedModule(hlo_string));
2165 AssignMemorySpace(module.get());
2166
2167 if (GetParam()) {
2168 // Make sure the copy1->add edge is in alternate memory. Before conditional,
2169 // this should be evicted to default memory and neg uses the input from
2170 // default memory.
2171 auto copy1 =
2172 module->GetComputationWithName("entry")->GetInstructionWithName(
2173 "copy1");
2174 EXPECT_EQ(copy1->shape().layout().memory_space(), kAlternateMemorySpace);
2175 auto add0 = module->GetComputationWithName("true_computation")
2176 ->GetInstructionWithName("add0");
2177 auto add0_operand = add0->operand(1);
2178 EXPECT_EQ(add0_operand->shape().layout().memory_space(),
2179 kAlternateMemorySpace);
2180 auto add1 =
2181 module->GetComputationWithName("entry")->GetInstructionWithName("add1");
2182 auto add1_operand = add1->operand(0);
2183 EXPECT_EQ(add1_operand->shape().layout().memory_space(),
2184 kDefaultMemorySpace);
2185 EXPECT_EQ(add1_operand->opcode(), HloOpcode::kCopyDone);
2186 }
2187 }
2188
TEST_P(MemorySpaceAssignmentTest,ConditionalMultiUseInWhile)2189 TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUseInWhile) {
2190 absl::string_view hlo_string = R"(
2191 HloModule CondAllocation, is_scheduled=true
2192
2193 true_computation {
2194 p0 = (f32[3]{0}) parameter(0)
2195 gte = f32[3]{0} get-tuple-element(p0), index=0
2196 ROOT neg1 = f32[3]{0} negate(gte)
2197 }
2198
2199 false_computation {
2200 p0 = (f32[3]{0}) parameter(0)
2201 gte = f32[3]{0} get-tuple-element(p0), index=0
2202 ROOT neg2 = f32[3]{0} negate(gte)
2203 }
2204
2205 while_cond {
2206 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
2207 ROOT gte = pred[] get-tuple-element(p0), index=2
2208 }
2209
2210 while_body {
2211 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
2212 gte0 = f32[3]{0} get-tuple-element(p0), index=0
2213 gte1 = f32[3]{0} get-tuple-element(p0), index=1
2214 gte2 = pred[] get-tuple-element(p0), index=2
2215 cond_tuple = (f32[3]{0}) tuple(gte0)
2216 conditional = f32[3]{0} conditional(gte2, cond_tuple, cond_tuple), true_computation=true_computation, false_computation=false_computation
2217 add = f32[3]{0} add(conditional, gte1)
2218 neg0 = f32[3]{0} negate(add)
2219 neg1 = f32[3]{0} negate(neg0)
2220 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, neg1, gte2)
2221 }
2222
2223 ENTRY entry {
2224 p0 = f32[3]{0} parameter(0)
2225 p1 = pred[] parameter(1)
2226 copy0 = f32[3]{0} copy(p0)
2227 copy1 = f32[3]{0} copy(p0)
2228 tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, p1)
2229 while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
2230 ROOT gte = f32[3]{0} get-tuple-element(while), index=1
2231 }
2232 )";
2233 TF_ASSERT_OK_AND_ASSIGN(auto module,
2234 ParseAndReturnVerifiedModule(hlo_string));
2235 AssignMemorySpace(module.get());
2236
2237 if (GetParam()) {
2238 // Make sure copy1/while{0}/cond_tuple{0} gets alternate memory allocation.
2239 // This will force an eviction and a prefetch for while body root.
2240 auto copy0 =
2241 module->GetComputationWithName("entry")->GetInstructionWithName(
2242 "copy0");
2243 EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace);
2244 auto conditional = module->GetComputationWithName("while_body")
2245 ->GetInstructionWithName("conditional");
2246 auto conditional_operand = conditional->operand(1);
2247 EXPECT_EQ(ShapeUtil::GetSubshape(conditional_operand->shape(), {0})
2248 .layout()
2249 .memory_space(),
2250 kAlternateMemorySpace);
2251 auto while_root =
2252 module->GetComputationWithName("while_body")->root_instruction();
2253 auto while_root_operand = while_root->operand(0);
2254 EXPECT_THAT(
2255 while_root_operand,
2256 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
2257 op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
2258 op::GetTupleElement(op::Parameter(0)))));
2259 }
2260 }
2261
TEST_P(MemorySpaceAssignmentTest,NestedConditional)2262 TEST_P(MemorySpaceAssignmentTest, NestedConditional) {
2263 absl::string_view hlo_string = R"(
2264 HloModule CondAllocation, is_scheduled=true
2265
2266 true_computation2 {
2267 p0 = (f32[3]{0}) parameter(0)
2268 gte = f32[3]{0} get-tuple-element(p0), index=0
2269 ROOT neg1 = f32[3]{0} negate(gte)
2270 }
2271
2272 false_computation2 {
2273 p0 = (f32[3]{0}) parameter(0)
2274 gte = f32[3]{0} get-tuple-element(p0), index=0
2275 ROOT neg2 = f32[3]{0} negate(gte)
2276 }
2277
2278 true_computation1 {
2279 p0 = (f32[3]{0}) parameter(0)
2280 gte = f32[3]{0} get-tuple-element(p0), index=0
2281 slice = f32[1]{0} slice(gte), slice={[0:1]}
2282 bitcast = f32[] bitcast(slice)
2283 constant = f32[] constant(0.0)
2284 compare = pred[] compare(bitcast, constant), direction=GT
2285 ROOT conditional = f32[3]{0} conditional(compare, p0, p0), true_computation=true_computation2, false_computation=false_computation2
2286 }
2287
2288 false_computation1 {
2289 p0 = (f32[3]{0}) parameter(0)
2290 gte = f32[3]{0} get-tuple-element(p0), index=0
2291 ROOT neg3 = f32[3]{0} negate(gte)
2292 }
2293
2294
2295 ENTRY entry {
2296 p0 = f32[3]{0} parameter(0)
2297 p1 = pred[] parameter(1)
2298 copy = f32[3]{0} copy(p0)
2299 tuple = (f32[3]{0}) tuple(copy)
2300 ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
2301 }
2302 )";
2303 TF_ASSERT_OK_AND_ASSIGN(auto module,
2304 ParseAndReturnVerifiedModule(hlo_string));
2305 AssignMemorySpace(module.get());
2306
2307 if (GetParam()) {
2308 // Make sure alternate memory allocation gets propagated into both levels of
2309 // conditional.
2310 auto copy =
2311 module->GetComputationWithName("entry")->GetInstructionWithName("copy");
2312 EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace);
2313 auto neg1_operand = module->GetComputationWithName("true_computation2")
2314 ->GetInstructionWithName("neg1")
2315 ->operand(0);
2316 auto neg2_operand = module->GetComputationWithName("false_computation2")
2317 ->GetInstructionWithName("neg2")
2318 ->operand(0);
2319 auto neg3_operand = module->GetComputationWithName("false_computation1")
2320 ->GetInstructionWithName("neg3")
2321 ->operand(0);
2322 EXPECT_EQ(neg1_operand->shape().layout().memory_space(),
2323 kAlternateMemorySpace);
2324 EXPECT_EQ(neg2_operand->shape().layout().memory_space(),
2325 kAlternateMemorySpace);
2326 EXPECT_EQ(neg3_operand->shape().layout().memory_space(),
2327 kAlternateMemorySpace);
2328 }
2329 }
2330
TEST_P(MemorySpaceAssignmentTest,NestedConditionalBufferReuseVerificationBug)2331 TEST_P(MemorySpaceAssignmentTest, NestedConditionalBufferReuseVerificationBug) {
2332 // Tests a spurious verification failure when there are nested conditionals
2333 // and the innermost conditional computation reuses the buffer. Here, both the
2334 // parameter of true_computation2 and neg2 will get the same buffer. Make sure
2335 // that verification doesn't claim a failure in this case.
2336 absl::string_view hlo_string = R"(
2337 HloModule CondAllocation, is_scheduled=true
2338
2339 true_computation2 {
2340 p0 = (f32[3]{0}) parameter(0)
2341 gte = f32[3]{0} get-tuple-element(p0), index=0
2342 neg1 = f32[3]{0} negate(gte)
2343 neg2 = f32[3]{0} negate(neg1)
2344 ROOT neg3 = f32[3]{0} negate(neg2)
2345 }
2346
2347 false_computation2 {
2348 p0 = (f32[3]{0}) parameter(0)
2349 gte = f32[3]{0} get-tuple-element(p0), index=0
2350 ROOT neg4 = f32[3]{0} negate(gte)
2351 }
2352
2353 true_computation1 {
2354 p0 = (f32[3]{0}) parameter(0)
2355 gte = f32[3]{0} get-tuple-element(p0), index=0
2356 slice = f32[1]{0} slice(gte), slice={[0:1]}
2357 bitcast = f32[] bitcast(slice)
2358 constant = f32[] constant(0.0)
2359 compare = pred[] compare(bitcast, constant), direction=GT
2360 tuple = (f32[3]{0}) tuple(gte)
2361 ROOT conditional = f32[3]{0} conditional(compare, tuple, tuple), true_computation=true_computation2, false_computation=false_computation2
2362 }
2363
2364 false_computation1 {
2365 p0 = (f32[3]{0}) parameter(0)
2366 gte = f32[3]{0} get-tuple-element(p0), index=0
2367 ROOT neg5 = f32[3]{0} negate(gte)
2368 }
2369
2370 ENTRY entry {
2371 p0 = f32[3]{0} parameter(0)
2372 p1 = pred[] parameter(1)
2373 copy = f32[3]{0} copy(p0)
2374 tuple = (f32[3]{0}) tuple(copy)
2375 ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
2376 }
2377 )";
2378 TF_ASSERT_OK_AND_ASSIGN(auto module,
2379 ParseAndReturnVerifiedModule(hlo_string));
2380 AssignMemorySpace(module.get());
2381 }
2382
TEST_P(MemorySpaceAssignmentTest,ConditionalComputationBufferOverlapBeforeParam)2383 TEST_P(MemorySpaceAssignmentTest,
2384 ConditionalComputationBufferOverlapBeforeParam) {
2385 absl::string_view hlo_string = R"(
2386 HloModule CondAllocation, is_scheduled=true
2387
2388 true_computation {
2389 p0 = (f32[3]{0}) parameter(0)
2390 gte = f32[3]{0} get-tuple-element(p0), index=0
2391 ROOT neg2 = f32[3]{0} negate(gte)
2392 }
2393
2394 false_computation {
2395 c = f32[3]{0} constant({0.0, 1.0, 2.0})
2396 neg0 = f32[3]{0} negate(c)
2397 neg1 = f32[3]{0} negate(neg0)
2398 p0 = (f32[3]{0}) parameter(0)
2399 gte = f32[3]{0} get-tuple-element(p0), index=0
2400 ROOT add = f32[3]{0} add(gte, neg1)
2401 }
2402
2403 ENTRY entry {
2404 p0 = f32[3]{0} parameter(0)
2405 p1 = pred[] parameter(1)
2406 copy = f32[3]{0} copy(p0)
2407 tuple = (f32[3]{0}) tuple(copy)
2408 ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
2409 }
2410 )";
2411 TF_ASSERT_OK_AND_ASSIGN(auto module,
2412 ParseAndReturnVerifiedModule(hlo_string));
2413 auto preset_assignments = AssignMemorySpace(module.get());
2414
2415 auto get_offset = [&](absl::string_view hlo_name) {
2416 for (const auto& chunk : preset_assignments->chunks()) {
2417 if (chunk.first.instruction->name() == hlo_name) {
2418 return chunk.second.offset;
2419 }
2420 }
2421 return static_cast<int64_t>(-1);
2422 };
2423
2424 int64_t copy_offset = get_offset("copy");
2425 int64_t neg0_offset = get_offset("neg0");
2426 EXPECT_NE(copy_offset, -1);
2427 EXPECT_NE(neg0_offset, -1);
2428 EXPECT_NE(copy_offset, neg0_offset);
2429 }
2430
TEST_P(MemorySpaceAssignmentTest,RequestIdentifierShouldNotBeAllocatedInAlternateMem)2431 TEST_P(MemorySpaceAssignmentTest,
2432 RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
2433 // Ensure that request identifier returned by Send/Recv HLOs are not allocated
2434 // in the alternate memory.
2435 absl::string_view hlo_string = R"(
2436 HloModule SendRecv, is_scheduled=true
2437
2438 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2439 %p = f32[3]{0} parameter(0)
2440 %after-all = token[] after-all()
2441 %recv.4 = (f32[3]{0}, u32[], token[]) recv(token[] %after-all), channel_id=7
2442 %recv-done.4 = (f32[3]{0}, token[]) recv-done((f32[3]{0}, u32[], token[]) %recv.4), channel_id=7
2443 %token.1 = token[] get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=1
2444 %data = f32[3]{0} get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=0
2445 %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %data, token[] %token.1), channel_id=2
2446 %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2447 ROOT %add = f32[3]{0} add(f32[3]{0} %p, f32[3]{0} %data)
2448 }
2449 )";
2450
2451 TF_ASSERT_OK_AND_ASSIGN(auto module,
2452 ParseAndReturnVerifiedModule(hlo_string));
2453 AssignMemorySpace(module.get());
2454
2455 for (const HloInstruction* instruction :
2456 module->entry_computation()->instructions()) {
2457 if (instruction->opcode() == HloOpcode::kSend ||
2458 instruction->opcode() == HloOpcode::kRecv) {
2459 const Shape& request_identifier_shape =
2460 ShapeUtil::GetSubshape(instruction->shape(), {1});
2461 EXPECT_NE(request_identifier_shape.layout().memory_space(),
2462 kAlternateMemorySpace);
2463 }
2464 }
2465 }
2466
TEST_P(MemorySpaceAssignmentTest,SendDoneShouldHaveSendOperand)2467 TEST_P(MemorySpaceAssignmentTest, SendDoneShouldHaveSendOperand) {
2468 // Ensure that SendDone has only a Send operand.
2469 absl::string_view hlo_string = R"(
2470 HloModule SendRecv, is_scheduled=true
2471
2472 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2473 %p0 = f32[3]{0} parameter(0)
2474 %p1 = f32[3]{0} parameter(1)
2475 %neg0 = f32[3]{0} negate(f32[3]{0} %p1)
2476 %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
2477 %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
2478 %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
2479 %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
2480 %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
2481 %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
2482 %after-all = token[] after-all()
2483 %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2
2484 %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2485 ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6)
2486 }
2487 )";
2488
2489 TF_ASSERT_OK_AND_ASSIGN(auto module,
2490 ParseAndReturnVerifiedModule(hlo_string));
2491 AssignMemorySpace(module.get());
2492 }
2493
TEST_P(MemorySpaceAssignmentTest,SendAndSendDoneShouldGetSameAllocation)2494 TEST_P(MemorySpaceAssignmentTest, SendAndSendDoneShouldGetSameAllocation) {
2495 // Ensure that Send and SendDone have the same allocation.
2496 absl::string_view hlo_string = R"(
2497 HloModule SendRecv, is_scheduled=true
2498
2499 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2500 %p0 = f32[3]{0} parameter(0)
2501 %p1 = f32[3]{0} parameter(1)
2502 %after-all = token[] after-all()
2503 %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2
2504 %neg0 = f32[3]{0} negate(f32[3]{0} %p1)
2505 %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
2506 %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
2507 %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
2508 %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
2509 %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
2510 %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
2511 %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2512 ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6)
2513 }
2514 )";
2515
2516 TF_ASSERT_OK_AND_ASSIGN(auto module,
2517 ParseAndReturnVerifiedModule(hlo_string));
2518 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
2519 /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/4);
2520 }
2521
TEST_P(MemorySpaceAssignmentTest,LastUseOpt)2522 TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
2523 // Test that checks the last use optimization. It uses two buffers that should
2524 // be placed in alternate memory.
2525 //
2526 // +-------+
2527 // / \
2528 // add1--->sub1 +-------->mul2
2529 // mul1===>add2
2530 //
2531 // Without the last use optimization, the mul1 buffer will be assigned first
2532 // (because it is larger) to offset 0. Then, add1 will be scheduled for the
2533 // add1 to sub1 segment. Because offset 0 is available, it will get that
2534 // offset. But because offset 0 is not available in the sub1 to mul2 offset,
2535 // it will end up in unnecessary copies. With the last use optimization, these
2536 // copies can be optimized away.
2537 HloComputation::Builder builder(TestName());
2538 Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
2539 Shape shape2 = ShapeUtil::MakeShape(F32, {2, 4});
2540 PaddingConfig padding_config = MakeEdgePaddingConfig({{0, 0}, {0, 1}});
2541 HloInstruction* p0 =
2542 builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0"));
2543 HloInstruction* p1 =
2544 builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1"));
2545 HloInstruction* add1 = builder.AddInstruction(
2546 HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, p0, p0));
2547 HloInstruction* sub1 = builder.AddInstruction(
2548 HloInstruction::CreateBinary(shape1, HloOpcode::kSubtract, p0, add1));
2549 HloInstruction* mul1 = builder.AddInstruction(
2550 HloInstruction::CreateBinary(shape2, HloOpcode::kMultiply, p1, p1));
2551 HloInstruction* add2 = builder.AddInstruction(
2552 HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, mul1, p1));
2553 HloInstruction* mul2 = builder.AddInstruction(
2554 HloInstruction::CreateBinary(shape1, HloOpcode::kMultiply, add1, sub1));
2555 HloInstruction* padding_value = builder.AddInstruction(
2556 HloInstruction::CreateConstant(LiteralUtil::Zero(F32)));
2557 HloInstruction* padded_mul2 = builder.AddInstruction(
2558 HloInstruction::CreatePad(shape2, mul2, padding_value, padding_config));
2559 HloInstruction* add3 = builder.AddInstruction(
2560 HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, add2, padded_mul2));
2561
2562 auto module = CreateNewVerifiedModule();
2563 HloComputation* computation = module->AddEntryComputation(builder.Build());
2564
2565 HloSchedule schedule(module.get());
2566 schedule.set_sequence(computation, {p0, p1, add1, sub1, mul1, add2, mul2,
2567 padding_value, padded_mul2, add3});
2568 TF_CHECK_OK(module->set_schedule(schedule));
2569
2570 AssignMemorySpace(module.get());
2571
2572 EXPECT_THAT(
2573 mul2,
2574 op::Multiply(
2575 op::Add(op::Parameter(0), op::Parameter(0)),
2576 op::Subtract(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
2577 op::Parameter(0)),
2578 op::Add(op::Parameter(0), op::Parameter(0)))));
2579 }
2580
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule1)2581 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
2582 // Test to ensure CopyStart/CopyDone is placed only in the entry computation.
2583 auto module = CreateNewVerifiedModule();
2584 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2585 Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
2586 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
2587
2588 auto cond_builder = HloComputation::Builder("WhileCond");
2589 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
2590 HloInstruction* cond_param = cond_builder.AddInstruction(
2591 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
2592 HloInstruction* cond_iter = cond_builder.AddInstruction(
2593 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
2594 HloInstruction* cond_limit = cond_builder.AddInstruction(
2595 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
2596 // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
2597 HloInstruction* cond_lt = cond_builder.AddInstruction(
2598 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
2599 cond_limit, ComparisonDirection::kLt));
2600 HloComputation* cond_computation =
2601 module->AddEmbeddedComputation(cond_builder.Build());
2602
2603 auto body_builder = HloComputation::Builder("WhileBody");
2604 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
2605 HloInstruction* body_param = body_builder.AddInstruction(
2606 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
2607 HloInstruction* body_iter = body_builder.AddInstruction(
2608 HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
2609 HloInstruction* body_data = body_builder.AddInstruction(
2610 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
2611 HloInstruction* body_iter_increment = body_builder.AddInstruction(
2612 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
2613 HloInstruction* body_iter_next =
2614 body_builder.AddInstruction(HloInstruction::CreateBinary(
2615 scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
2616 HloInstruction* body_data_increment =
2617 body_builder.AddInstruction(HloInstruction::CreateConstant(
2618 LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
2619 HloInstruction* body_data_mul =
2620 body_builder.AddInstruction(HloInstruction::CreateBinary(
2621 shape, HloOpcode::kMultiply, body_data, body_data));
2622 HloInstruction* body_data_add =
2623 body_builder.AddInstruction(HloInstruction::CreateBinary(
2624 shape, HloOpcode::kAdd, body_data, body_data_increment));
2625 HloInstruction* body_data_next =
2626 body_builder.AddInstruction(HloInstruction::CreateBinary(
2627 shape, HloOpcode::kAdd, body_data_add, body_data_mul));
2628 HloInstruction* body_out = body_builder.AddInstruction(
2629 HloInstruction::CreateTuple({body_data_next, body_iter_next}));
2630 HloComputation* body_computation =
2631 module->AddEmbeddedComputation(body_builder.Build());
2632
2633 auto builder = HloComputation::Builder(TestName());
2634 HloInstruction* data = builder.AddInstruction(
2635 HloInstruction::CreateParameter(0, shape, "param_iter"));
2636 HloInstruction* iter = builder.AddInstruction(
2637 HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
2638 HloInstruction* p2 =
2639 builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "p2"));
2640 HloInstruction* tuple =
2641 builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
2642 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
2643 tuple_shape, cond_computation, body_computation, tuple));
2644 HloInstruction* while_data = builder.AddInstruction(
2645 HloInstruction::CreateGetTupleElement(shape, while_op, 0));
2646 HloInstruction* add = builder.AddInstruction(
2647 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, while_data, p2));
2648 HloComputation* entry_computation =
2649 module->AddEntryComputation(builder.Build());
2650
2651 HloSchedule schedule(module.get());
2652 schedule.set_sequence(cond_computation,
2653 {cond_param, cond_iter, cond_limit, cond_lt});
2654 schedule.set_sequence(body_computation,
2655 {body_param, body_iter, body_data, body_iter_increment,
2656 body_iter_next, body_data_increment, body_data_mul,
2657 body_data_add, body_data_next, body_out});
2658 schedule.set_sequence(entry_computation,
2659 {iter, data, p2, tuple, while_op, while_data, add});
2660 TF_CHECK_OK(module->set_schedule(schedule));
2661
2662 AssignMemorySpace(module.get(), -1, 50);
2663 }
2664
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule2)2665 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
2666 auto module = CreateNewVerifiedModule();
2667 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2668 Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2669
2670 auto call_builder = HloComputation::Builder("Call");
2671 HloInstruction* call_param = call_builder.AddInstruction(
2672 HloInstruction::CreateParameter(0, shape, "call_param"));
2673 HloInstruction* call_param2 = call_builder.AddInstruction(
2674 HloInstruction::CreateParameter(1, shape2, "call_param2"));
2675 HloInstruction* slice = call_builder.AddInstruction(
2676 HloInstruction::CreateSlice(shape, call_param2, {0, 0}, {2, 3}, {1, 1}));
2677 HloInstruction* mul =
2678 call_builder.AddInstruction(HloInstruction::CreateBinary(
2679 shape, HloOpcode::kMultiply, call_param, slice));
2680 HloInstruction* negate0 = call_builder.AddInstruction(
2681 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2682 HloInstruction* negate1 = call_builder.AddInstruction(
2683 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2684 HloInstruction* negate2 = call_builder.AddInstruction(
2685 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2686 HloInstruction* negate3 = call_builder.AddInstruction(
2687 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2688 HloInstruction* negate4 = call_builder.AddInstruction(
2689 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2690 HloInstruction* negate5 = call_builder.AddInstruction(
2691 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2692 HloInstruction* negate6 = call_builder.AddInstruction(
2693 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2694 HloInstruction* negate7 = call_builder.AddInstruction(
2695 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2696 HloInstruction* add0 =
2697 call_builder.AddInstruction(HloInstruction::CreateBinary(
2698 shape, HloOpcode::kAdd, call_param, negate7));
2699 HloComputation* call_computation =
2700 module->AddEmbeddedComputation(call_builder.Build());
2701
2702 auto builder = HloComputation::Builder(TestName());
2703 HloInstruction* p0 =
2704 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2705 HloInstruction* p1 =
2706 builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1"));
2707 HloInstruction* add1 = builder.AddInstruction(
2708 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2709 HloInstruction* add2 = builder.AddInstruction(
2710 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2711 HloInstruction* negate8 = builder.AddInstruction(
2712 HloInstruction::CreateUnary(shape2, HloOpcode::kNegate, p1));
2713 HloInstruction* call = builder.AddInstruction(
2714 HloInstruction::CreateCall(shape, {add1, negate8}, call_computation));
2715 HloInstruction* add3 = builder.AddInstruction(
2716 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, add1));
2717 HloInstruction* add4 = builder.AddInstruction(
2718 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add3));
2719 HloInstruction* add5 = builder.AddInstruction(
2720 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add2, add4));
2721 HloComputation* entry_computation =
2722 module->AddEntryComputation(builder.Build());
2723
2724 HloSchedule schedule(module.get());
2725 schedule.set_sequence(
2726 call_computation,
2727 {call_param, call_param2, slice, mul, negate0, negate1, negate2, negate3,
2728 negate4, negate5, negate6, negate7, add0});
2729 schedule.set_sequence(entry_computation,
2730 {p0, p1, add1, add2, negate8, call, add3, add4, add5});
2731 TF_CHECK_OK(module->set_schedule(schedule));
2732
2733 AssignMemorySpace(module.get(), -1, 5);
2734 }
2735
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule3)2736 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
2737 auto module = CreateNewVerifiedModule();
2738 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2739 Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2740
2741 auto call_builder = HloComputation::Builder("Call");
2742 HloInstruction* call_param = call_builder.AddInstruction(
2743 HloInstruction::CreateParameter(0, shape, "call_param"));
2744 // Use shape2 here which is larger (scheduled earlier) to occupy alternate
2745 // memory at the beginning. This should cause a situation where the prefetch
2746 // of add1 later in the function body gets the wrong offset which cannot be
2747 // communicated to the outside the function.
2748 HloInstruction* iota =
2749 call_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0));
2750 HloInstruction* slice = call_builder.AddInstruction(
2751 HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1}));
2752 HloInstruction* mul =
2753 call_builder.AddInstruction(HloInstruction::CreateBinary(
2754 shape, HloOpcode::kMultiply, call_param, slice));
2755 HloInstruction* negate0 = call_builder.AddInstruction(
2756 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2757 HloInstruction* negate1 = call_builder.AddInstruction(
2758 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2759 HloInstruction* negate2 = call_builder.AddInstruction(
2760 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2761 HloInstruction* negate3 = call_builder.AddInstruction(
2762 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2763 HloInstruction* negate4 = call_builder.AddInstruction(
2764 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2765 HloInstruction* negate5 = call_builder.AddInstruction(
2766 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2767 HloInstruction* negate6 = call_builder.AddInstruction(
2768 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2769 HloInstruction* negate7 = call_builder.AddInstruction(
2770 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2771 HloInstruction* add0 =
2772 call_builder.AddInstruction(HloInstruction::CreateBinary(
2773 shape, HloOpcode::kAdd, call_param, negate7));
2774 HloComputation* call_computation =
2775 module->AddEmbeddedComputation(call_builder.Build());
2776
2777 auto builder = HloComputation::Builder(TestName());
2778 HloInstruction* p0 =
2779 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2780 HloInstruction* add1 = builder.AddInstruction(
2781 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2782 HloInstruction* add2 = builder.AddInstruction(
2783 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2784 HloInstruction* call = builder.AddInstruction(
2785 HloInstruction::CreateCall(shape, {add1}, call_computation));
2786 HloInstruction* add3 = builder.AddInstruction(
2787 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add1));
2788 HloComputation* entry_computation =
2789 module->AddEntryComputation(builder.Build());
2790
2791 HloSchedule schedule(module.get());
2792 schedule.set_sequence(
2793 call_computation,
2794 {call_param, iota, slice, mul, negate0, negate1, negate2, negate3,
2795 negate4, negate5, negate6, negate7, add0});
2796 schedule.set_sequence(entry_computation, {p0, add1, add2, call, add3});
2797 TF_CHECK_OK(module->set_schedule(schedule));
2798
2799 AssignMemorySpace(module.get(), -1, 5);
2800 }
2801
2802 // TODO(berkin): This might be an incorrect input graph, investigate.
TEST_P(MemorySpaceAssignmentTest,DISABLED_NonEntryComputationSchedule4)2803 TEST_P(MemorySpaceAssignmentTest, DISABLED_NonEntryComputationSchedule4) {
2804 auto module = CreateNewVerifiedModule();
2805 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2806 Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2807
2808 auto true_builder = HloComputation::Builder("True");
2809 HloInstruction* true_param = true_builder.AddInstruction(
2810 HloInstruction::CreateParameter(0, shape, "true_param"));
2811 HloInstruction* iota =
2812 true_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0));
2813 HloInstruction* slice = true_builder.AddInstruction(
2814 HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1}));
2815 HloInstruction* mul =
2816 true_builder.AddInstruction(HloInstruction::CreateBinary(
2817 shape, HloOpcode::kMultiply, true_param, slice));
2818 HloInstruction* negate0 = true_builder.AddInstruction(
2819 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2820 HloInstruction* negate1 = true_builder.AddInstruction(
2821 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2822 HloInstruction* negate2 = true_builder.AddInstruction(
2823 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2824 HloInstruction* negate3 = true_builder.AddInstruction(
2825 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2826 HloInstruction* negate4 = true_builder.AddInstruction(
2827 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2828 HloInstruction* negate5 = true_builder.AddInstruction(
2829 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2830 HloInstruction* negate6 = true_builder.AddInstruction(
2831 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2832 HloInstruction* negate7 = true_builder.AddInstruction(
2833 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2834 HloInstruction* add0 =
2835 true_builder.AddInstruction(HloInstruction::CreateBinary(
2836 shape, HloOpcode::kAdd, true_param, negate7));
2837 HloComputation* true_computation =
2838 module->AddEmbeddedComputation(true_builder.Build());
2839
2840 auto false_builder = HloComputation::Builder("False");
2841 HloInstruction* false_param = false_builder.AddInstruction(
2842 HloInstruction::CreateParameter(0, shape, "false_param"));
2843 HloComputation* false_computation =
2844 module->AddEmbeddedComputation(false_builder.Build());
2845
2846 auto builder = HloComputation::Builder(TestName());
2847 HloInstruction* p0 =
2848 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2849 HloInstruction* add1 = builder.AddInstruction(
2850 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2851 HloInstruction* add2 = builder.AddInstruction(
2852 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2853 HloInstruction* pred = builder.AddInstruction(
2854 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2855 HloInstruction* conditional =
2856 builder.AddInstruction(HloInstruction::CreateConditional(
2857 shape, pred, add1, true_computation, add2, false_computation));
2858 HloInstruction* add3 = builder.AddInstruction(
2859 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, conditional, add1));
2860 HloComputation* entry_computation =
2861 module->AddEntryComputation(builder.Build());
2862
2863 HloSchedule schedule(module.get());
2864 schedule.set_sequence(
2865 true_computation,
2866 {true_param, iota, slice, mul, negate0, negate1, negate2, negate3,
2867 negate4, negate5, negate6, negate7, add0});
2868 schedule.set_sequence(false_computation, {false_param});
2869 schedule.set_sequence(entry_computation,
2870 {p0, add1, add2, pred, conditional, add3});
2871 TF_CHECK_OK(module->set_schedule(schedule));
2872
2873 AssignMemorySpace(module.get(), -1, 5);
2874 }
2875
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule5)2876 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
2877 // This test reproduces the failure in b/143288178. Given a graph like the
2878 // following:
2879 //
2880 // ... = foo(a)
2881 // tuple = tuple((..., a)
2882 // ... = while(tuple) {
2883 // p = param(0)
2884 // a1 = get-tuple-element(p), index=n-1
2885 // ...
2886 // ROOT tuple((..., a1))
2887 // }
2888 //
2889 // If a copy to alternate memory is inserted before foo, and if the size of
2890 // the while body is less than max prefetch interval so that the copy-done is
2891 // kept in the alternate memory, then we end up referring to the copy-done in
2892 // the root instruction of the while loop body. I.e.,
2893 //
2894 // cs = copy-start(a)
2895 // ...
2896 // cd = copy-done(cs)
2897 // ... = foo(cd)
2898 // tuple = tuple((..., cd)
2899 // ... = while(tuple) {
2900 // p = param(0)
2901 // a1 = get-tuple-element(p), index=n-1
2902 // ...
2903 // ROOT tuple((..., cd)) <-- Error: cd belongs to outside computation.
2904 // }
2905 //
2906 auto module = CreateNewVerifiedModule();
2907 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2908 Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
2909 Shape tuple_shape =
2910 ShapeUtil::MakeTupleShape({shape, scalar_shape, scalar_shape});
2911
2912 auto cond_builder = HloComputation::Builder("WhileCond");
2913 HloInstruction* cond_param = cond_builder.AddInstruction(
2914 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
2915 HloInstruction* cond_iter = cond_builder.AddInstruction(
2916 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
2917 HloInstruction* cond_limit = cond_builder.AddInstruction(
2918 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
2919 HloInstruction* cond_lt = cond_builder.AddInstruction(
2920 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
2921 cond_limit, ComparisonDirection::kLt));
2922 HloComputation* cond_computation =
2923 module->AddEmbeddedComputation(cond_builder.Build());
2924
2925 auto body_builder = HloComputation::Builder("WhileBody");
2926 HloInstruction* body_param = body_builder.AddInstruction(
2927 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
2928 HloInstruction* body_iter = body_builder.AddInstruction(
2929 HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
2930 HloInstruction* body_data = body_builder.AddInstruction(
2931 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
2932 HloInstruction* body_iter_increment = body_builder.AddInstruction(
2933 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
2934 HloInstruction* body_iter_next =
2935 body_builder.AddInstruction(HloInstruction::CreateBinary(
2936 scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
2937 HloInstruction* body_data2 = body_builder.AddInstruction(
2938 HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 2));
2939 HloInstruction* body_out = body_builder.AddInstruction(
2940 HloInstruction::CreateTuple({body_data, body_iter_next, body_data2}));
2941 HloComputation* body_computation =
2942 module->AddEmbeddedComputation(body_builder.Build());
2943
2944 auto builder = HloComputation::Builder(TestName());
2945 HloInstruction* data = builder.AddInstruction(
2946 HloInstruction::CreateParameter(0, shape, "param_data"));
2947 HloInstruction* iter = builder.AddInstruction(
2948 HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
2949 HloInstruction* data2 = builder.AddInstruction(
2950 HloInstruction::CreateParameter(2, scalar_shape, "param_data2"));
2951 HloInstruction* negate0 = builder.AddInstruction(
2952 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
2953 HloInstruction* negate1 = builder.AddInstruction(
2954 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2955 HloInstruction* negate2 = builder.AddInstruction(
2956 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2957 HloInstruction* negate3 = builder.AddInstruction(
2958 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2959 HloInstruction* negate4 = builder.AddInstruction(
2960 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2961 HloInstruction* negate5 = builder.AddInstruction(
2962 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2963 HloInstruction* negate6 = builder.AddInstruction(
2964 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2965 HloInstruction* negate7 = builder.AddInstruction(
2966 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2967 HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
2968 scalar_shape, HloOpcode::kSubtract, iter, data2));
2969 HloInstruction* tuple = builder.AddInstruction(
2970 HloInstruction::CreateTuple({negate7, iter, data2}));
2971 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
2972 tuple_shape, cond_computation, body_computation, tuple));
2973 HloInstruction* while_data = builder.AddInstruction(
2974 HloInstruction::CreateGetTupleElement(scalar_shape, while_op, 1));
2975 HloInstruction* root =
2976 builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub}));
2977 HloComputation* entry_computation =
2978 module->AddEntryComputation(builder.Build());
2979
2980 HloSchedule schedule(module.get());
2981 schedule.set_sequence(cond_computation,
2982 {cond_param, cond_iter, cond_limit, cond_lt});
2983 schedule.set_sequence(body_computation,
2984 {body_param, body_iter, body_data, body_iter_increment,
2985 body_iter_next, body_data2, body_out});
2986 schedule.set_sequence(
2987 entry_computation,
2988 {iter, data, data2, negate0, negate1, negate2, negate3, negate4, negate5,
2989 negate6, negate7, sub, tuple, while_op, while_data, root});
2990 TF_CHECK_OK(module->set_schedule(schedule));
2991
2992 // Set a large max prefetch interval so that the buffer can be kept in
2993 // alternate memory.
2994 AssignMemorySpace(module.get(), -1, 20);
2995 }
2996
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule6)2997 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) {
2998 auto module = CreateNewVerifiedModule();
2999 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
3000 Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
3001 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape, shape});
3002
3003 auto cond_builder = HloComputation::Builder("WhileCond");
3004 HloInstruction* cond_param = cond_builder.AddInstruction(
3005 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
3006 HloInstruction* cond_iter = cond_builder.AddInstruction(
3007 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
3008 HloInstruction* cond_limit = cond_builder.AddInstruction(
3009 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
3010 HloInstruction* cond_lt = cond_builder.AddInstruction(
3011 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
3012 cond_limit, ComparisonDirection::kLt));
3013 HloComputation* cond_computation =
3014 module->AddEmbeddedComputation(cond_builder.Build());
3015
3016 auto body_builder = HloComputation::Builder("WhileBody");
3017 HloInstruction* body_param = body_builder.AddInstruction(
3018 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
3019 HloInstruction* body_iter = body_builder.AddInstruction(
3020 HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
3021 HloInstruction* body_data = body_builder.AddInstruction(
3022 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
3023 HloInstruction* body_negate0 = body_builder.AddInstruction(
3024 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_data));
3025 HloInstruction* body_negate1 = body_builder.AddInstruction(
3026 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate0));
3027 HloInstruction* body_negate2 = body_builder.AddInstruction(
3028 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate1));
3029 HloInstruction* body_negate3 = body_builder.AddInstruction(
3030 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate2));
3031 HloInstruction* body_negate4 = body_builder.AddInstruction(
3032 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate3));
3033 HloInstruction* body_negate5 = body_builder.AddInstruction(
3034 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate4));
3035 HloInstruction* body_negate6 = body_builder.AddInstruction(
3036 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate5));
3037 HloInstruction* body_negate7 = body_builder.AddInstruction(
3038 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate6));
3039 HloInstruction* body_iter_increment = body_builder.AddInstruction(
3040 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
3041 HloInstruction* body_iter_next =
3042 body_builder.AddInstruction(HloInstruction::CreateBinary(
3043 scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
3044 HloInstruction* body_out = body_builder.AddInstruction(
3045 HloInstruction::CreateTuple({body_data, body_iter_next, body_negate7}));
3046 HloComputation* body_computation =
3047 module->AddEmbeddedComputation(body_builder.Build());
3048
3049 auto builder = HloComputation::Builder(TestName());
3050 HloInstruction* data = builder.AddInstruction(
3051 HloInstruction::CreateParameter(0, shape, "param_data"));
3052 HloInstruction* iter = builder.AddInstruction(
3053 HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
3054 HloInstruction* negate0 = builder.AddInstruction(
3055 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
3056 HloInstruction* negate1 = builder.AddInstruction(
3057 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3058 HloInstruction* negate2 = builder.AddInstruction(
3059 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3060 HloInstruction* negate3 = builder.AddInstruction(
3061 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3062 HloInstruction* negate4 = builder.AddInstruction(
3063 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3064 HloInstruction* negate5 = builder.AddInstruction(
3065 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3066 HloInstruction* negate6 = builder.AddInstruction(
3067 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3068 HloInstruction* negate7 = builder.AddInstruction(
3069 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
3070 HloInstruction* tuple = builder.AddInstruction(
3071 HloInstruction::CreateTuple({data, iter, negate7}));
3072 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
3073 tuple_shape, cond_computation, body_computation, tuple));
3074 HloInstruction* while_data = builder.AddInstruction(
3075 HloInstruction::CreateGetTupleElement(shape, while_op, 0));
3076 HloInstruction* while_data2 = builder.AddInstruction(
3077 HloInstruction::CreateGetTupleElement(shape, while_op, 2));
3078 HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
3079 shape, HloOpcode::kAdd, while_data, while_data2));
3080 HloComputation* entry_computation =
3081 module->AddEntryComputation(builder.Build());
3082
3083 HloSchedule schedule(module.get());
3084 schedule.set_sequence(cond_computation,
3085 {cond_param, cond_iter, cond_limit, cond_lt});
3086 schedule.set_sequence(
3087 body_computation,
3088 {body_param, body_iter, body_data, body_negate0, body_negate1,
3089 body_negate2, body_negate3, body_negate4, body_negate5, body_negate6,
3090 body_negate7, body_iter_increment, body_iter_next, body_out});
3091 schedule.set_sequence(
3092 entry_computation,
3093 {iter, data, negate0, negate1, negate2, negate3, negate4, negate5,
3094 negate6, negate7, tuple, while_op, while_data, while_data2, root});
3095 TF_CHECK_OK(module->set_schedule(schedule));
3096
3097 // Pick a large max prefetch interval to ensure all the while inputs are
3098 // allocated in the alternate memory.
3099 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
3100 /*max_prefetch_interval=*/25);
3101
3102 // Index {0} of the while loop argument is not written inside the while loop,
3103 // so it can be trivially placed in the alternate memory space.
3104 *ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() =
3105 LayoutUtil::MakeLayout(
3106 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3107 /*element_size_in_bits=*/0, kAlternateMemorySpace);
3108 // Index {1} is a scalar, so it is always placed in the default memory.
3109 *ShapeUtil::GetMutableSubshape(&tuple_shape, {1})->mutable_layout() =
3110 LayoutUtil::MakeLayout(
3111 /*minor_to_major=*/{}, /*dim_level_types=*/{}, /*tiles=*/{},
3112 /*element_size_in_bits=*/0, kDefaultMemorySpace);
3113 // Index {2} of the while loop is placed in the default memory.
3114 *ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() =
3115 LayoutUtil::MakeLayout(
3116 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3117 /*element_size_in_bits=*/0, kDefaultMemorySpace);
3118
3119 // Expect the layout for the while loop and its aliased buffers.
3120 EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape));
3121 EXPECT_THAT(while_op->operand(0), op::ShapeWithLayout(tuple_shape));
3122 EXPECT_THAT(cond_param, op::ShapeWithLayout(tuple_shape));
3123 EXPECT_THAT(body_param, op::ShapeWithLayout(tuple_shape));
3124 EXPECT_THAT(body_out, op::ShapeWithLayout(tuple_shape));
3125 }
3126
TEST_P(MemorySpaceAssignmentTest,DanglingCopy)3127 TEST_P(MemorySpaceAssignmentTest, DanglingCopy) {
3128 // This situation was encountered in vss, where there is a mismatch in the
3129 // memory space in preset assignments and the output graph.
3130 HloComputation::Builder builder(TestName());
3131 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3132 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3133
3134 HloInstruction* p = builder.AddInstruction(
3135 HloInstruction::CreateParameter(0, tuple_shape, "p"));
3136 HloInstruction* p0 = builder.AddInstruction(
3137 HloInstruction::CreateGetTupleElement(shape, p, 0));
3138 HloInstruction* p1a = builder.AddInstruction(
3139 HloInstruction::CreateGetTupleElement(shape, p, 1));
3140 HloInstruction* copy = builder.AddInstruction(
3141 HloInstruction::CreateUnary(shape, HloOpcode::kCopy, p1a));
3142 HloInstruction* negate0 = builder.AddInstruction(
3143 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3144 HloInstruction* negate1 = builder.AddInstruction(
3145 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3146 HloInstruction* negate2 = builder.AddInstruction(
3147 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3148 HloInstruction* negate3 = builder.AddInstruction(
3149 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3150 HloInstruction* negate4 = builder.AddInstruction(
3151 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3152 HloInstruction* negate5 = builder.AddInstruction(
3153 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3154 HloInstruction* negate6 = builder.AddInstruction(
3155 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3156 HloInstruction* p1b = builder.AddInstruction(
3157 HloInstruction::CreateGetTupleElement(shape, p, 1));
3158 HloInstruction* add = builder.AddInstruction(
3159 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1b));
3160
3161 auto module = CreateNewVerifiedModule();
3162 HloComputation* computation = module->AddEntryComputation(builder.Build());
3163
3164 HloSchedule schedule(module.get());
3165 schedule.set_sequence(
3166 computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
3167 negate6, p1a, copy, p1b, add});
3168 TF_CHECK_OK(module->set_schedule(schedule));
3169
3170 AssignMemorySpace(module.get());
3171 }
3172
TEST_P(MemorySpaceAssignmentTest,MultiOutputFusion)3173 TEST_P(MemorySpaceAssignmentTest, MultiOutputFusion) {
3174 HloComputation::Builder builder(TestName());
3175 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3176 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3177 auto module = CreateNewVerifiedModule();
3178
3179 HloComputation::Builder fusion_builder("fusion");
3180 HloInstruction* fusion_param0 = fusion_builder.AddInstruction(
3181 HloInstruction::CreateParameter(0, shape, "p0"));
3182 HloInstruction* fusion_param1 = fusion_builder.AddInstruction(
3183 HloInstruction::CreateParameter(1, shape, "p1"));
3184 fusion_builder.AddInstruction(
3185 HloInstruction::CreateTuple({fusion_param0, fusion_param1}));
3186 HloComputation* fusion_computation =
3187 module->AddEmbeddedComputation(fusion_builder.Build());
3188
3189 HloInstruction* p0 =
3190 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3191 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
3192 tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3193 fusion_computation));
3194 HloInstruction* element0 = builder.AddInstruction(
3195 HloInstruction::CreateGetTupleElement(shape, fusion, 0));
3196 HloInstruction* element1 = builder.AddInstruction(
3197 HloInstruction::CreateGetTupleElement(shape, fusion, 1));
3198 HloInstruction* add = builder.AddInstruction(
3199 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1));
3200
3201 HloComputation* computation = module->AddEntryComputation(builder.Build());
3202
3203 HloSchedule schedule(module.get());
3204 schedule.set_sequence(computation, {p0, fusion, element0, element1, add});
3205 TF_CHECK_OK(module->set_schedule(schedule));
3206
3207 AssignMemorySpace(module.get());
3208 }
3209
TEST_P(MemorySpaceAssignmentTest,TupleInput)3210 TEST_P(MemorySpaceAssignmentTest, TupleInput) {
3211 HloComputation::Builder builder(TestName());
3212 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3213 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3214 auto module = CreateNewVerifiedModule();
3215
3216 HloComputation::Builder fusion_builder("fusion");
3217 HloInstruction* fusion_param = fusion_builder.AddInstruction(
3218 HloInstruction::CreateParameter(0, tuple_shape, "p"));
3219 HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
3220 HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
3221 HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
3222 HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
3223 fusion_builder.AddInstruction(HloInstruction::CreateBinary(
3224 shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
3225 HloComputation* fusion_computation =
3226 module->AddEmbeddedComputation(fusion_builder.Build());
3227
3228 HloInstruction* p0 =
3229 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3230 HloInstruction* p1 =
3231 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3232 HloInstruction* negate0 = builder.AddInstruction(
3233 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3234 HloInstruction* negate1 = builder.AddInstruction(
3235 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1));
3236 HloInstruction* tuple =
3237 builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
3238 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
3239 shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
3240
3241 HloComputation* computation = module->AddEntryComputation(builder.Build());
3242
3243 HloSchedule schedule(module.get());
3244 schedule.set_sequence(computation, {p0, p1, negate0, negate1, tuple, fusion});
3245 TF_CHECK_OK(module->set_schedule(schedule));
3246
3247 AssignMemorySpace(module.get());
3248 }
3249
TEST_P(MemorySpaceAssignmentTest,TupleToTuple1)3250 TEST_P(MemorySpaceAssignmentTest, TupleToTuple1) {
3251 HloComputation::Builder builder(TestName());
3252 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3253 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3254 auto module = CreateNewVerifiedModule();
3255
3256 HloComputation::Builder fusion0_builder("fusion0");
3257 HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3258 HloInstruction::CreateParameter(0, shape, "p0"));
3259 HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3260 HloInstruction::CreateParameter(1, shape, "p1"));
3261 fusion0_builder.AddInstruction(
3262 HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3263 HloComputation* fusion0_computation =
3264 module->AddEmbeddedComputation(fusion0_builder.Build());
3265
3266 HloComputation::Builder fusion1_builder("fusion1");
3267 HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3268 HloInstruction::CreateParameter(0, tuple_shape, "p"));
3269 HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3270 HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3271 HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3272 HloInstruction::CreateGetTupleElement(shape, fusion1_param, 1));
3273 fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3274 shape, HloOpcode::kAdd, fusion1_element0, fusion1_element1));
3275 HloComputation* fusion1_computation =
3276 module->AddEmbeddedComputation(fusion1_builder.Build());
3277
3278 HloInstruction* p0 =
3279 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3280 HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3281 tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3282 fusion0_computation));
3283 HloInstruction* element0 = builder.AddInstruction(
3284 HloInstruction::CreateGetTupleElement(shape, fusion0, 0));
3285 HloInstruction* element1 = builder.AddInstruction(
3286 HloInstruction::CreateGetTupleElement(shape, fusion0, 1));
3287 HloInstruction* negate0 = builder.AddInstruction(
3288 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3289 HloInstruction* negate1 = builder.AddInstruction(
3290 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3291 HloInstruction* negate2 = builder.AddInstruction(
3292 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3293 HloInstruction* negate3 = builder.AddInstruction(
3294 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3295 HloInstruction* negate4 = builder.AddInstruction(
3296 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3297 HloInstruction* negate5 = builder.AddInstruction(
3298 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3299 HloInstruction* negate6 = builder.AddInstruction(
3300 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3301 HloInstruction* add0 = builder.AddInstruction(
3302 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1));
3303 HloInstruction* add1 = builder.AddInstruction(
3304 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, negate6));
3305 HloInstruction* fusion1 = builder.AddInstruction(
3306 HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3307 {fusion0}, fusion1_computation));
3308 HloInstruction* mul = builder.AddInstruction(
3309 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add1, fusion1));
3310
3311 HloComputation* computation = module->AddEntryComputation(builder.Build());
3312
3313 HloSchedule schedule(module.get());
3314 schedule.set_sequence(
3315 computation,
3316 {p0, fusion0, element0, element1, negate0, negate1, negate2, negate3,
3317 negate4, negate5, negate6, add0, add1, fusion1, mul});
3318 TF_CHECK_OK(module->set_schedule(schedule));
3319
3320 AssignMemorySpace(module.get(), -1, 5);
3321 EXPECT_THAT(fusion1,
3322 op::Fusion(op::Tuple(
3323 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3324 op::GetTupleElement(op::Fusion(), 0)),
3325 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3326 op::GetTupleElement(op::Fusion(), 1)))));
3327 }
3328
TEST_P(MemorySpaceAssignmentTest,TupleToTuple2)3329 TEST_P(MemorySpaceAssignmentTest, TupleToTuple2) {
3330 HloComputation::Builder builder(TestName());
3331 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3332 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3333 Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({shape, tuple_shape});
3334 auto module = CreateNewVerifiedModule();
3335
3336 HloComputation::Builder fusion0_builder("fusion0");
3337 HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3338 HloInstruction::CreateParameter(0, shape, "p0"));
3339 HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3340 HloInstruction::CreateParameter(1, shape, "p1"));
3341 HloInstruction* fusion0_tuple = fusion0_builder.AddInstruction(
3342 HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3343 fusion0_builder.AddInstruction(
3344 HloInstruction::CreateTuple({fusion0_param0, fusion0_tuple}));
3345 HloComputation* fusion0_computation =
3346 module->AddEmbeddedComputation(fusion0_builder.Build());
3347
3348 HloComputation::Builder fusion1_builder("fusion1");
3349 HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3350 HloInstruction::CreateParameter(0, nested_tuple_shape, "p"));
3351 HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3352 HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3353 HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3354 HloInstruction::CreateGetTupleElement(tuple_shape, fusion1_param, 1));
3355 HloInstruction* fusion1_element2 = fusion1_builder.AddInstruction(
3356 HloInstruction::CreateGetTupleElement(shape, fusion1_element1, 1));
3357 fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3358 shape, HloOpcode::kAdd, fusion1_element0, fusion1_element2));
3359 HloComputation* fusion1_computation =
3360 module->AddEmbeddedComputation(fusion1_builder.Build());
3361
3362 HloInstruction* p0 =
3363 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3364 HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3365 nested_tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3366 fusion0_computation));
3367 HloInstruction* negate0 = builder.AddInstruction(
3368 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3369 HloInstruction* negate1 = builder.AddInstruction(
3370 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3371 HloInstruction* negate2 = builder.AddInstruction(
3372 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3373 HloInstruction* negate3 = builder.AddInstruction(
3374 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3375 HloInstruction* negate4 = builder.AddInstruction(
3376 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3377 HloInstruction* negate5 = builder.AddInstruction(
3378 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3379 HloInstruction* negate6 = builder.AddInstruction(
3380 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3381 HloInstruction* fusion1 = builder.AddInstruction(
3382 HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3383 {fusion0}, fusion1_computation));
3384
3385 HloComputation* computation = module->AddEntryComputation(builder.Build());
3386
3387 HloSchedule schedule(module.get());
3388 schedule.set_sequence(
3389 computation, {p0, fusion0, negate0, negate1, negate2, negate3, negate4,
3390 negate5, negate6, fusion1});
3391 TF_CHECK_OK(module->set_schedule(schedule));
3392
3393 AssignMemorySpace(module.get(), -1, 5);
3394
3395 EXPECT_THAT(
3396 fusion1,
3397 op::Fusion(op::Tuple(
3398 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3399 op::GetTupleElement(op::Fusion(), 0)),
3400 op::Tuple(
3401 op::AsyncCopy(
3402 kAlternateMemorySpace, kDefaultMemorySpace,
3403 op::GetTupleElement(op::GetTupleElement(op::Fusion(), 1), 0)),
3404 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3405 op::GetTupleElement(
3406 op::GetTupleElement(op::Fusion(), 1), 1))))));
3407 }
3408
TEST_P(MemorySpaceAssignmentTest,TupleToTuple3)3409 TEST_P(MemorySpaceAssignmentTest, TupleToTuple3) {
3410 HloComputation::Builder builder(TestName());
3411 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3412 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3413 auto module = CreateNewVerifiedModule();
3414
3415 HloComputation::Builder fusion0_builder("fusion0");
3416 HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3417 HloInstruction::CreateParameter(0, shape, "p0"));
3418 HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3419 HloInstruction::CreateParameter(1, shape, "p1"));
3420 fusion0_builder.AddInstruction(
3421 HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3422 HloComputation* fusion0_computation =
3423 module->AddEmbeddedComputation(fusion0_builder.Build());
3424
3425 HloComputation::Builder fusion1_builder("fusion1");
3426 HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3427 HloInstruction::CreateParameter(0, tuple_shape, "p"));
3428 HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3429 HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3430 HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3431 HloInstruction::CreateGetTupleElement(shape, fusion1_param, 1));
3432 fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3433 shape, HloOpcode::kAdd, fusion1_element0, fusion1_element1));
3434 HloComputation* fusion1_computation =
3435 module->AddEmbeddedComputation(fusion1_builder.Build());
3436
3437 HloInstruction* p0 =
3438 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3439 HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3440 tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3441 fusion0_computation));
3442 HloInstruction* fusion1 = builder.AddInstruction(
3443 HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3444 {fusion0}, fusion1_computation));
3445
3446 HloComputation* computation = module->AddEntryComputation(builder.Build());
3447
3448 HloSchedule schedule(module.get());
3449 schedule.set_sequence(computation, {p0, fusion0, fusion1});
3450 TF_CHECK_OK(module->set_schedule(schedule));
3451
3452 AssignMemorySpace(module.get());
3453 EXPECT_THAT(fusion1, op::Fusion(op::Fusion()));
3454 }
3455
TEST_P(MemorySpaceAssignmentTest,InputOutputAlias)3456 TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) {
3457 HloComputation::Builder builder(TestName());
3458 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3459 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3460 HloInstruction* p = builder.AddInstruction(
3461 HloInstruction::CreateParameter(0, tuple_shape, "p"));
3462 HloInstruction* p0 = builder.AddInstruction(
3463 HloInstruction::CreateGetTupleElement(shape, p, 0));
3464 HloInstruction* negate0 = builder.AddInstruction(
3465 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3466 HloInstruction* negate1 = builder.AddInstruction(
3467 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3468 HloInstruction* negate2 = builder.AddInstruction(
3469 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3470 HloInstruction* negate3 = builder.AddInstruction(
3471 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3472 HloInstruction* negate4 = builder.AddInstruction(
3473 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3474 HloInstruction* negate5 = builder.AddInstruction(
3475 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3476 HloInstruction* negate6 = builder.AddInstruction(
3477 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3478 HloInstruction* p1 = builder.AddInstruction(
3479 HloInstruction::CreateGetTupleElement(shape, p, 1));
3480 HloInstruction* add = builder.AddInstruction(
3481 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
3482 HloInstruction* negate7 = builder.AddInstruction(
3483 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add));
3484 HloInstruction* tuple =
3485 builder.AddInstruction(HloInstruction::CreateTuple({p0, add}));
3486
3487 auto module = CreateNewVerifiedModule();
3488 HloComputation* computation = module->AddEntryComputation(builder.Build());
3489
3490 HloSchedule schedule(module.get());
3491 schedule.set_sequence(
3492 computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
3493 negate6, p1, add, negate7, tuple});
3494 TF_CHECK_OK(module->set_schedule(schedule));
3495
3496 // Make input {0} alias with output {0} and input {1} alias with output {1}.
3497 TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
3498 TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
3499
3500 AssignMemorySpace(module.get());
3501
3502 // Make sure the input is in the default memory space.
3503 EXPECT_EQ(p->shape().tuple_shapes(0).layout().memory_space(),
3504 kDefaultMemorySpace);
3505 EXPECT_EQ(p->shape().tuple_shapes(1).layout().memory_space(),
3506 kDefaultMemorySpace);
3507 }
3508
TEST_P(MemorySpaceAssignmentTest,CostAnalysis)3509 TEST_P(MemorySpaceAssignmentTest, CostAnalysis) {
3510 // This is mostly a smoke test since it's difficult and brittle to work out
3511 // the cost of the HLO instructions.
3512 HloComputation::Builder builder(TestName());
3513 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3514 HloInstruction* p0 =
3515 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3516 HloInstruction* p1 =
3517 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3518 HloInstruction* negate0 = builder.AddInstruction(
3519 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3520 HloInstruction* negate1 = builder.AddInstruction(
3521 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3522 HloInstruction* negate2 = builder.AddInstruction(
3523 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3524 HloInstruction* negate3 = builder.AddInstruction(
3525 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3526 HloInstruction* negate4 = builder.AddInstruction(
3527 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3528 HloInstruction* negate5 = builder.AddInstruction(
3529 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3530 HloInstruction* negate6 = builder.AddInstruction(
3531 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3532 HloInstruction* add = builder.AddInstruction(
3533 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
3534
3535 auto module = CreateNewVerifiedModule();
3536 HloComputation* computation = module->AddEntryComputation(builder.Build());
3537
3538 HloSchedule schedule(module.get());
3539 schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
3540 negate3, negate4, negate5, negate6, add});
3541 TF_CHECK_OK(module->set_schedule(schedule));
3542
3543 AssignMemorySpaceUsingCostAnalysis(module.get());
3544 // Parameters are in the default memory space.
3545 EXPECT_THAT(p0, op::ShapeWithLayout(shape));
3546 EXPECT_THAT(p1, op::ShapeWithLayout(shape));
3547 // Negate instructions are in the alternate memory space (1).
3548 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
3549 F32, {2, 3},
3550 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3551 /*element_size_in_bits=*/0, kAlternateMemorySpace);
3552 EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
3553 EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
3554 EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
3555 EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
3556 EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
3557 EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem));
3558 EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
3559 }
3560
TEST_P(MemorySpaceAssignmentTest,MemoryBoundednessBufferIntervalCompare)3561 TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
3562 // This test is carefully crafted to force only negates to be allocated to the
3563 // alternate memory. The graph consists of interleaving negate and tanh
3564 // operations:
3565 //
3566 // +------+ +-------+ +-----
3567 // / \ / \ /
3568 // negate tanh negate tanh negate tanh
3569 // \ / \ /
3570 // +--------+ +---------+
3571 //
3572 // The alternate memory is sized to fit only two f32[4,3] tensors at a time.
3573 // Also, transcendentals are made to be lower bandwidth than FLOPs. So, the
3574 // MemoryBoundednessBufferIntervalCompare should prioritize the negates, which
3575 // are more memory bound.
3576 HloComputation::Builder builder(TestName());
3577 Shape shape = ShapeUtil::MakeShape(F32, {4, 3});
3578 HloInstruction* p0 =
3579 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3580 HloInstruction* p1 =
3581 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3582 HloInstruction* tanh0 = builder.AddInstruction(
3583 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3584 HloInstruction* negate0 = builder.AddInstruction(
3585 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1));
3586 HloInstruction* tanh1 = builder.AddInstruction(
3587 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh0));
3588 HloInstruction* negate1 = builder.AddInstruction(
3589 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3590 HloInstruction* tanh2 = builder.AddInstruction(
3591 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
3592 HloInstruction* negate2 = builder.AddInstruction(
3593 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3594 HloInstruction* tanh3 = builder.AddInstruction(
3595 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
3596 HloInstruction* negate3 = builder.AddInstruction(
3597 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3598 HloInstruction* tanh4 = builder.AddInstruction(
3599 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh3));
3600 HloInstruction* negate4 = builder.AddInstruction(
3601 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3602 HloInstruction* tuple =
3603 builder.AddInstruction(HloInstruction::CreateTuple({tanh4, negate4}));
3604
3605 auto module = CreateNewVerifiedModule();
3606 HloComputation* computation = module->AddEntryComputation(builder.Build());
3607
3608 HloSchedule schedule(module.get());
3609 schedule.set_sequence(computation,
3610 {p0, p1, tanh0, negate0, tanh1, negate1, tanh2, negate2,
3611 tanh3, negate3, tanh4, negate4, tuple});
3612 TF_CHECK_OK(module->set_schedule(schedule));
3613
3614 AssignMemorySpaceUsingCostAnalysis(module.get());
3615 // Parameters are in the default memory space.
3616 EXPECT_THAT(p0, op::ShapeWithLayout(shape));
3617 EXPECT_THAT(p1, op::ShapeWithLayout(shape));
3618 Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3619 F32, {4, 3},
3620 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3621 /*element_size_in_bits=*/0, kDefaultMemorySpace);
3622 // Expect only negates to be in alternate memory space. Not all might fit but
3623 // make sure at least one does.
3624 std::vector<HloInstruction*> negate_instructions = {negate0, negate1, negate2,
3625 negate3, negate4};
3626 int64_t num_negates_in_alternate_mem = absl::c_count_if(
3627 negate_instructions, [&](const HloInstruction* instruction) {
3628 return instruction->shape().layout().memory_space() ==
3629 kAlternateMemorySpace;
3630 });
3631 EXPECT_GE(num_negates_in_alternate_mem, 1);
3632 EXPECT_THAT(tanh0, op::ShapeWithLayout(shape_in_default_mem));
3633 EXPECT_THAT(tanh1, op::ShapeWithLayout(shape_in_default_mem));
3634 EXPECT_THAT(tanh2, op::ShapeWithLayout(shape_in_default_mem));
3635 EXPECT_THAT(tanh3, op::ShapeWithLayout(shape_in_default_mem));
3636 EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem));
3637 }
3638
TEST_P(MemorySpaceAssignmentTest,SimpleWhileTupleTest)3639 TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) {
3640 Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
3641 Shape f32v1 = ShapeUtil::MakeShape(F32, {1});
3642 Shape t_s32_f32v1 = ShapeUtil::MakeTupleShape({s32, f32v1});
3643 auto module = CreateNewVerifiedModule("SimpleWhile");
3644 HloSchedule schedule(module.get());
3645
3646 // A simple compare-to-limit (x < 4) computation for a While.
3647 //
3648 // condition:
3649 // const4[s32] -----------------------------------\
3650 // \
3651 // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
3652 //
3653 HloComputation* cond_computation;
3654 {
3655 auto builder = HloComputation::Builder("WhileCond");
3656 auto const4 = builder.AddInstruction(
3657 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
3658 auto param = builder.AddInstruction(
3659 HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
3660 auto index = builder.AddInstruction(
3661 HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
3662 auto compare = builder.AddInstruction(
3663 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
3664 const4, ComparisonDirection::kLt));
3665 cond_computation = module->AddEmbeddedComputation(builder.Build());
3666 schedule.set_sequence(cond_computation, {const4, param, index, compare});
3667 }
3668
3669 // Builds a simple body computation for a While.
3670 //
3671 // body:
3672 // constv[f32[1]] --------------------------------------\
3673 // \
3674 // /--- get-tuple-elementv[1] --- addv ---\
3675 // param[(s32,f32[1])] ---| tuple
3676 // \--- get-tuple-elementc[0] --- addc ---/
3677 // /
3678 // const1[s32] -----------------------------------------/
3679 //
3680 HloComputation* body_computation;
3681 {
3682 auto builder = HloComputation::Builder("WhileBody");
3683 auto const1 = builder.AddInstruction(
3684 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
3685 auto constv = builder.AddInstruction(
3686 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.1f})));
3687 auto param = builder.AddInstruction(
3688 HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
3689 auto indexc = builder.AddInstruction(
3690 HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
3691 auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
3692 indexc->shape(), HloOpcode::kAdd, indexc, const1));
3693 auto indexv = builder.AddInstruction(
3694 HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
3695 auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
3696 constv->shape(), HloOpcode::kAdd, indexv, constv));
3697 auto tuple =
3698 builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
3699 body_computation = module->AddEmbeddedComputation(builder.Build());
3700 schedule.set_sequence(body_computation, {const1, constv, param, indexc,
3701 addc, indexv, addv, tuple});
3702 }
3703
3704 // This tests a simple while loop where the parameters are aliased with the
3705 // output buffers.
3706 auto builder = HloComputation::Builder("SimpleWhile");
3707 auto param = builder.AddInstruction(
3708 HloInstruction::CreateParameter(0, t_s32_f32v1, "param"));
3709 auto gte0 = builder.AddInstruction(
3710 HloInstruction::CreateGetTupleElement(s32, param, 0));
3711 auto gte1 = builder.AddInstruction(
3712 HloInstruction::CreateGetTupleElement(f32v1, param, 1));
3713 auto tuple =
3714 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
3715 auto while0 = builder.AddInstruction(HloInstruction::CreateWhile(
3716 t_s32_f32v1, cond_computation, body_computation, tuple));
3717
3718 HloComputation* computation = module->AddEntryComputation(builder.Build());
3719 schedule.set_sequence(computation, {param, gte0, gte1, tuple, while0});
3720 TF_CHECK_OK(module->set_schedule(schedule));
3721
3722 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
3723 /*max_prefetch_interval=*/50);
3724
3725 // Ensure all parameters and while are placed in default memory.
3726 Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3727 F32, {4, 6},
3728 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3729 /*element_size_in_bits=*/0, kDefaultMemorySpace);
3730 Shape s32_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3731 xla::S32, {},
3732 /*minor_to_major=*/{}, /*dim_level_types=*/{}, /*tiles=*/{},
3733 /*element_size_in_bits=*/0, kDefaultMemorySpace);
3734 Shape f32v1_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3735 F32, {1},
3736 /*minor_to_major=*/{0}, /*dim_level_types=*/{}, /*tiles=*/{},
3737 /*element_size_in_bits=*/0, kDefaultMemorySpace);
3738 Shape t_s32_f32v1_in_default_mem =
3739 ShapeUtil::MakeTupleShape({s32_in_default_mem, f32v1_in_default_mem});
3740 EXPECT_THAT(param, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
3741 EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
3742 }
3743
TEST_P(MemorySpaceAssignmentTest,EvictionsShouldntBeDelayed)3744 TEST_P(MemorySpaceAssignmentTest, EvictionsShouldntBeDelayed) {
3745 // This test reproduces an eviction scheduling bug where evictions to default
3746 // memory can happen later than intended, causing memory corruption. This test
3747 // is a variant of MemoryBoundednessBufferIntervalCompare but uses f32[4,3]
3748 // tensors instead, so at most two tensors should fit in the alternate memory
3749 // space at a given time. We have a number of redundant operations
3750 // (tanh_redundant ops) that do not have users. The bug was due to
3751 // SimplifyGraph removing dead instructions, and removing them from the
3752 // schedule. However, the CopyStart/CopyDone insertion relies on the schedule
3753 // indexes, so they could be inserted too late.
3754 HloComputation::Builder builder(TestName());
3755 Shape shape = ShapeUtil::MakeShape(F32, {4, 3});
3756 HloInstruction* p0 =
3757 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3758 HloInstruction* tanh0 = builder.AddInstruction(
3759 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3760 HloInstruction* tanh_redundant0 = builder.AddInstruction(
3761 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3762 HloInstruction* tanh_redundant1 = builder.AddInstruction(
3763 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3764 HloInstruction* tanh_redundant2 = builder.AddInstruction(
3765 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3766 HloInstruction* tanh_redundant3 = builder.AddInstruction(
3767 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3768 HloInstruction* tanh_redundant4 = builder.AddInstruction(
3769 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3770 HloInstruction* tanh_redundant5 = builder.AddInstruction(
3771 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3772 HloInstruction* tanh_redundant6 = builder.AddInstruction(
3773 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3774 HloInstruction* negate0 = builder.AddInstruction(
3775 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, tanh0));
3776 HloInstruction* tanh1 = builder.AddInstruction(
3777 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, negate0));
3778 HloInstruction* negate1 = builder.AddInstruction(
3779 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3780 HloInstruction* tanh2 = builder.AddInstruction(
3781 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
3782 HloInstruction* negate2 = builder.AddInstruction(
3783 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3784 HloInstruction* tanh3 = builder.AddInstruction(
3785 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
3786 HloInstruction* negate3 = builder.AddInstruction(
3787 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3788 HloInstruction* tuple = builder.AddInstruction(
3789 HloInstruction::CreateTuple({tanh3, negate3, tanh0}));
3790
3791 auto module = CreateNewVerifiedModule();
3792 HloComputation* computation = module->AddEntryComputation(builder.Build());
3793
3794 HloSchedule schedule(module.get());
3795 schedule.set_sequence(
3796 computation,
3797 {p0, tanh0, tanh_redundant0, tanh_redundant1, tanh_redundant2,
3798 tanh_redundant3, tanh_redundant4, tanh_redundant5, tanh_redundant6,
3799 negate0, tanh1, negate1, tanh2, negate2, tanh3, negate3, tuple});
3800 TF_CHECK_OK(module->set_schedule(schedule));
3801
3802 AssignMemorySpaceUsingCostAnalysis(module.get());
3803
3804 TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis,
3805 HloAliasAnalysis::Run(module.get()));
3806 TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range,
3807 HloLiveRange::Run(module->schedule(), *alias_analysis,
3808 module->entry_computation()));
3809
3810 std::vector<int> num_live_buffers_in_alternate_mem(
3811 hlo_live_range->flattened_instruction_sequence().size() + 1, 0);
3812
3813 // Go through each value and for those that are allocated in the alternate
3814 // memory space, increment (inclusive) num_live_buffers_in_alternate_mem for
3815 // every time step that they are live.
3816 for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
3817 const Shape& shape = value->shape();
3818 if (!shape.has_layout() ||
3819 shape.layout().memory_space() == kDefaultMemorySpace) {
3820 continue;
3821 }
3822
3823 HloLiveRange::TimeBound time_bound =
3824 hlo_live_range->buffer_live_ranges().at(value);
3825 for (int i = time_bound.start; i <= time_bound.end; ++i) {
3826 ++num_live_buffers_in_alternate_mem[i];
3827 }
3828 }
3829
3830 // The test memory can at most hold two f32[4,3] buffers at a time. If there
3831 // is more than that, it means we have memory corruption.
3832 for (int i = 0; i < num_live_buffers_in_alternate_mem.size(); ++i) {
3833 EXPECT_LE(num_live_buffers_in_alternate_mem[i], 2);
3834 }
3835 }
3836
TEST_P(MemorySpaceAssignmentTest,InputOutputsInAlternateMemShouldntBeAssigned)3837 TEST_P(MemorySpaceAssignmentTest,
3838 InputOutputsInAlternateMemShouldntBeAssigned) {
3839 // When input/outputs are marked to be in the alternate memory (e.g.
3840 // go/tpu-fast-mem-inference), do not allocate those and assume they will live
3841 // in the alternate memory for the entire computation. The BufferAssignment
3842 // pass, which is run after this, will allocate those buffers.
3843 HloComputation::Builder builder(TestName());
3844 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3845 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
3846 F32, {2, 3},
3847 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3848 /*element_size_in_bits=*/0, kAlternateMemorySpace);
3849 // p0 is in the default memory space.
3850 HloInstruction* p0 =
3851 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3852 // p1 is in the alternate memory space.
3853 HloInstruction* p1 = builder.AddInstruction(
3854 HloInstruction::CreateParameter(1, shape_in_alternate_mem, "p1"));
3855 HloInstruction* negate0 = builder.AddInstruction(
3856 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3857 HloInstruction* negate1 = builder.AddInstruction(
3858 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3859 HloInstruction* negate2 = builder.AddInstruction(
3860 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3861 HloInstruction* negate3 = builder.AddInstruction(
3862 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3863 HloInstruction* negate4 = builder.AddInstruction(
3864 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3865 HloInstruction* negate5 = builder.AddInstruction(
3866 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3867 HloInstruction* negate6 = builder.AddInstruction(
3868 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3869 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
3870 shape_in_alternate_mem, HloOpcode::kAdd, negate6, p1));
3871 // Index {0} of the root instruction is in the alternate memory space, index
3872 // {1} is in the default memory space.
3873 HloInstruction* tuple =
3874 builder.AddInstruction(HloInstruction::CreateTuple({add, negate5}));
3875
3876 auto module = CreateNewVerifiedModule();
3877 HloComputation* computation = module->AddEntryComputation(builder.Build());
3878
3879 HloSchedule schedule(module.get());
3880 schedule.set_sequence(computation,
3881 {p0, p1, negate0, negate1, negate2, negate3, negate4,
3882 negate5, negate6, add, tuple});
3883 TF_CHECK_OK(module->set_schedule(schedule));
3884
3885 Options options;
3886 options.max_size_in_bytes = 128;
3887 options.alignment_in_bytes = 8;
3888 options.verify = true;
3889 options.is_allowed_in_alternate_mem_fn = [](const HloValue& value) {
3890 return true;
3891 };
3892 std::unique_ptr<PresetAssignments> preset_assignments = AssignMemorySpace(
3893 module.get(),
3894 /*max_outstanding_async_copies=*/-1,
3895 /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2, options);
3896
3897 // Ensure that p1 is in the alternate memory and add, which has p1 as an
3898 // operand, has a direct dependency to p1 (no CopyStart/CopyDone).
3899 EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_alternate_mem));
3900 EXPECT_THAT(add, op::Add(op::Negate(), op::Parameter(1)));
3901 // Make sure add is still in the alternate memory space.
3902 EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
3903
3904 // Check the preset assignments and ensure the inputs/outputs in the alternate
3905 // memory space aren't in the preset assignments. Inputs/outputs in the
3906 // alternate memory space are left to BufferAssignment to be allocated.
3907 for (const auto& position_and_chunk : preset_assignments->chunks()) {
3908 const HloPosition& position = position_and_chunk.first;
3909 EXPECT_NE(position.instruction, p1);
3910 EXPECT_NE(position.instruction, add);
3911 }
3912 }
3913
TEST_P(MemorySpaceAssignmentTest,PendingChunkMemoryCorruptionBug)3914 TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) {
3915 // Tests a memory corruption bug where the allocated chunk overlaps with a
3916 // pending chunk. To test this, we provide a new buffer interval compare where
3917 // we prioritize the allocation of sine, cosine, and tanh to create the
3918 // situation:
3919 //
3920 // Max memory
3921 // -------------------------------------------
3922 // +------------+
3923 // | b |
3924 // +------------+
3925 // +-------+
3926 // | |
3927 // | |
3928 // | a |
3929 // | | +------------+
3930 // | | | n |
3931 // +-------+ +------------+
3932 // -------------------------------------------
3933 // Min memory time ->
3934 //
3935 //
3936 // Then allocating for buffer d, we have these two prefetch buffers
3937 // overlapping:
3938 //
3939 // Max memory
3940 // -------------------------------------------
3941 // +------------+ +----------+
3942 // | b | | prefetch |
3943 // +------------+ | for o |
3944 // +-------+ +---------+ |
3945 // | | | | | |
3946 // | | | | | |
3947 // | a | | +----|-----+
3948 // | | | prefetch| +------------+
3949 // | | | for m | | n |
3950 // +-------+ +---------+ +------------+
3951 // -------------------------------------------
3952 // Min memory time ->
3953 //
3954 absl::string_view hlo_string = R"(
3955 HloModule bug, is_scheduled=true
3956
3957 ENTRY %Entry {
3958 %param0 = f32[8,3] parameter(0)
3959 %param1 = f32[2,4] parameter(1)
3960 %a = f32[8,3] sine(%param0)
3961 %b = f32[2,4] cosine(%param1)
3962 %d = f32[8,3] tanh(%a)
3963 %c = f32[8,3] negate(%a)
3964 %e = f32[2,4] negate(%b)
3965 %f = f32[2,4] negate(%e)
3966 %g = f32[2,4] negate(%f)
3967 %h = f32[2,4] negate(%g)
3968 %i = f32[2,4] negate(%h)
3969 %j = f32[2,4] negate(%i)
3970 %k = f32[2,4] negate(%j)
3971 %l = f32[2,4] negate(%k)
3972 %m = f32[8,3] negate(%d)
3973 %n = f32[2,4] sine(%l)
3974 %o = f32[8,3] negate(%d)
3975 %p = f32[2,4] negate(%n)
3976 %q = f32[8,3] negate(%m)
3977 ROOT %tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(%p, %q, %o)
3978 }
3979 )";
3980
3981 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
3982 [](const MemorySpaceAssignment::BufferInterval& a,
3983 const MemorySpaceAssignment::BufferInterval& b) {
3984 auto get_opcode_priority = [](const HloOpcode& opcode) {
3985 switch (opcode) {
3986 case HloOpcode::kSin:
3987 return 0;
3988 case HloOpcode::kCos:
3989 return 1;
3990 case HloOpcode::kTanh:
3991 return 2;
3992 default:
3993 return 3;
3994 }
3995 };
3996
3997 return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
3998 get_opcode_priority(b.buffer->defining_instruction()->opcode());
3999 };
4000 TF_ASSERT_OK_AND_ASSIGN(auto module,
4001 ParseAndReturnVerifiedModule(hlo_string));
4002
4003 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4004 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4005 buffer_interval_compare, &prefetch_interval_picker);
4006 }
4007
TEST_P(MemorySpaceAssignmentTest,WhileAliasedArgumentRequiredAssignmentBug)4008 TEST_P(MemorySpaceAssignmentTest, WhileAliasedArgumentRequiredAssignmentBug) {
4009 // Tests an overly pessimistic assertion when the same HloValue is passed
4010 // multiple times to a while HLO. We already handle this case that the two
4011 // arguments must alias and get the same allocation in AllocateSegment so the
4012 // assertion isn't necessary.
4013 absl::string_view hlo_string = R"(
4014 HloModule bug, is_scheduled=true
4015
4016 while_condition {
4017 param1 = (f32[2,4], f32[2,4], f32[2,4]) parameter(0)
4018 ROOT cond = pred[] constant(true)
4019 }
4020
4021 while_body {
4022 param2 = (f32[2,4], f32[2,4], f32[2,4]) parameter(0)
4023 gte2 = f32[2,4] get-tuple-element(param2), index=0
4024 gte3 = f32[2,4] get-tuple-element(param2), index=1
4025 gte4 = f32[2,4] get-tuple-element(param2), index=2
4026 add = f32[2,4] add(gte2, gte3)
4027 ROOT tuple2 = (f32[2,4], f32[2,4], f32[2,4]) tuple(add, gte3, gte4)
4028 }
4029
4030 ENTRY Entry {
4031 param0 = f32[2,4] parameter(0)
4032 a = f32[2,4] negate(param0)
4033 b = f32[2,4] negate(param0)
4034 tuple = (f32[2,4], f32[2,4], f32[2,4]) tuple(a, b, b)
4035 while = (f32[2,4], f32[2,4], f32[2,4]) while(tuple), condition=while_condition, body=while_body
4036 gte1 = f32[2,4] get-tuple-element(while), index=0
4037 gte2 = f32[2,4] get-tuple-element(while), index=1
4038 ROOT root = f32[2,4] add(gte1, gte2)
4039 }
4040 )";
4041 TF_ASSERT_OK_AND_ASSIGN(auto module,
4042 ParseAndReturnVerifiedModule(hlo_string));
4043 AssignMemorySpace(module.get());
4044 }
4045
TEST_P(MemorySpaceAssignmentTest,DisallowedUseBug)4046 TEST_P(MemorySpaceAssignmentTest, DisallowedUseBug) {
4047 // When we have a disallowed use (in this case tanh), we aren't allowed to
4048 // allocate this use in alternate memory. However, if we have another use
4049 // after this on the same buffer (o), this use may refer to "a" instead of the
4050 // evicted value, which is illegal because "a" will be allocated in the
4051 // alternate memory space.
4052 absl::string_view hlo_string = R"(
4053 HloModule bug, is_scheduled=true
4054
4055 ENTRY Entry {
4056 param0 = f32[8,3] parameter(0)
4057 param1 = f32[2,4] parameter(1)
4058 a = f32[8,3] cosine(param0)
4059 b = f32[2,4] negate(param1)
4060 d = f32[8,3] negate(a)
4061 c = f32[2,4] negate(b)
4062 e = f32[2,4] negate(c)
4063 f = f32[8,3] tanh(a)
4064 g = f32[2,4] negate(e)
4065 h = f32[2,4] negate(g)
4066 i = f32[2,4] negate(h)
4067 j = f32[2,4] negate(i)
4068 k = f32[2,4] negate(j)
4069 l = f32[2,4] negate(k)
4070 m = f32[2,4] negate(l)
4071 n = f32[2,4] sine(m)
4072 o = f32[8,3] negate(a)
4073 p = f32[2,4] negate(n)
4074 q = f32[8,3] add(o, f)
4075 r = f32[8,3] add(q, d)
4076 ROOT tuple = (f32[2,4], f32[8,3]) tuple(p, r)
4077 }
4078 )";
4079
4080 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4081 [](const MemorySpaceAssignment::BufferInterval& a,
4082 const MemorySpaceAssignment::BufferInterval& b) {
4083 auto get_opcode_priority = [](const HloOpcode& opcode) {
4084 switch (opcode) {
4085 case HloOpcode::kSin:
4086 return 0;
4087 case HloOpcode::kCos:
4088 return 1;
4089 case HloOpcode::kTanh:
4090 return 2;
4091 default:
4092 return 3;
4093 }
4094 };
4095
4096 return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
4097 get_opcode_priority(b.buffer->defining_instruction()->opcode());
4098 };
4099 TF_ASSERT_OK_AND_ASSIGN(auto module,
4100 ParseAndReturnVerifiedModule(hlo_string));
4101
4102 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4103 Options options;
4104 options.max_size_in_bytes = 128;
4105 options.alignment_in_bytes = 8;
4106 options.verify = true;
4107 options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
4108 return use.instruction->opcode() != HloOpcode::kTanh;
4109 };
4110 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4111 buffer_interval_compare, &prefetch_interval_picker,
4112 options);
4113 }
4114
TEST_P(MemorySpaceAssignmentTest,DisallowedUseBugInWhile)4115 TEST_P(MemorySpaceAssignmentTest, DisallowedUseBugInWhile) {
4116 // Test for situations where we disallow a use (tanh in this case) in the
4117 // alternate memory space and there is a subsequent use that also requires the
4118 // buffer to be in the default memory space. In this case, the allocation in
4119 // the default memory space might not be the very last one, so we need to
4120 // search the allocation sequence and find the one in the default memory
4121 // space.
4122 absl::string_view hlo_string = R"(
4123 HloModule module, is_scheduled=true
4124
4125 while_cond {
4126 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4127 ROOT gte = pred[] get-tuple-element(p0), index=3
4128 }
4129
4130 while_body {
4131 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4132 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4133 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4134 gte2 = f32[3]{0} get-tuple-element(p0), index=2
4135 gte3 = pred[] get-tuple-element(p0), index=3
4136 add = f32[3]{0} add(gte0, gte0)
4137 negate0 = f32[3]{0} negate(add)
4138 negate1 = f32[3]{0} negate(negate0)
4139 negate2 = f32[3]{0} negate(negate1)
4140 negate3 = f32[3]{0} negate(negate2)
4141 negate4 = f32[3]{0} negate(negate3)
4142 negate5 = f32[3]{0} negate(negate4)
4143 negate6 = f32[3]{0} negate(negate5)
4144 negate7 = f32[3]{0} negate(negate6)
4145 negate8 = f32[3]{0} negate(negate7)
4146 negate9 = f32[3]{0} negate(negate8)
4147 negate10 = f32[3]{0} negate(negate9)
4148 negate11 = f32[3]{0} negate(negate10)
4149 negate12 = f32[3]{0} negate(negate11)
4150 negate13 = f32[3]{0} negate(negate12)
4151 negate14 = f32[3]{0} negate(negate13)
4152 negate15 = f32[3]{0} negate(gte2)
4153 tanh = f32[3]{0} tanh(gte2)
4154 ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(negate14, tanh, gte2, gte3)
4155 }
4156
4157 ENTRY entry {
4158 p0 = f32[3]{0} parameter(0)
4159 p1 = pred[] parameter(1)
4160 copy0 = f32[3]{0} copy(p0)
4161 copy1 = f32[3]{0} copy(p0)
4162 tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
4163 while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4164 ROOT gte = f32[3]{0} get-tuple-element(while), index=2
4165 }
4166 )";
4167
4168 TF_ASSERT_OK_AND_ASSIGN(auto module,
4169 ParseAndReturnVerifiedModule(hlo_string));
4170 Options options;
4171 options.max_size_in_bytes = 128;
4172 options.alignment_in_bytes = 8;
4173 options.verify = true;
4174 options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
4175 return use.instruction->opcode() != HloOpcode::kTanh;
4176 };
4177 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4178 /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
4179 options);
4180 }
4181
TEST_P(MemorySpaceAssignmentTest,AvoidRedundantEvictionInWhile)4182 TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionInWhile) {
4183 absl::string_view hlo_string = R"(
4184 HloModule module, is_scheduled=true
4185
4186 while_cond {
4187 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4188 ROOT gte = pred[] get-tuple-element(p0), index=2
4189 }
4190
4191 while_body {
4192 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4193 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4194 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4195 tanh = f32[3]{0} tanh(gte1)
4196 gte2 = pred[] get-tuple-element(p0), index=2
4197 negate0 = f32[3]{0} negate(gte0)
4198 negate1 = f32[3]{0} negate(negate0)
4199 negate2 = f32[3]{0} negate(negate1)
4200 negate3 = f32[3]{0} negate(negate2)
4201 negate4 = f32[3]{0} negate(negate3)
4202 negate5 = f32[3]{0} negate(negate4)
4203 negate6 = f32[3]{0} negate(negate5)
4204 negate7 = f32[3]{0} negate(negate6)
4205 negate8 = f32[3]{0} negate(negate7)
4206 negate9 = f32[3]{0} negate(negate8)
4207 negate10 = f32[3]{0} negate(negate9)
4208 negate11 = f32[3]{0} negate(negate10)
4209 negate12 = f32[3]{0} negate(negate11)
4210 negate13 = f32[3]{0} negate(negate12)
4211 negate14 = f32[3]{0} negate(negate13)
4212 add = f32[3]{0} add(negate14, tanh)
4213 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2)
4214 }
4215
4216 ENTRY entry {
4217 p0 = f32[3]{0} parameter(0)
4218 p1 = pred[] parameter(1)
4219 copy = f32[3]{0} copy(p0)
4220 tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4221 while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4222 gte = f32[3]{0} get-tuple-element(while), index=1
4223 ROOT negate = f32[3]{0} negate(gte)
4224 }
4225 )";
4226
4227 TF_ASSERT_OK_AND_ASSIGN(auto module,
4228 ParseAndReturnVerifiedModule(hlo_string));
4229 AssignMemorySpace(module.get());
4230
4231 if (GetParam()) {
4232 // Expect that while{1} is allocated to alternate memory space. Also expect
4233 // that this buffer is prefetched at the end of the while loop body but is
4234 // never evicted (since it has a copy in the default memory space).
4235 const HloInstruction* while_instr = FindInstruction(module.get(), "while");
4236 EXPECT_EQ(while_instr->shape().tuple_shapes(1).layout().memory_space(),
4237 kAlternateMemorySpace);
4238 const HloInstruction* gte1 = FindInstruction(module.get(), "gte1");
4239 EXPECT_EQ(gte1->user_count(), 1);
4240 EXPECT_EQ(gte1->users()[0]->opcode(), HloOpcode::kTanh);
4241 const HloInstruction* while_root =
4242 while_instr->while_body()->root_instruction();
4243 EXPECT_THAT(while_root->operand(1),
4244 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4245 op::GetTupleElement(op::Parameter(0))));
4246 }
4247 }
4248
TEST_P(MemorySpaceAssignmentTest,RedundantEvictionEliminationShouldntAddRedundantParam)4249 TEST_P(MemorySpaceAssignmentTest,
4250 RedundantEvictionEliminationShouldntAddRedundantParam) {
4251 // Check that if there wasn't an eviction in the while loop, we don't add the
4252 // buffer in default memory as an additional parameter to the while loop.
4253 absl::string_view hlo_string = R"(
4254 HloModule module, is_scheduled=true
4255
4256 while_cond {
4257 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4258 ROOT gte = pred[] get-tuple-element(p0), index=2
4259 }
4260
4261 while_body {
4262 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4263 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4264 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4265 tanh = f32[3]{0} tanh(gte1)
4266 gte2 = pred[] get-tuple-element(p0), index=2
4267 negate0 = f32[3]{0} negate(gte0)
4268 negate1 = f32[3]{0} negate(negate0)
4269 add = f32[3]{0} add(negate1, tanh)
4270 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2)
4271 }
4272
4273 ENTRY entry {
4274 p0 = f32[3]{0} parameter(0)
4275 p1 = pred[] parameter(1)
4276 copy = f32[3]{0} copy(p0)
4277 tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4278 while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4279 gte = f32[3]{0} get-tuple-element(while), index=1
4280 ROOT negate = f32[3]{0} negate(gte)
4281 }
4282 )";
4283
4284 TF_ASSERT_OK_AND_ASSIGN(auto module,
4285 ParseAndReturnVerifiedModule(hlo_string));
4286 AssignMemorySpace(module.get());
4287
4288 // Expect that while tuple shape contains 3 elements like the original.
4289 const HloInstruction* while_instr = FindInstruction(module.get(), "while");
4290 EXPECT_EQ(while_instr->shape().tuple_shapes_size(), 3);
4291 }
4292
TEST_P(MemorySpaceAssignmentTest,AvoidRedundantEvictionInNestedWhile)4293 TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionInNestedWhile) {
4294 absl::string_view hlo_string = R"(
4295 HloModule module, is_scheduled=true
4296
4297 while_cond2 {
4298 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4299 ROOT gte = pred[] get-tuple-element(p0), index=2
4300 }
4301
4302 while_body2 {
4303 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4304 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4305 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4306 tanh = f32[3]{0} tanh(gte1)
4307 gte2 = pred[] get-tuple-element(p0), index=2
4308 negate0 = f32[3]{0} negate(gte0)
4309 negate1 = f32[3]{0} negate(negate0)
4310 negate2 = f32[3]{0} negate(negate1)
4311 negate3 = f32[3]{0} negate(negate2)
4312 negate4 = f32[3]{0} negate(negate3)
4313 negate5 = f32[3]{0} negate(negate4)
4314 negate6 = f32[3]{0} negate(negate5)
4315 negate7 = f32[3]{0} negate(negate6)
4316 negate8 = f32[3]{0} negate(negate7)
4317 negate9 = f32[3]{0} negate(negate8)
4318 negate10 = f32[3]{0} negate(negate9)
4319 negate11 = f32[3]{0} negate(negate10)
4320 negate12 = f32[3]{0} negate(negate11)
4321 negate13 = f32[3]{0} negate(negate12)
4322 negate14 = f32[3]{0} negate(negate13)
4323 add = f32[3]{0} add(negate14, tanh)
4324 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2)
4325 }
4326
4327 while_cond1 {
4328 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4329 ROOT gte = pred[] get-tuple-element(p0), index=2
4330 }
4331
4332 while_body1 {
4333 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4334 ROOT while2 = (f32[3]{0}, f32[3]{0}, pred[]) while(p0), condition=while_cond2, body=while_body2
4335 }
4336
4337 ENTRY entry {
4338 p0 = f32[3]{0} parameter(0)
4339 p1 = pred[] parameter(1)
4340 copy = f32[3]{0} copy(p0)
4341 tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4342 while1 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond1, body=while_body1
4343 gte = f32[3]{0} get-tuple-element(while1), index=1
4344 ROOT negate = f32[3]{0} negate(gte)
4345 }
4346 )";
4347
4348 TF_ASSERT_OK_AND_ASSIGN(auto module,
4349 ParseAndReturnVerifiedModule(hlo_string));
4350 AssignMemorySpace(module.get());
4351
4352 if (GetParam()) {
4353 // Expect that while1{1} and while2{1} are allocated to alternate memory
4354 // space. Also expect that this buffer is prefetched at the end of the while
4355 // loop body but is never evicted (since it has a copy in the default memory
4356 // space).
4357 const HloInstruction* while1_instr =
4358 FindInstruction(module.get(), "while1");
4359 EXPECT_EQ(while1_instr->shape().tuple_shapes(1).layout().memory_space(),
4360 kAlternateMemorySpace);
4361 const HloInstruction* while2_instr =
4362 FindInstruction(module.get(), "while2");
4363 EXPECT_EQ(while2_instr->shape().tuple_shapes(1).layout().memory_space(),
4364 kAlternateMemorySpace);
4365 const HloInstruction* gte1 = FindInstruction(module.get(), "gte1");
4366 EXPECT_EQ(gte1->user_count(), 1);
4367 EXPECT_EQ(gte1->users()[0]->opcode(), HloOpcode::kTanh);
4368 const HloInstruction* while_root =
4369 while2_instr->while_body()->root_instruction();
4370 EXPECT_THAT(while_root->operand(1),
4371 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4372 op::GetTupleElement(op::Parameter(0))));
4373 }
4374 }
4375
TEST_P(MemorySpaceAssignmentTest,RedundantEvictionEliminationBug)4376 TEST_P(MemorySpaceAssignmentTest, RedundantEvictionEliminationBug) {
4377 absl::string_view hlo_string = R"(
4378 HloModule module, is_scheduled=true
4379
4380 while_cond {
4381 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4382 ROOT gte = pred[] get-tuple-element(p0), index=2
4383 }
4384
4385 while_body {
4386 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4387 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4388 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4389 tanh = f32[3]{0} tanh(gte1)
4390 gte2 = pred[] get-tuple-element(p0), index=2
4391 negate0 = f32[3]{0} negate(gte0)
4392 negate1 = f32[3]{0} negate(negate0)
4393 negate2 = f32[3]{0} negate(negate1)
4394 negate3 = f32[3]{0} negate(negate2)
4395 negate4 = f32[3]{0} negate(negate3)
4396 negate5 = f32[3]{0} negate(negate4)
4397 negate6 = f32[3]{0} negate(negate5)
4398 negate7 = f32[3]{0} negate(negate6)
4399 negate8 = f32[3]{0} negate(negate7)
4400 negate9 = f32[3]{0} negate(negate8)
4401 negate10 = f32[3]{0} negate(negate9)
4402 negate11 = f32[3]{0} negate(negate10)
4403 negate12 = f32[3]{0} negate(negate11)
4404 negate13 = f32[3]{0} negate(negate12)
4405 negate14 = f32[3]{0} negate(negate13)
4406 add0 = f32[3]{0} add(negate14, tanh)
4407 add1 = f32[3]{0} add(add0, gte1)
4408 negate = f32[3]{0} negate(add1)
4409 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add1, negate, gte2)
4410 }
4411
4412 ENTRY entry {
4413 p0 = f32[3]{0} parameter(0)
4414 p1 = pred[] parameter(1)
4415 copy = f32[3]{0} copy(p0)
4416 tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4417 while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4418 gte = f32[3]{0} get-tuple-element(while), index=1
4419 ROOT negate = f32[3]{0} negate(gte)
4420 }
4421 )";
4422
4423 TF_ASSERT_OK_AND_ASSIGN(auto module,
4424 ParseAndReturnVerifiedModule(hlo_string));
4425 AssignMemorySpace(module.get());
4426
4427 // Expect that redundant eviction elimination doesn't kick in because
4428 // while{1} is updated within the body.
4429 const HloInstruction* while_instr = FindInstruction(module.get(), "while");
4430 EXPECT_EQ(while_instr->shape().tuple_shapes_size(), 3);
4431 if (GetParam()) {
4432 EXPECT_EQ(while_instr->shape().tuple_shapes(1).layout().memory_space(),
4433 kAlternateMemorySpace);
4434 const HloInstruction* gte1 = FindInstruction(module.get(), "gte1");
4435 EXPECT_EQ(gte1->user_count(), 2);
4436 EXPECT_NE(absl::c_find_if(gte1->users(),
4437 [](const HloInstruction* use) {
4438 return use->opcode() == HloOpcode::kCopyStart;
4439 }),
4440 gte1->users().end());
4441 }
4442 }
4443
TEST_P(MemorySpaceAssignmentTest,RedundantEvictionEliminationInChainedWhile)4444 TEST_P(MemorySpaceAssignmentTest, RedundantEvictionEliminationInChainedWhile) {
4445 // Check against a bug where a while HLO feeding to another while HLO can
4446 // cause a crash if we performa redundant eviction elimination to the first
4447 // while but not the other (while operand/parameter shapes would mismatch).
4448 absl::string_view hlo_string = R"(
4449 HloModule module, is_scheduled=true
4450
4451 while_cond1 {
4452 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4453 ROOT gte = pred[] get-tuple-element(p0), index=2
4454 }
4455
4456 while_body1 {
4457 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4458 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4459 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4460 tanh = f32[3]{0} tanh(gte1)
4461 gte2 = pred[] get-tuple-element(p0), index=2
4462 negate0 = f32[3]{0} negate(gte0)
4463 negate1 = f32[3]{0} negate(negate0)
4464 negate2 = f32[3]{0} negate(negate1)
4465 negate3 = f32[3]{0} negate(negate2)
4466 negate4 = f32[3]{0} negate(negate3)
4467 negate5 = f32[3]{0} negate(negate4)
4468 negate6 = f32[3]{0} negate(negate5)
4469 negate7 = f32[3]{0} negate(negate6)
4470 negate8 = f32[3]{0} negate(negate7)
4471 negate9 = f32[3]{0} negate(negate8)
4472 negate10 = f32[3]{0} negate(negate9)
4473 negate11 = f32[3]{0} negate(negate10)
4474 negate12 = f32[3]{0} negate(negate11)
4475 negate13 = f32[3]{0} negate(negate12)
4476 negate14 = f32[3]{0} negate(negate13)
4477 add = f32[3]{0} add(negate14, tanh)
4478 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2)
4479 }
4480
4481 while_cond2 {
4482 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4483 ROOT gte = pred[] get-tuple-element(p0), index=2
4484 }
4485
4486 while_body2 {
4487 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4488 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4489 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4490 tanh = f32[3]{0} tanh(gte1)
4491 gte2 = pred[] get-tuple-element(p0), index=2
4492 negate0 = f32[3]{0} negate(gte0)
4493 add = f32[3]{0} add(negate0, tanh)
4494 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2)
4495 }
4496
4497 ENTRY entry {
4498 p0 = f32[3]{0} parameter(0)
4499 p1 = pred[] parameter(1)
4500 copy = f32[3]{0} copy(p0)
4501 tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4502 while1 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond1, body=while_body1
4503 while2 = (f32[3]{0}, f32[3]{0}, pred[]) while(while1), condition=while_cond2, body=while_body2
4504 gte = f32[3]{0} get-tuple-element(while2), index=1
4505 ROOT negate = f32[3]{0} negate(gte)
4506 }
4507 )";
4508
4509 TF_ASSERT_OK_AND_ASSIGN(auto module,
4510 ParseAndReturnVerifiedModule(hlo_string));
4511 AssignMemorySpace(module.get());
4512
4513 if (GetParam()) {
4514 // Expect that while1 has one more value than while2 in its shape.
4515 EXPECT_EQ(
4516 FindInstruction(module.get(), "while1")->shape().tuple_shapes_size(),
4517 FindInstruction(module.get(), "while2")->shape().tuple_shapes_size() +
4518 1);
4519 }
4520 }
4521
TEST_P(MemorySpaceAssignmentTest,AvoidRedundantEvictionAfterWhile)4522 TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionAfterWhile) {
4523 absl::string_view hlo_string = R"(
4524 HloModule module, is_scheduled=true
4525
4526 while_cond {
4527 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4528 ROOT gte = pred[] get-tuple-element(p0), index=2
4529 }
4530
4531 while_body {
4532 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4533 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4534 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4535 gte2 = pred[] get-tuple-element(p0), index=2
4536 add = f32[3]{0} add(gte0, gte1)
4537 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
4538 }
4539
4540 ENTRY entry {
4541 p0 = f32[3]{0} parameter(0)
4542 p1 = pred[] parameter(1)
4543 copy = f32[3]{0} copy(p0)
4544 negate0 = f32[3]{0} negate(p0)
4545 negate1 = f32[3]{0} negate(negate0)
4546 negate2 = f32[3]{0} negate(negate1)
4547 negate3 = f32[3]{0} negate(negate2)
4548 negate4 = f32[3]{0} negate(negate3)
4549 negate5 = f32[3]{0} negate(negate4)
4550 negate6 = f32[3]{0} negate(negate5)
4551 negate7 = f32[3]{0} negate(negate6)
4552 negate8 = f32[3]{0} negate(negate7)
4553 negate9 = f32[3]{0} negate(negate8)
4554 negate10 = f32[3]{0} negate(negate9)
4555 negate11 = f32[3]{0} negate(negate10)
4556 negate12 = f32[3]{0} negate(negate11)
4557 negate13 = f32[3]{0} negate(negate12)
4558 negate14 = f32[3]{0} negate(negate13)
4559 tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, negate14, p1)
4560 while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4561 gte0 = f32[3]{0} get-tuple-element(while), index=0
4562 gte1 = f32[3]{0} get-tuple-element(while), index=1
4563 negate20 = f32[3]{0} negate(gte1)
4564 negate21 = f32[3]{0} negate(negate20)
4565 negate22 = f32[3]{0} negate(negate21)
4566 negate23 = f32[3]{0} negate(negate22)
4567 negate24 = f32[3]{0} negate(negate23)
4568 negate25 = f32[3]{0} negate(negate24)
4569 negate26 = f32[3]{0} negate(negate25)
4570 negate27 = f32[3]{0} negate(negate26)
4571 negate28 = f32[3]{0} negate(negate27)
4572 negate29 = f32[3]{0} negate(negate28)
4573 negate30 = f32[3]{0} negate(negate29)
4574 negate31 = f32[3]{0} negate(negate30)
4575 negate32 = f32[3]{0} negate(negate31)
4576 negate33 = f32[3]{0} negate(negate32)
4577 negate34 = f32[3]{0} negate(negate33)
4578 ROOT add = f32[3]{0} add(negate34, gte0)
4579 }
4580 )";
4581
4582 TF_ASSERT_OK_AND_ASSIGN(auto module,
4583 ParseAndReturnVerifiedModule(hlo_string));
4584 AssignMemorySpace(module.get());
4585
4586 if (GetParam()) {
4587 EXPECT_THAT(
4588 module->entry_computation()->root_instruction()->operand(1),
4589 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, op::Copy()));
4590 }
4591 }
4592
TEST_P(MemorySpaceAssignmentTest,AvoidRedundantEvictionAfterWhile2)4593 TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionAfterWhile2) {
4594 absl::string_view hlo_string = R"(
4595 HloModule module, is_scheduled=true
4596
4597 while_cond1 {
4598 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4599 ROOT gte = pred[] get-tuple-element(p0), index=2
4600 }
4601
4602 while_body1 {
4603 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4604 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4605 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4606 gte2 = pred[] get-tuple-element(p0), index=2
4607 add = f32[3]{0} add(gte0, gte1)
4608 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
4609 }
4610
4611 while_cond2 {
4612 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4613 ROOT gte = pred[] get-tuple-element(p0), index=2
4614 }
4615
4616 while_body2 {
4617 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4618 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4619 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4620 gte2 = pred[] get-tuple-element(p0), index=2
4621 add = f32[3]{0} add(gte0, gte1)
4622 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
4623 }
4624
4625 ENTRY entry {
4626 p0 = f32[3]{0} parameter(0)
4627 p1 = pred[] parameter(1)
4628 copy = f32[3]{0} copy(p0)
4629 tuple1 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4630 while1 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple1), condition=while_cond1, body=while_body1
4631 gte0 = f32[3]{0} get-tuple-element(while1), index=0
4632 gte1 = f32[3]{0} get-tuple-element(while1), index=1
4633 negate0 = f32[3]{0} negate(gte1)
4634 negate1 = f32[3]{0} negate(negate0)
4635 negate2 = f32[3]{0} negate(negate1)
4636 negate3 = f32[3]{0} negate(negate2)
4637 negate4 = f32[3]{0} negate(negate3)
4638 negate5 = f32[3]{0} negate(negate4)
4639 negate6 = f32[3]{0} negate(negate5)
4640 negate7 = f32[3]{0} negate(negate6)
4641 negate8 = f32[3]{0} negate(negate7)
4642 negate9 = f32[3]{0} negate(negate8)
4643 negate10 = f32[3]{0} negate(negate9)
4644 negate11 = f32[3]{0} negate(negate10)
4645 negate12 = f32[3]{0} negate(negate11)
4646 negate13 = f32[3]{0} negate(negate12)
4647 negate14 = f32[3]{0} negate(negate13)
4648 tuple2 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, negate14, p1)
4649 while2 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple2), condition=while_cond2, body=while_body2
4650 gte2 = f32[3]{0} get-tuple-element(while2), index=0
4651 gte3 = f32[3]{0} get-tuple-element(while2), index=1
4652 negate20 = f32[3]{0} negate(gte3)
4653 negate21 = f32[3]{0} negate(negate20)
4654 negate22 = f32[3]{0} negate(negate21)
4655 negate23 = f32[3]{0} negate(negate22)
4656 negate24 = f32[3]{0} negate(negate23)
4657 negate25 = f32[3]{0} negate(negate24)
4658 negate26 = f32[3]{0} negate(negate25)
4659 negate27 = f32[3]{0} negate(negate26)
4660 negate28 = f32[3]{0} negate(negate27)
4661 negate29 = f32[3]{0} negate(negate28)
4662 negate30 = f32[3]{0} negate(negate29)
4663 negate31 = f32[3]{0} negate(negate30)
4664 negate32 = f32[3]{0} negate(negate31)
4665 negate33 = f32[3]{0} negate(negate32)
4666 negate34 = f32[3]{0} negate(negate33)
4667 ROOT add = f32[3]{0} add(negate34, gte2)
4668 }
4669 )";
4670
4671 TF_ASSERT_OK_AND_ASSIGN(auto module,
4672 ParseAndReturnVerifiedModule(hlo_string));
4673 AssignMemorySpace(module.get());
4674
4675 if (GetParam()) {
4676 EXPECT_THAT(
4677 module->entry_computation()->root_instruction()->operand(1),
4678 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4679 op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
4680 op::GetTupleElement(op::While()))));
4681 }
4682 }
4683
TEST_P(MemorySpaceAssignmentTest,AfterWhileRedundantEarlierEvictionModifiedBuffer)4684 TEST_P(MemorySpaceAssignmentTest,
4685 AfterWhileRedundantEarlierEvictionModifiedBuffer) {
4686 absl::string_view hlo_string = R"(
4687 HloModule module, is_scheduled=true
4688
4689 while_cond {
4690 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4691 ROOT gte = pred[] get-tuple-element(p0), index=2
4692 }
4693
4694 while_body {
4695 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4696 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4697 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4698 gte2 = pred[] get-tuple-element(p0), index=2
4699 add = f32[3]{0} add(gte0, gte1)
4700 negate = f32[3]{0} negate(gte0)
4701 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(negate, add, gte2)
4702 }
4703
4704 ENTRY entry {
4705 p0 = f32[3]{0} parameter(0)
4706 p1 = pred[] parameter(1)
4707 copy = f32[3]{0} copy(p0)
4708 negate0 = f32[3]{0} negate(p0)
4709 negate1 = f32[3]{0} negate(negate0)
4710 negate2 = f32[3]{0} negate(negate1)
4711 negate3 = f32[3]{0} negate(negate2)
4712 negate4 = f32[3]{0} negate(negate3)
4713 negate5 = f32[3]{0} negate(negate4)
4714 negate6 = f32[3]{0} negate(negate5)
4715 negate7 = f32[3]{0} negate(negate6)
4716 negate8 = f32[3]{0} negate(negate7)
4717 negate9 = f32[3]{0} negate(negate8)
4718 negate10 = f32[3]{0} negate(negate9)
4719 negate11 = f32[3]{0} negate(negate10)
4720 negate12 = f32[3]{0} negate(negate11)
4721 negate13 = f32[3]{0} negate(negate12)
4722 negate14 = f32[3]{0} negate(negate13)
4723 tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, negate14, p1)
4724 while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4725 gte0 = f32[3]{0} get-tuple-element(while), index=0
4726 gte1 = f32[3]{0} get-tuple-element(while), index=1
4727 negate20 = f32[3]{0} negate(gte1)
4728 negate21 = f32[3]{0} negate(negate20)
4729 negate22 = f32[3]{0} negate(negate21)
4730 negate23 = f32[3]{0} negate(negate22)
4731 negate24 = f32[3]{0} negate(negate23)
4732 negate25 = f32[3]{0} negate(negate24)
4733 negate26 = f32[3]{0} negate(negate25)
4734 negate27 = f32[3]{0} negate(negate26)
4735 negate28 = f32[3]{0} negate(negate27)
4736 negate29 = f32[3]{0} negate(negate28)
4737 negate30 = f32[3]{0} negate(negate29)
4738 negate31 = f32[3]{0} negate(negate30)
4739 negate32 = f32[3]{0} negate(negate31)
4740 negate33 = f32[3]{0} negate(negate32)
4741 negate34 = f32[3]{0} negate(negate33)
4742 ROOT add = f32[3]{0} add(negate34, gte0)
4743 }
4744 )";
4745
4746 TF_ASSERT_OK_AND_ASSIGN(auto module,
4747 ParseAndReturnVerifiedModule(hlo_string));
4748 AssignMemorySpace(module.get());
4749
4750 if (GetParam()) {
4751 EXPECT_THAT(
4752 module->entry_computation()->root_instruction()->operand(1),
4753 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4754 op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
4755 op::GetTupleElement(op::While()))));
4756 }
4757 }
4758
TEST_P(MemorySpaceAssignmentTest,BitcastRoot)4759 TEST_P(MemorySpaceAssignmentTest, BitcastRoot) {
4760 // Tests against a bug where the root of entry computation is a bitcast
4761 // instruction and it ends up getting an allocation in the alternate memory.
4762 absl::string_view hlo_string = R"(
4763 HloModule primitive_computation_gather.4, is_scheduled=true
4764
4765 %while_body {
4766 %param.1 = (s32[], f32[3,3,3]) parameter(0)
4767 %get-tuple-element.32 = s32[] get-tuple-element(%param.1), index=0
4768 %copy.6 = s32[] copy(s32[] %get-tuple-element.32)
4769 %constant.8 = s32[] constant(1)
4770 %add = s32[] add(s32[] %copy.6, s32[] %constant.8)
4771 %get-tuple-element.35 = f32[3,3,3] get-tuple-element(%param.1), index=1
4772 negate = f32[3,3,3] negate(get-tuple-element.35)
4773 ROOT %tuple.10 = (s32[], f32[3,3,3]) tuple(s32[] %add, f32[3,3,3] negate)
4774 }
4775
4776 %while_cond {
4777 %param.0 = (s32[], f32[3,3,3]) parameter(0)
4778 %get-tuple-element = s32[] get-tuple-element(%param.0), index=0
4779 %constant.3 = s32[] constant(3)
4780 ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.3), direction=LT
4781 }
4782
4783 ENTRY %primitive_computation_gather.4 (parameter.1: f32[3,10,5], parameter.2: s32[3,1]) -> f32[3,3,3] {
4784 %constant.1 = s32[] constant(0)
4785 %copy.11 = s32[] copy(s32[] %constant.1)
4786 %constant = f32[] constant(0)
4787 %broadcast = f32[3,3,3] broadcast(f32[] %constant), dimensions={}
4788 %tuple.8 = (s32[], f32[3,3,3]) tuple(s32[] %copy.11, f32[3,3,3] %broadcast)
4789 %while = (s32[], f32[3,3,3]) while(%tuple.8), condition=%while_cond, body=%while_body
4790 %get-tuple-element.7 = f32[3,3,3] get-tuple-element(%while), index=1
4791 ROOT %bitcast.1 = f32[3,3,3] bitcast(f32[3,3,3] %get-tuple-element.7)
4792 }
4793 )";
4794
4795 TF_ASSERT_OK_AND_ASSIGN(auto module,
4796 ParseAndReturnVerifiedModule(hlo_string));
4797 AssignMemorySpace(module.get());
4798
4799 const HloInstruction* root = module->entry_computation()->root_instruction();
4800 EXPECT_TRUE(!root->shape().has_layout() ||
4801 root->shape().layout().memory_space() == kDefaultMemorySpace);
4802 }
4803
TEST_P(MemorySpaceAssignmentTest,AsyncOpShortLiveRange)4804 TEST_P(MemorySpaceAssignmentTest, AsyncOpShortLiveRange) {
4805 absl::string_view hlo_string = R"(
4806 HloModule module, is_scheduled=true
4807
4808 ENTRY entry {
4809 param = bf16[4]{0} parameter(0)
4810 negate0 = bf16[4]{0} negate(param)
4811 collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0), source_target_pairs={{0,1},{1,2},{2,3}}
4812 negate1 = bf16[4]{0} negate(param)
4813 negate2 = bf16[4]{0} negate(negate1)
4814 negate3 = bf16[4]{0} negate(negate2)
4815 collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
4816 ROOT add = add(collective-permute-done, negate3)
4817 }
4818 )";
4819
4820 TF_ASSERT_OK_AND_ASSIGN(auto module,
4821 ParseAndReturnVerifiedModule(hlo_string));
4822 AssignMemorySpace(module.get());
4823
4824 // Expect both the source and destination buffers to get alternate memory
4825 // allocations.
4826 HloInstruction* collective_permute_start =
4827 module->entry_computation()->GetInstructionWithName(
4828 "collective-permute-start");
4829 EXPECT_TRUE(collective_permute_start->shape()
4830 .tuple_shapes(0)
4831 .layout()
4832 .memory_space() == kAlternateMemorySpace);
4833 EXPECT_TRUE(collective_permute_start->shape()
4834 .tuple_shapes(1)
4835 .layout()
4836 .memory_space() == kAlternateMemorySpace);
4837 }
4838
TEST_P(MemorySpaceAssignmentTest,AsyncOpShortLiveRangeInputBufferConsumer)4839 TEST_P(MemorySpaceAssignmentTest, AsyncOpShortLiveRangeInputBufferConsumer) {
4840 absl::string_view hlo_string = R"(
4841 HloModule module, is_scheduled=true
4842
4843 ENTRY entry {
4844 param = bf16[4]{0} parameter(0)
4845 negate0 = bf16[4]{0} negate(param)
4846 collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0), source_target_pairs={{0,1},{1,2},{2,3}}
4847 negate1 = bf16[4]{0} negate(negate0)
4848 negate2 = bf16[4]{0} negate(negate1)
4849 negate3 = bf16[4]{0} negate(negate2)
4850 collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
4851 ROOT add = add(collective-permute-done, negate3)
4852 }
4853 )";
4854
4855 TF_ASSERT_OK_AND_ASSIGN(auto module,
4856 ParseAndReturnVerifiedModule(hlo_string));
4857 AssignMemorySpace(module.get());
4858
4859 // Expect only the destination buffer to get alternate memory allocation
4860 // because negate0 is also used by negate1.
4861 HloInstruction* collective_permute_start =
4862 module->entry_computation()->GetInstructionWithName(
4863 "collective-permute-start");
4864 EXPECT_TRUE(collective_permute_start->shape()
4865 .tuple_shapes(0)
4866 .layout()
4867 .memory_space() == kDefaultMemorySpace);
4868 EXPECT_TRUE(collective_permute_start->shape()
4869 .tuple_shapes(1)
4870 .layout()
4871 .memory_space() == kAlternateMemorySpace);
4872 }
4873
TEST_P(MemorySpaceAssignmentTest,AsyncOpLongLiveRange)4874 TEST_P(MemorySpaceAssignmentTest, AsyncOpLongLiveRange) {
4875 absl::string_view hlo_string = R"(
4876 HloModule module, is_scheduled=true
4877
4878 ENTRY entry {
4879 param = bf16[4]{0} parameter(0)
4880 negate0 = bf16[4]{0} negate(param)
4881 collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0), source_target_pairs={{0,1},{1,2},{2,3}}
4882 negate1 = bf16[4]{0} negate(param)
4883 negate2 = bf16[4]{0} negate(negate1)
4884 negate3 = bf16[4]{0} negate(negate2)
4885 negate4 = bf16[4]{0} negate(negate3)
4886 negate5 = bf16[4]{0} negate(negate4)
4887 negate6 = bf16[4]{0} negate(negate5)
4888 negate7 = bf16[4]{0} negate(negate6)
4889 negate8 = bf16[4]{0} negate(negate7)
4890 negate9 = bf16[4]{0} negate(negate8)
4891 negate10 = bf16[4]{0} negate(negate9)
4892 negate11 = bf16[4]{0} negate(negate10)
4893 negate12 = bf16[4]{0} negate(negate11)
4894 negate13 = bf16[4]{0} negate(negate12)
4895 collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
4896 ROOT add = add(collective-permute-done, negate13)
4897 }
4898 )";
4899
4900 TF_ASSERT_OK_AND_ASSIGN(auto module,
4901 ParseAndReturnVerifiedModule(hlo_string));
4902 AssignMemorySpace(module.get());
4903
4904 // Expect none of the buffers to get alternate memory allocations because of
4905 // the long live range.
4906 HloInstruction* collective_permute_start =
4907 module->entry_computation()->GetInstructionWithName(
4908 "collective-permute-start");
4909 EXPECT_TRUE(collective_permute_start->shape()
4910 .tuple_shapes(0)
4911 .layout()
4912 .memory_space() == kDefaultMemorySpace);
4913 EXPECT_TRUE(collective_permute_start->shape()
4914 .tuple_shapes(1)
4915 .layout()
4916 .memory_space() == kDefaultMemorySpace);
4917 }
4918
TEST_P(MemorySpaceAssignmentTest,AsyncOpLongLiveRangeInputBufferConsumer)4919 TEST_P(MemorySpaceAssignmentTest, AsyncOpLongLiveRangeInputBufferConsumer) {
4920 absl::string_view hlo_string = R"(
4921 HloModule module, is_scheduled=true
4922
4923 ENTRY entry {
4924 param = bf16[4]{0} parameter(0)
4925 negate0 = bf16[4]{0} negate(param)
4926 collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0), source_target_pairs={{0,1},{1,2},{2,3}}
4927 negate1 = bf16[4]{0} negate(negate0)
4928 negate2 = bf16[4]{0} negate(negate1)
4929 negate3 = bf16[4]{0} negate(negate2)
4930 negate4 = bf16[4]{0} negate(negate3)
4931 negate5 = bf16[4]{0} negate(negate4)
4932 negate6 = bf16[4]{0} negate(negate5)
4933 negate7 = bf16[4]{0} negate(negate6)
4934 negate8 = bf16[4]{0} negate(negate7)
4935 negate9 = bf16[4]{0} negate(negate8)
4936 negate10 = bf16[4]{0} negate(negate9)
4937 negate11 = bf16[4]{0} negate(negate10)
4938 negate12 = bf16[4]{0} negate(negate11)
4939 negate13 = bf16[4]{0} negate(negate12)
4940 collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
4941 ROOT add = add(collective-permute-done, negate13)
4942 }
4943 )";
4944
4945 TF_ASSERT_OK_AND_ASSIGN(auto module,
4946 ParseAndReturnVerifiedModule(hlo_string));
4947 AssignMemorySpace(module.get());
4948
4949 // Expect none of the buffers to get alternate memory allocations because of
4950 // the long live range and because negate0 is also used by negate1.
4951 HloInstruction* collective_permute_start =
4952 module->entry_computation()->GetInstructionWithName(
4953 "collective-permute-start");
4954 EXPECT_TRUE(collective_permute_start->shape()
4955 .tuple_shapes(0)
4956 .layout()
4957 .memory_space() == kDefaultMemorySpace);
4958 EXPECT_TRUE(collective_permute_start->shape()
4959 .tuple_shapes(1)
4960 .layout()
4961 .memory_space() == kDefaultMemorySpace);
4962 }
4963
TEST_P(MemorySpaceAssignmentTest,InPlaceAsyncCollectivePermute)4964 TEST_P(MemorySpaceAssignmentTest, InPlaceAsyncCollectivePermute) {
4965 absl::string_view hlo_string = R"(
4966 HloModule module, is_scheduled=true
4967
4968 ENTRY entry {
4969 param = bf16[4]{0} parameter(0)
4970 negate0 = bf16[4]{0} negate(param)
4971 negate1 = bf16[4]{0} negate(param)
4972 const0 = s32[] constant(0)
4973 const1 = s32[] constant(1)
4974 tuple0 = (s32[]) tuple(const0)
4975 tuple1 = (s32[]) tuple(const1)
4976 collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0, negate1, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
4977 negate2 = bf16[4]{0} negate(param)
4978 negate3 = bf16[4]{0} negate(negate2)
4979 negate4 = bf16[4]{0} negate(negate3)
4980 collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
4981 ROOT add = add(collective-permute-done, negate4)
4982 }
4983 )";
4984
4985 TF_ASSERT_OK_AND_ASSIGN(auto module,
4986 ParseAndReturnVerifiedModule(hlo_string));
4987 AssignMemorySpace(module.get());
4988
4989 // Expect both the source and destination buffers to get alternate memory
4990 // allocations.
4991 if (GetParam()) {
4992 HloInstruction* collective_permute_start =
4993 module->entry_computation()->GetInstructionWithName(
4994 "collective-permute-start");
4995 EXPECT_TRUE(collective_permute_start->shape()
4996 .tuple_shapes(0)
4997 .layout()
4998 .memory_space() == kAlternateMemorySpace);
4999 EXPECT_TRUE(collective_permute_start->shape()
5000 .tuple_shapes(1)
5001 .layout()
5002 .memory_space() == kAlternateMemorySpace);
5003 }
5004 }
5005
TEST_P(MemorySpaceAssignmentTest,InPlaceAsyncCollectivePermuteSameBuffer)5006 TEST_P(MemorySpaceAssignmentTest, InPlaceAsyncCollectivePermuteSameBuffer) {
5007 absl::string_view hlo_string = R"(
5008 HloModule module, is_scheduled=true
5009
5010 ENTRY entry {
5011 param = bf16[4]{0} parameter(0)
5012 negate0 = bf16[4]{0} negate(param)
5013 const0 = s32[] constant(0)
5014 const1 = s32[] constant(1)
5015 tuple0 = (s32[]) tuple(const0)
5016 tuple1 = (s32[]) tuple(const1)
5017 collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0, negate0, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
5018 negate2 = bf16[4]{0} negate(param)
5019 negate3 = bf16[4]{0} negate(negate2)
5020 negate4 = bf16[4]{0} negate(negate3)
5021 collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
5022 ROOT add = add(collective-permute-done, negate4)
5023 }
5024 )";
5025
5026 TF_ASSERT_OK_AND_ASSIGN(auto module,
5027 ParseAndReturnVerifiedModule(hlo_string));
5028 AssignMemorySpace(module.get());
5029
5030 // Expect both the source and destination buffers to get alternate memory
5031 // allocations.
5032 if (GetParam()) {
5033 HloInstruction* collective_permute_start =
5034 module->entry_computation()->GetInstructionWithName(
5035 "collective-permute-start");
5036 EXPECT_TRUE(collective_permute_start->shape()
5037 .tuple_shapes(0)
5038 .layout()
5039 .memory_space() == kAlternateMemorySpace);
5040 EXPECT_TRUE(collective_permute_start->shape()
5041 .tuple_shapes(1)
5042 .layout()
5043 .memory_space() == kAlternateMemorySpace);
5044 }
5045 }
5046
TEST_P(MemorySpaceAssignmentTest,InPlaceAsyncCollectivePermuteSameBufferChained)5047 TEST_P(MemorySpaceAssignmentTest,
5048 InPlaceAsyncCollectivePermuteSameBufferChained) {
5049 absl::string_view hlo_string = R"(
5050 HloModule module, is_scheduled=true
5051
5052 ENTRY entry {
5053 param = bf16[4]{0} parameter(0)
5054 negate0 = bf16[4]{0} negate(param)
5055 const0 = s32[] constant(0)
5056 const1 = s32[] constant(1)
5057 tuple0 = (s32[]) tuple(const0)
5058 tuple1 = (s32[]) tuple(const1)
5059 collective-permute-start.1 = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0, negate0, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
5060 negate2 = bf16[4]{0} negate(param)
5061 negate3 = bf16[4]{0} negate(negate2)
5062 negate4 = bf16[4]{0} negate(negate3)
5063 collective-permute-done.1 = bf16[4]{0} collective-permute-done(collective-permute-start.1)
5064 collective-permute-start.2 = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(collective-permute-done.1, collective-permute-done.1, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
5065 negate5 = bf16[4]{0} negate(negate4)
5066 negate6 = bf16[4]{0} negate(negate5)
5067 negate7 = bf16[4]{0} negate(negate6)
5068 collective-permute-done.2 = bf16[4]{0} collective-permute-done(collective-permute-start.2)
5069 ROOT add = add(collective-permute-done.2, negate7)
5070 }
5071 )";
5072
5073 TF_ASSERT_OK_AND_ASSIGN(auto module,
5074 ParseAndReturnVerifiedModule(hlo_string));
5075 AssignMemorySpace(module.get());
5076
5077 // Expect both the source and destination buffers to get alternate memory
5078 // allocations.
5079 if (GetParam()) {
5080 HloInstruction* collective_permute_start_1 =
5081 module->entry_computation()->GetInstructionWithName(
5082 "collective-permute-start.1");
5083 EXPECT_TRUE(collective_permute_start_1->shape()
5084 .tuple_shapes(0)
5085 .layout()
5086 .memory_space() == kAlternateMemorySpace);
5087 EXPECT_TRUE(collective_permute_start_1->shape()
5088 .tuple_shapes(1)
5089 .layout()
5090 .memory_space() == kAlternateMemorySpace);
5091 HloInstruction* collective_permute_start_2 =
5092 module->entry_computation()->GetInstructionWithName(
5093 "collective-permute-start.2");
5094 EXPECT_TRUE(collective_permute_start_2->shape()
5095 .tuple_shapes(0)
5096 .layout()
5097 .memory_space() == kAlternateMemorySpace);
5098 EXPECT_TRUE(collective_permute_start_2->shape()
5099 .tuple_shapes(1)
5100 .layout()
5101 .memory_space() == kAlternateMemorySpace);
5102 }
5103 }
5104
TEST_P(MemorySpaceAssignmentTest,TupleInPlaceAsyncCollectivePermuteSameBufferChained)5105 TEST_P(MemorySpaceAssignmentTest,
5106 TupleInPlaceAsyncCollectivePermuteSameBufferChained) {
5107 absl::string_view hlo_string = R"(
5108 HloModule module, is_scheduled=true
5109
5110 ENTRY entry {
5111 param = bf16[4]{0} parameter(0)
5112 param2 = bf16[48]{0} parameter(1)
5113 negate0.1 = bf16[48]{0} negate(param2)
5114 negate0.2 = bf16[48]{0} negate(param2)
5115 const0 = s32[] constant(0)
5116 const1 = s32[] constant(1)
5117 tuple0.0 = (s32[]) tuple(const0)
5118 tuple0 = ((s32[]), (s32[])) tuple(tuple0.0, tuple0.0)
5119 tuple1.0 = (s32[]) tuple(const1)
5120 tuple1 = ((s32[]), (s32[])) tuple(tuple1.0, tuple1.0)
5121 tuple2 = (bf16[48]{0}, bf16[48]{0}) tuple(negate0.1, negate0.2)
5122 collective-permute-start.1 = ((bf16[48]{0}, bf16[48]{0}), (bf16[48]{0}, bf16[48]{0}), u32[], u32[]) collective-permute-start(tuple2, tuple2, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
5123 negate2 = bf16[4]{0} negate(param)
5124 negate3 = bf16[4]{0} negate(negate2)
5125 negate4 = bf16[4]{0} negate(negate3)
5126 collective-permute-done.1 = (bf16[48]{0}, bf16[48]{0}) collective-permute-done(collective-permute-start.1)
5127 collective-permute-start.2 = ((bf16[48]{0}, bf16[48]{0}), (bf16[48]{0}, bf16[48]{0}), u32[], u32[]) collective-permute-start(collective-permute-done.1, collective-permute-done.1, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
5128 negate5 = bf16[4]{0} negate(negate4)
5129 negate6 = bf16[4]{0} negate(negate5)
5130 negate7 = bf16[4]{0} negate(negate6)
5131 collective-permute-done.2 = (bf16[48]{0}, bf16[48]{0}) collective-permute-done(collective-permute-start.2)
5132 gte = bf16[48]{0} get-tuple-element(collective-permute-done.2), index=0
5133 ROOT root = (bf16[48]{0}, bf16[4]{0}) tuple(gte, negate7)
5134 }
5135 )";
5136
5137 TF_ASSERT_OK_AND_ASSIGN(auto module,
5138 ParseAndReturnVerifiedModule(hlo_string));
5139 AssignMemorySpace(module.get());
5140
5141 const HloInstruction* cp_done1 =
5142 FindInstruction(module.get(), "collective-permute-done.1");
5143 EXPECT_EQ(cp_done1->operand(0)->opcode(), HloOpcode::kCollectivePermuteStart);
5144 const HloInstruction* cp_done2 =
5145 FindInstruction(module.get(), "collective-permute-done.2");
5146 EXPECT_EQ(cp_done2->operand(0)->opcode(), HloOpcode::kCollectivePermuteStart);
5147 }
5148
TEST_P(MemorySpaceAssignmentTest,ReservedScopedMemory)5149 TEST_P(MemorySpaceAssignmentTest, ReservedScopedMemory) {
5150 absl::string_view hlo_string = R"(
5151 HloModule module, is_scheduled=true
5152
5153 ENTRY entry {
5154 param0 = f32[2,4] parameter(0)
5155 a = f32[2,4] negate(param0)
5156 b = f32[2,4] negate(a)
5157 c = f32[2,4] negate(b)
5158 d = f32[2,4] negate(c)
5159 e = f32[2,4] negate(d)
5160 ROOT f = f32[2,4] add(e, b)
5161 }
5162 )";
5163
5164 TF_ASSERT_OK_AND_ASSIGN(auto module,
5165 ParseAndReturnVerifiedModule(hlo_string));
5166 Options options;
5167 options.max_size_in_bytes = 128;
5168 options.alignment_in_bytes = 8;
5169 options.verify = true;
5170 // Make instruction c reserve 64 bytes in the alternate memory. This should
5171 // prevent both b and c to put their outputs in the alternate memory.
5172 options.reserved_scoped_memory_fn = [&](const HloInstruction* instruction) {
5173 if (instruction->name() == "c") {
5174 return 100;
5175 }
5176 return 0;
5177 };
5178 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5179 /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
5180 options);
5181 auto get_memory_space = [&](absl::string_view instruction_name) {
5182 return module->entry_computation()
5183 ->GetInstructionWithName(instruction_name)
5184 ->shape()
5185 .layout()
5186 .memory_space();
5187 };
5188 EXPECT_TRUE(get_memory_space("a") == kAlternateMemorySpace);
5189 EXPECT_TRUE(get_memory_space("b") == kDefaultMemorySpace);
5190 EXPECT_TRUE(get_memory_space("c") == kDefaultMemorySpace);
5191 EXPECT_TRUE(get_memory_space("d") == kAlternateMemorySpace);
5192 EXPECT_TRUE(get_memory_space("e") == kAlternateMemorySpace);
5193 }
5194
TEST_P(MemorySpaceAssignmentTest,ConstantAllocationFar)5195 TEST_P(MemorySpaceAssignmentTest, ConstantAllocationFar) {
5196 absl::string_view hlo_string = R"(
5197 HloModule module, is_scheduled=true
5198
5199 ENTRY entry {
5200 param0 = f32[2,4] parameter(0)
5201 const = f32[2,4] constant({...})
5202 a = f32[2,4] negate(param0)
5203 b = f32[2,4] negate(a)
5204 c = f32[2,4] negate(b)
5205 d = f32[2,4] negate(c)
5206 e = f32[2,4] negate(d)
5207 ROOT negate = f32[2,4] add(const, e)
5208 }
5209 )";
5210
5211 TF_ASSERT_OK_AND_ASSIGN(auto module,
5212 ParseAndReturnVerifiedModule(hlo_string));
5213 AssignMemorySpace(module.get());
5214 EXPECT_TRUE(module->entry_computation()
5215 ->GetInstructionWithName("const")
5216 ->shape()
5217 .layout()
5218 .memory_space() == kDefaultMemorySpace);
5219 EXPECT_TRUE(module->entry_computation()
5220 ->GetInstructionWithName("negate")
5221 ->operand(0)
5222 ->shape()
5223 .layout()
5224 .memory_space() == kAlternateMemorySpace);
5225 }
5226
TEST_P(MemorySpaceAssignmentTest,ConstantAllocationNear)5227 TEST_P(MemorySpaceAssignmentTest, ConstantAllocationNear) {
5228 absl::string_view hlo_string = R"(
5229 HloModule module, is_scheduled=true
5230
5231 ENTRY entry {
5232 param0 = f32[2,4] parameter(0)
5233 a = f32[2,4] negate(param0)
5234 b = f32[2,4] negate(a)
5235 c = f32[2,4] negate(b)
5236 d = f32[2,4] negate(c)
5237 e = f32[2,4] negate(d)
5238 const = f32[2,4] constant({...})
5239 ROOT negate = f32[2,4] add(const, e)
5240 }
5241 )";
5242
5243 TF_ASSERT_OK_AND_ASSIGN(auto module,
5244 ParseAndReturnVerifiedModule(hlo_string));
5245 AssignMemorySpace(module.get());
5246 EXPECT_TRUE(module->entry_computation()
5247 ->GetInstructionWithName("const")
5248 ->shape()
5249 .layout()
5250 .memory_space() == kDefaultMemorySpace);
5251 EXPECT_TRUE(module->entry_computation()
5252 ->GetInstructionWithName("negate")
5253 ->operand(0)
5254 ->shape()
5255 .layout()
5256 .memory_space() == kAlternateMemorySpace);
5257 }
5258
5259 // A mock MemorySpaceAssignmentRepacker class that accepst a map of
5260 // (start_time,offset) -> new_offset values. Using this map, the repacker
5261 // repacks the allocations to the new_offset.
5262 class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker {
5263 public:
FakeMemorySpaceAssignmentRepacker(absl::flat_hash_map<std::pair<int64_t,int64_t>,int64_t> & repack_map,std::function<void (absl::Span<AllocationBlock * >)> check_fun=nullptr,bool always_return_modified=false)5264 explicit FakeMemorySpaceAssignmentRepacker(
5265 absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t>& repack_map,
5266 std::function<void(absl::Span<AllocationBlock*>)> check_fun = nullptr,
5267 bool always_return_modified = false)
5268 : MemorySpaceAssignmentRepacker(/*max_size=*/128, /*alignment=*/8),
5269 repack_map_(repack_map),
5270 check_fun_(check_fun),
5271 always_return_modified_(always_return_modified) {}
5272
Repack(absl::Span<AllocationBlock * > allocations)5273 StatusOr<bool> Repack(absl::Span<AllocationBlock*> allocations) override {
5274 bool modified = false;
5275 for (AllocationBlock* block : allocations) {
5276 absl::flat_hash_set<int64_t> colocations;
5277 std::string colocations_str;
5278 for (const AllocationBlock* colocation : block->colocations) {
5279 absl::StrAppend(&colocations_str, colocation->id, ", ");
5280 colocations.insert(colocation->id);
5281 }
5282 VLOG(1) << "Alloc id: " << block->id << " time: [" << block->start_time
5283 << ", " << block->end_time << "] size: " << block->size
5284 << " init offset: " << block->initial_offset << " colocations: {"
5285 << colocations_str << "}";
5286 auto it = repack_map_.find({block->start_time, block->initial_offset});
5287 if (it != repack_map_.end()) {
5288 modified = true;
5289 block->offset = it->second;
5290 } else {
5291 block->offset = block->initial_offset;
5292 }
5293 for (AllocationBlock* colocation : block->colocations) {
5294 if (it != repack_map_.end()) {
5295 colocation->offset = it->second;
5296 } else {
5297 colocation->offset = colocation->initial_offset;
5298 }
5299 }
5300 }
5301 if (check_fun_) {
5302 check_fun_(allocations);
5303 }
5304
5305 return always_return_modified_ || modified;
5306 }
5307
5308 private:
5309 // A map from (start_time, offset) to new_offset.
5310 absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> repack_map_;
5311 std::function<void(absl::Span<AllocationBlock*>)> check_fun_;
5312 bool always_return_modified_;
5313 };
5314
TEST_P(MemorySpaceAssignmentTest,Repack)5315 TEST_P(MemorySpaceAssignmentTest, Repack) {
5316 // We initially perform the following allocations at these offsets.
5317 //
5318 // Max memory
5319 // -------------------------------------------
5320 //
5321 //
5322 //
5323 //
5324 // +------------+
5325 // | b |
5326 // +------------+
5327 // +-------+ +------------+
5328 // | a | | n |
5329 // +-------+ +------------+
5330 // -------------------------------------------
5331 // Min memory time ->
5332 //
5333 // Next up, we try to allocate the prefetch for m. However due to
5334 // fragmentation, this won't be possible:
5335 //
5336 // Max memory
5337 // -------------------------------------------
5338 //
5339 //
5340 //
5341 // +---------+
5342 // +------------+ |
5343 // | b | | |
5344 // +------------+ |
5345 // +-------+ | | +------------+
5346 // | a | | d | | n |
5347 // +-------+ +---------+ +------------+
5348 // -------------------------------------------
5349 // Min memory time ->
5350 //
5351 // We then call repack to repack the existing allocations which allows us to
5352 // allocate the prefetch for m:
5353 //
5354 // Max memory
5355 // -------------------------------------------
5356 // +---------+
5357 // | |
5358 // | |
5359 // | |
5360 // +-------+ | |
5361 // | a | | d |
5362 // +-------+ +---------+
5363 // +------------+ +------------+
5364 // | b | | n |
5365 // +------------+ +------------+
5366 // -------------------------------------------
5367 // Min memory time ->
5368 absl::string_view hlo_string = R"(
5369 HloModule bug, is_scheduled=true
5370
5371 ENTRY Entry {
5372 param0 = f32[8,3] parameter(0)
5373 param1 = f32[2,4] parameter(1)
5374 a = f32[2,4] sine(param1)
5375 b = f32[2,4] cosine(param1)
5376 c = f32[8,3] negate(param0)
5377 j = f32[2,4] negate(a)
5378 d = f32[8,3] tanh(param0)
5379 k = f32[2,4] negate(j)
5380 l = f32[2,4] add(b, k)
5381 m = f32[8,3] negate(d)
5382 n = f32[2,4] sine(l)
5383 o = f32[8,3] negate(m)
5384 p = f32[2,4] negate(n)
5385 q = f32[8,3] negate(m)
5386 ROOT tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(p, q, o)
5387 }
5388 )";
5389
5390 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
5391 [](const MemorySpaceAssignment::BufferInterval& a,
5392 const MemorySpaceAssignment::BufferInterval& b) {
5393 auto get_opcode_priority = [](const HloOpcode& opcode) {
5394 switch (opcode) {
5395 case HloOpcode::kSin:
5396 return 0;
5397 case HloOpcode::kCos:
5398 return 1;
5399 case HloOpcode::kTanh:
5400 return 2;
5401 default:
5402 return 3;
5403 }
5404 };
5405
5406 return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
5407 get_opcode_priority(b.buffer->defining_instruction()->opcode());
5408 };
5409 TF_ASSERT_OK_AND_ASSIGN(auto module,
5410 ParseAndReturnVerifiedModule(hlo_string));
5411
5412 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
5413 absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> repack_map;
5414 // Move "a" from offset 0 to 32.
5415 repack_map[{2, 0}] = 32;
5416 // Move "b" from offset 32 to 0.
5417 repack_map[{3, 32}] = 0;
5418 FakeMemorySpaceAssignmentRepacker repacker =
5419 FakeMemorySpaceAssignmentRepacker(repack_map);
5420 Options options;
5421 options.max_size_in_bytes = 128;
5422 options.alignment_in_bytes = 8;
5423 options.verify = true;
5424 options.max_repacks = 1;
5425 options.repacker = &repacker;
5426 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5427 buffer_interval_compare, &prefetch_interval_picker,
5428 options);
5429
5430 // If repacking succeeds, we should find the buffer for d in alternate memory.
5431 const HloInstruction* d =
5432 module->entry_computation()->GetInstructionWithName("d");
5433 EXPECT_EQ(d->shape().layout().memory_space(), kAlternateMemorySpace);
5434 }
5435
TEST_P(MemorySpaceAssignmentTest,RepackExportsAliasedOffsets)5436 TEST_P(MemorySpaceAssignmentTest, RepackExportsAliasedOffsets) {
5437 // This test is that we are correctly exporting aliased offsets for repacking.
5438 // In this example, the buffer produced at HLO "a" will be allocated first,
5439 // and will consist of four allocations:
5440 // 1) a produced in the alternate memory (and then evicted to the default
5441 // memory). 2) a prefetched to the alternate memory to be used by q and
5442 // while HLOs. 3) a used within the while loop body. 4) the output of while
5443 // HLO, used by u.
5444 //
5445 // Since a will be allocated first (the test is crafted to prioritize sine
5446 // HLO), all four allocations should get the same (zero) offsets. However,
5447 // while allocations 2, 3, and 4 need to be colocated with each other,
5448 // allocation 1 doesn't need to be colocated with the other three.
5449 absl::string_view hlo_string = R"(
5450 HloModule bug, is_scheduled=true
5451
5452 while_condition {
5453 param1 = (f32[2,4], f32[2,4]) parameter(0)
5454 ROOT cond = pred[] constant(true)
5455 }
5456
5457 while_body {
5458 param2 = (f32[2,4], f32[2,4]) parameter(0)
5459 gte2 = f32[2,4] get-tuple-element(param2), index=0
5460 gte3 = f32[2,4] get-tuple-element(param2), index=1
5461 add = f32[2,4] add(gte2, gte3)
5462 ROOT tuple2 = (f32[2,4], f32[2,4]) tuple(add, gte3)
5463 }
5464
5465 ENTRY Entry {
5466 param0 = f32[2,4] parameter(0)
5467 a = f32[2,4] sine(param0)
5468 b = f32[2,4] negate(a)
5469 c = f32[2,4] negate(b)
5470 d = f32[2,4] negate(c)
5471 e = f32[2,4] negate(d)
5472 f = f32[2,4] negate(e)
5473 g = f32[2,4] negate(f)
5474 h = f32[2,4] negate(g)
5475 i = f32[2,4] negate(h)
5476 j = f32[2,4] negate(i)
5477 k = f32[2,4] negate(j)
5478 l = f32[2,4] negate(k)
5479 m = f32[2,4] negate(l)
5480 n = f32[2,4] negate(m)
5481 o = f32[2,4] negate(n)
5482 p = f32[2,4] negate(o)
5483 q = f32[2,4] add(p, a)
5484 tuple = (f32[2,4], f32[2,4]) tuple(q, a)
5485 while = (f32[2,4], f32[2,4]) while(tuple), condition=while_condition, body=while_body
5486 gte0 = f32[2,4] get-tuple-element(while), index=0
5487 gte1 = f32[2,4] get-tuple-element(while), index=1
5488 r = f32[2,4] negate(gte0)
5489 s = f32[2,4] negate(r)
5490 t = f32[2,4] negate(s)
5491 constant = f32[] constant(0)
5492 broadcast = f32[8,4] broadcast(constant), dimensions={}
5493 cos = f32[8,4] cosine(broadcast)
5494 u = f32[2,4] add(t, gte1)
5495 v = f32[2,4] add(u, param0)
5496 w = f32[8,4] negate(cos)
5497 ROOT tuple3 = (f32[2,4], f32[8,4]) tuple(v, w)
5498 }
5499 )";
5500
5501 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
5502 [](const MemorySpaceAssignment::BufferInterval& a,
5503 const MemorySpaceAssignment::BufferInterval& b) {
5504 auto get_opcode_priority = [](const HloOpcode& opcode) {
5505 switch (opcode) {
5506 case HloOpcode::kSin:
5507 return 0;
5508 case HloOpcode::kCos:
5509 return 1;
5510 case HloOpcode::kTanh:
5511 return 2;
5512 default:
5513 return 3;
5514 }
5515 };
5516
5517 return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
5518 get_opcode_priority(b.buffer->defining_instruction()->opcode());
5519 };
5520 TF_ASSERT_OK_AND_ASSIGN(auto module,
5521 ParseAndReturnVerifiedModule(hlo_string));
5522
5523 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
5524 absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> repack_map;
5525
5526 // Expect that of the four separate allocations for the "a" buffer, the first
5527 // and the next three are in separate colocations.
5528 auto check_fun =
5529 [](absl::Span<MemorySpaceAssignmentRepacker::AllocationBlock*>
5530 allocations) {
5531 EXPECT_TRUE(allocations.at(0)->colocations.size() == 1 ||
5532 allocations.at(0)->colocations.size() == 3);
5533 EXPECT_EQ(allocations.at(1)->colocations.size(), 3);
5534 EXPECT_EQ(allocations.at(2)->colocations.size(), 3);
5535 EXPECT_TRUE(allocations.at(3)->colocations.size() == 1 ||
5536 allocations.at(3)->colocations.size() == 3);
5537 };
5538 FakeMemorySpaceAssignmentRepacker repacker =
5539 FakeMemorySpaceAssignmentRepacker(repack_map, check_fun);
5540 Options options;
5541 options.max_size_in_bytes = 128;
5542 options.alignment_in_bytes = 8;
5543 options.verify = true;
5544 options.max_repacks = 1;
5545 options.repacker = &repacker;
5546 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5547 buffer_interval_compare, &prefetch_interval_picker,
5548 options);
5549 }
5550
TEST_P(MemorySpaceAssignmentTest,RepackExportsAliasedOffsetsForReservedScopedMemory)5551 TEST_P(MemorySpaceAssignmentTest,
5552 RepackExportsAliasedOffsetsForReservedScopedMemory) {
5553 absl::string_view hlo_string = R"(
5554 HloModule module, is_scheduled=true
5555
5556 ENTRY entry {
5557 param0 = f32[2,4] parameter(0)
5558 a = f32[2,4] negate(param0)
5559 b = f32[2,4] negate(a)
5560 c = f32[2,4] negate(b)
5561 d = f32[2,4] negate(c)
5562 e = f32[2,4] negate(d)
5563 ROOT f = f32[2,4] add(e, b)
5564 }
5565 )";
5566 TF_ASSERT_OK_AND_ASSIGN(auto module,
5567 ParseAndReturnVerifiedModule(hlo_string));
5568 Options options;
5569 options.max_size_in_bytes = 128;
5570 options.alignment_in_bytes = 8;
5571 options.verify = true;
5572 options.max_repacks = 1;
5573 // Make two instructions reserve scoped memory.
5574 options.reserved_scoped_memory_fn = [&](const HloInstruction* instruction) {
5575 if (instruction->name() == "c" || instruction->name() == "d") {
5576 return 100;
5577 }
5578 return 0;
5579 };
5580
5581 absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> repack_map;
5582 bool repacker_ran = false;
5583
5584 // Expect that the first two value to repack has a colocations size of 2,
5585 // corresponding to the scoped allocations.
5586 auto check_fun =
5587 [&](absl::Span<MemorySpaceAssignmentRepacker::AllocationBlock*>
5588 allocations) {
5589 EXPECT_EQ(allocations.at(0)->colocations.size(), 2);
5590 EXPECT_EQ(allocations.at(1)->colocations.size(), 2);
5591 repacker_ran = true;
5592 };
5593 FakeMemorySpaceAssignmentRepacker repacker =
5594 FakeMemorySpaceAssignmentRepacker(repack_map, check_fun);
5595 options.repacker = &repacker;
5596 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5597 /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
5598 options);
5599 EXPECT_TRUE(repacker_ran);
5600 }
5601
TEST_P(MemorySpaceAssignmentTest,RepackShouldntEraseRequiredAssignmentForConditionalOutput)5602 TEST_P(MemorySpaceAssignmentTest,
5603 RepackShouldntEraseRequiredAssignmentForConditionalOutput) {
5604 // This is a test case for b/171040271. Repacks erase the required assignments
5605 // (since some required assignments are inserted conditionally based on
5606 // allocation decisions), including the fact that conditional outputs are
5607 // always required to get assignments in the default memory. After repacking,
5608 // this required assignment was never added back, causing conditionals to get
5609 // alternate-memory allocations.
5610 absl::string_view hlo_string = R"(
5611 HloModule CondAllocation, is_scheduled=true
5612
5613 true_computation {
5614 p0 = (f32[3]) parameter(0)
5615 gte = f32[3] get-tuple-element(p0), index=0
5616 neg1 = f32[3] negate(gte)
5617 ROOT tuple1 = (f32[3]) tuple(neg1)
5618 }
5619
5620 false_computation {
5621 p0 = (f32[3]) parameter(0)
5622 gte = f32[3] get-tuple-element(p0), index=0
5623 neg2 = f32[3] negate(gte)
5624 ROOT tuple2 = (f32[3]) tuple(neg2)
5625 }
5626
5627 ENTRY entry {
5628 p0 = f32[3] parameter(0)
5629 p1 = pred[] parameter(1)
5630 copy = f32[3] copy(p0)
5631 tuple = (f32[3]) tuple(copy)
5632 conditional = (f32[3]) conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
5633 ROOT gte = f32[3] get-tuple-element(conditional), index=0
5634 }
5635 )";
5636 TF_ASSERT_OK_AND_ASSIGN(auto module,
5637 ParseAndReturnVerifiedModule(hlo_string));
5638 absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> repack_map;
5639 FakeMemorySpaceAssignmentRepacker repacker =
5640 FakeMemorySpaceAssignmentRepacker(repack_map, nullptr,
5641 /*always_return_modified=*/true);
5642 Options options;
5643 options.max_size_in_bytes = 128;
5644 options.alignment_in_bytes = 8;
5645 options.verify = true;
5646 options.max_repacks = 10;
5647 options.repacker = &repacker;
5648 options.repack_after_every_allocation = true;
5649 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
5650 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5651 /*buffer_interval_compare=*/{}, &prefetch_interval_picker,
5652 options);
5653 }
5654
TEST_P(MemorySpaceAssignmentTest,Determinism)5655 TEST_P(MemorySpaceAssignmentTest, Determinism) {
5656 // Run memory space assignment a few times to make sure every time it compiles
5657 // to the same thing.
5658 std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
5659
5660 AssignMemorySpace(module.get());
5661 std::string module_str = module->ToString();
5662
5663 for (int i = 0; i < 10; ++i) {
5664 std::unique_ptr<HloModule> other_module = CreateEvictAndPrefetchModule();
5665 AssignMemorySpace(other_module.get());
5666 EXPECT_EQ(module_str, other_module->ToString());
5667 }
5668 }
5669
TEST_P(MemorySpaceAssignmentTest,InPlaceOp)5670 TEST_P(MemorySpaceAssignmentTest, InPlaceOp) {
5671 // Tests that in-place ops like DynamicUpdateSlice get the same allocation as
5672 // its input.
5673 absl::string_view hlo_string = R"(
5674 HloModule Module, is_scheduled=true
5675
5676 fused_computation {
5677 param0 = f32[2,3] parameter(0)
5678 constant.1 = f32[] constant(0)
5679 broadcast = f32[2,1] broadcast(constant.1), dimensions={}
5680 constant.3 = s32[] constant(0)
5681 ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
5682 }
5683
5684 ENTRY main {
5685 param = f32[2,3] parameter(0)
5686 negate = f32[2,3] negate(param)
5687 fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation
5688 ROOT add = f32[2,3] add(fusion, fusion)
5689 }
5690 )";
5691
5692 TF_ASSERT_OK_AND_ASSIGN(auto module,
5693 ParseAndReturnVerifiedModule(hlo_string));
5694 auto preset_assignments = AssignMemorySpace(module.get());
5695 HloInstruction* negate_instruction =
5696 module->entry_computation()->GetInstructionWithName("negate");
5697 int64_t negate_offset =
5698 GetAlternateMemoryOffset(*preset_assignments, negate_instruction);
5699 HloInstruction* fusion_instruction =
5700 module->entry_computation()->GetInstructionWithName("fusion");
5701 int64_t fusion_offset =
5702 GetAlternateMemoryOffset(*preset_assignments, fusion_instruction);
5703 // We expect negate and fusion to get the same offsets.
5704 EXPECT_EQ(negate_offset, fusion_offset);
5705 const bool allocate_across_sequential_calls = GetParam();
5706 if (allocate_across_sequential_calls) {
5707 EXPECT_NE(negate_offset, -1);
5708 }
5709 }
5710
TEST_P(MemorySpaceAssignmentTest,ConditionalInPlaceOp)5711 TEST_P(MemorySpaceAssignmentTest, ConditionalInPlaceOp) {
5712 absl::string_view hlo_string = R"(
5713 HloModule Module, is_scheduled=true
5714
5715 fused_computation {
5716 param0 = f32[2,3] parameter(0)
5717 constant.1 = f32[] constant(0)
5718 broadcast = f32[2,1] broadcast(constant.1), dimensions={}
5719 constant.3 = s32[] constant(0)
5720 ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
5721 }
5722
5723 true_computation {
5724 p0 = (f32[2,3]) parameter(0)
5725 gte = f32[2,3] get-tuple-element(p0), index=0
5726 ROOT neg1 = f32[2,3] negate(gte)
5727 }
5728
5729 false_computation {
5730 p0 = (f32[2,3]) parameter(0)
5731 gte = f32[2,3] get-tuple-element(p0), index=0
5732 neg2 = f32[2,3] negate(gte)
5733 ROOT fusion = f32[2,3] fusion(neg2), kind=kLoop, calls=fused_computation
5734 }
5735
5736 ENTRY entry {
5737 p0 = f32[2,3] parameter(0)
5738 p1 = pred[] parameter(1)
5739 copy = f32[2,3] copy(p0)
5740 tuple = (f32[2,3]) tuple(copy)
5741 ROOT conditional = f32[2,3] conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
5742 }
5743 )";
5744
5745 TF_ASSERT_OK_AND_ASSIGN(auto module,
5746 ParseAndReturnVerifiedModule(hlo_string));
5747 AssignMemorySpace(module.get());
5748 }
5749
5750 INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
5751 MemorySpaceAssignmentTest,
5752 ::testing::Values(false, true));
5753
5754 using AsynchronousCopyResourceTest = ::testing::Test;
5755
TEST_F(AsynchronousCopyResourceTest,Simple)5756 TEST_F(AsynchronousCopyResourceTest, Simple) {
5757 // time: 0 1 2 3 4 5 6 7 8 9
5758 // resource: 2 3 1 6 7 1 7 2 2 4
5759 // -1,3,5 +-----+ OK
5760 // resource: 0 0 1 6 7 1 7 2 2 4
5761 // 1,4,4 +---+ OK
5762 // resource: 0 0 0 3 7 1 7 2 2 4
5763 // 5,9,10 +-----+
5764 // resource: 0 0 0 3 7 1 0 0 1 4
5765 // 4,9,3 +-------+ Violate
5766 // 4,8,2 +-----+ OK; The 5,9 copy shifts resource to right.
5767 // resource: 0 0 0 3 7 0 0 0 0 4
5768 auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5769 AsynchronousCopyResource resource(
5770 {2.0, 3.0, 1.0, 6.0, 7.0, 1.0, 7.0, 2.0, 2.0, 4.0});
5771 EXPECT_TRUE(resource.HasEnoughResource(-1, 3, 5.0));
5772 resource.AddCopy({-1, 3, 5.0, alternate_mem_space, 0});
5773 EXPECT_TRUE(resource.HasEnoughResource(1, 4, 4.0));
5774 resource.AddCopy({1, 4, 4.0, alternate_mem_space, 1});
5775 EXPECT_TRUE(resource.HasEnoughResource(5, 9, 10.0));
5776 resource.AddCopy({5, 9, 10.0, alternate_mem_space, 2});
5777 EXPECT_FALSE(resource.HasEnoughResource(4, 9, 3.0));
5778 EXPECT_TRUE(resource.HasEnoughResource(4, 8, 2.0));
5779 resource.AddCopy({4, 8, 2.0, alternate_mem_space, 3});
5780 }
5781
TEST_F(AsynchronousCopyResourceTest,Propagate)5782 TEST_F(AsynchronousCopyResourceTest, Propagate) {
5783 // time: 0 1 2 3 4 5 6 7 8 9
5784 // resource: 2 2 2 2 2 2 2 2 2 2
5785 // 6,10,2 +-----+ OK
5786 // resource: 2 2 2 2 2 2 2 0 2 2
5787 // 5,9,2 +-----+ OK
5788 // resource: 2 2 2 2 2 2 0 0 2 2
5789 // 4,8,2 +-----+ OK
5790 // resource: 2 2 2 2 2 0 0 0 2 2
5791 // 3,7,2 +-----+ OK
5792 // resource: 2 2 2 2 0 0 0 0 2 2
5793 // 2,6,2 +-----+ OK
5794 // resource: 2 2 2 0 0 0 0 0 2 2
5795 // 1,5,2 +-----+ OK
5796 // resource: 2 2 0 0 0 0 0 0 2 2
5797 // 0,4,3 +-----+ OK
5798 // resource: 2 0 0 0 0 0 0 0 1 2
5799 // 0,4,3 +-----+ OK
5800 // resource: 2 0 0 0 0 0 0 0 0 0
5801 // 0,4,1 +-----+ Violate
5802 auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5803 AsynchronousCopyResource resource(
5804 {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0});
5805 EXPECT_TRUE(resource.HasEnoughResource(6, 10, 2.0));
5806 resource.AddCopy({6, 10, 2.0, alternate_mem_space, 0});
5807 EXPECT_EQ(
5808 resource.GetCurrentResources(),
5809 std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 2.0, 2.0}));
5810 EXPECT_TRUE(resource.HasEnoughResource(5, 9, 2.0));
5811 resource.AddCopy({5, 9, 2.0, alternate_mem_space, 1});
5812 EXPECT_TRUE(resource.HasEnoughResource(4, 8, 2.0));
5813 resource.AddCopy({4, 8, 2.0, alternate_mem_space, 2});
5814 EXPECT_TRUE(resource.HasEnoughResource(3, 7, 2.0));
5815 resource.AddCopy({3, 7, 2.0, alternate_mem_space, 3});
5816 EXPECT_TRUE(resource.HasEnoughResource(2, 6, 2.0));
5817 resource.AddCopy({2, 6, 2.0, alternate_mem_space, 4});
5818 EXPECT_TRUE(resource.HasEnoughResource(1, 5, 2.0));
5819 resource.AddCopy({1, 5, 2.0, alternate_mem_space, 5});
5820 EXPECT_EQ(
5821 resource.GetCurrentResources(),
5822 std::vector<float>({2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0}));
5823 EXPECT_TRUE(resource.HasEnoughResource(0, 4, 3.0));
5824 resource.AddCopy({0, 4, 3.0, alternate_mem_space, 6});
5825 EXPECT_EQ(
5826 resource.GetCurrentResources(),
5827 std::vector<float>({2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0}));
5828 EXPECT_TRUE(resource.HasEnoughResource(0, 4, 3.0));
5829 resource.AddCopy({0, 4, 3.0, alternate_mem_space, 7});
5830 EXPECT_EQ(
5831 resource.GetCurrentResources(),
5832 std::vector<float>({2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
5833 EXPECT_FALSE(resource.HasEnoughResource(0, 4, 1.0));
5834 }
5835
TEST_F(AsynchronousCopyResourceTest,CantPropagate)5836 TEST_F(AsynchronousCopyResourceTest, CantPropagate) {
5837 // time: 0 1 2 3 4 5 6 7 8 9
5838 // resource: 2 2 2 2 2 2 2 2 2 2
5839 // 5,10,2 +-------+ OK
5840 // resource: 2 2 2 2 2 2 0 2 2 2
5841 // 4,7,2 +---+ OK
5842 // resource: 2 2 2 2 2 0 0 2 2 2
5843 // 4,8,4 +-----+ OK
5844 // resource: 2 2 2 2 2 0 0 0 0 2
5845 // 3,6,4 +---+ Violate
5846 auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5847 AsynchronousCopyResource resource(
5848 {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0});
5849 EXPECT_TRUE(resource.HasEnoughResource(5, 10, 2.0));
5850 resource.AddCopy({5, 10, 2.0, alternate_mem_space, 0});
5851 EXPECT_EQ(
5852 resource.GetCurrentResources(),
5853 std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 2.0, 2.0, 2.0}));
5854 EXPECT_TRUE(resource.HasEnoughResource(4, 7, 2.0));
5855 resource.AddCopy({4, 7, 2.0, alternate_mem_space, 1});
5856 EXPECT_EQ(
5857 resource.GetCurrentResources(),
5858 std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 2.0, 2.0, 2.0}));
5859 EXPECT_TRUE(resource.HasEnoughResource(4, 8, 4.0));
5860 resource.AddCopy({4, 8, 4.0, alternate_mem_space, 2});
5861 EXPECT_EQ(
5862 resource.GetCurrentResources(),
5863 std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 2.0}));
5864 EXPECT_FALSE(resource.HasEnoughResource(3, 6, 4.0));
5865 }
5866
TEST_F(AsynchronousCopyResourceTest,Nested)5867 TEST_F(AsynchronousCopyResourceTest, Nested) {
5868 // time: 0 1 2 3 4
5869 // resource: 2 2 2 2 2
5870 // 1,3,2 +-+ OK
5871 // resource: 2 2 0 2 2
5872 // 0,4,4 +-----+ Violate
5873 auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5874 AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0});
5875 EXPECT_TRUE(resource.HasEnoughResource(1, 3, 2.0));
5876 resource.AddCopy({1, 3, 2.0, alternate_mem_space, 0});
5877 EXPECT_EQ(resource.GetCurrentResources(),
5878 std::vector<float>({2.0, 2.0, 0.0, 2.0, 2.0}));
5879 EXPECT_FALSE(resource.HasEnoughResource(0, 4, 4.0));
5880 }
5881
TEST_F(AsynchronousCopyResourceTest,Remove)5882 TEST_F(AsynchronousCopyResourceTest, Remove) {
5883 // time: 0 1 2 3 4
5884 // resource: 2 2 2 2 2
5885 // add:2,5,2 +---+ OK
5886 // resource: 2 2 2 0 2
5887 // add:-1,2,3+---+ OK
5888 // resource: 0 1 2 0 2
5889 // add:0,4,4 +-----+ OK
5890 // resource: 0 0 0 0 1
5891 // rem:0,4,4 +-----+
5892 // resource: 0 1 2 0 2
5893 // rem:2,5,2 +---+
5894 // resource: 0 1 2 2 2
5895 // rem:-1,2,3+---+
5896 // resource: 2 2 2 2 2
5897 auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5898 AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0});
5899 AsynchronousCopy copy1{2, 5, 2.0, alternate_mem_space, 0};
5900 AsynchronousCopy copy2{-1, 2, 3.0, alternate_mem_space, 1};
5901 AsynchronousCopy copy3{0, 4, 4.0, alternate_mem_space, 2};
5902 EXPECT_TRUE(resource.HasEnoughResource(2, 5, 2.0));
5903 resource.AddCopy(copy1);
5904 EXPECT_EQ(resource.GetCurrentResources(),
5905 std::vector<float>({2.0, 2.0, 2.0, 0.0, 2.0}));
5906 EXPECT_TRUE(resource.HasEnoughResource(-1, 2, 3.0));
5907 resource.AddCopy(copy2);
5908 EXPECT_EQ(resource.GetCurrentResources(),
5909 std::vector<float>({0.0, 1.0, 2.0, 0.0, 2.0}));
5910 EXPECT_TRUE(resource.HasEnoughResource(0, 4, 4.0));
5911 resource.AddCopy(copy3);
5912 EXPECT_EQ(resource.GetCurrentResources(),
5913 std::vector<float>({0.0, 0.0, 0.0, 0.0, 1.0}));
5914 resource.RemoveCopy(copy3);
5915 EXPECT_EQ(resource.GetCurrentResources(),
5916 std::vector<float>({0.0, 1.0, 2.0, 0.0, 2.0}));
5917 resource.RemoveCopy(copy1);
5918 EXPECT_EQ(resource.GetCurrentResources(),
5919 std::vector<float>({0.0, 1.0, 2.0, 2.0, 2.0}));
5920 resource.RemoveCopy(copy2);
5921 EXPECT_EQ(resource.GetCurrentResources(),
5922 std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0}));
5923 }
5924
TEST_F(AsynchronousCopyResourceTest,NestedRemove)5925 TEST_F(AsynchronousCopyResourceTest, NestedRemove) {
5926 // time: 0 1 2 3 4
5927 // resource: 2 2 2 2 2
5928 // add:1,3,2 +-+ OK
5929 // resource: 2 2 0 2 2
5930 // add:0,4,4 +-----+ Violate
5931 // rem:1,3,2 +-+
5932 // resource: 2 2 2 2 2
5933 // add:0,4,4 +-----+ OK
5934 // resource: 2 0 0 2 2
5935 // add:1,3,2 +-+ Violate
5936 // rem:0,4,4 +-----+
5937 // resource: 2 2 2 2 2
5938 // add:1,3,2 +-+ OK
5939 // resource: 2 2 0 2 2
5940 auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5941 AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0});
5942 AsynchronousCopy copy1{1, 3, 2.0, alternate_mem_space, 0};
5943 AsynchronousCopy copy2{0, 4, 4.0, alternate_mem_space, 1};
5944 EXPECT_TRUE(resource.HasEnoughResource(1, 3, 2.0));
5945 resource.AddCopy(copy1);
5946 EXPECT_EQ(resource.GetCurrentResources(),
5947 std::vector<float>({2.0, 2.0, 0.0, 2.0, 2.0}));
5948 EXPECT_FALSE(resource.HasEnoughResource(0, 4, 4.0));
5949 resource.RemoveCopy(copy1);
5950 auto current_resources = resource.GetCurrentResources();
5951 EXPECT_EQ(resource.GetCurrentResources(),
5952 std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0}));
5953 EXPECT_TRUE(resource.HasEnoughResource(0, 4, 4.0));
5954 resource.AddCopy(copy2);
5955 EXPECT_EQ(resource.GetCurrentResources(),
5956 std::vector<float>({2.0, 0.0, 0.0, 2.0, 2.0}));
5957 EXPECT_FALSE(resource.HasEnoughResource(1, 3, 2.0));
5958 resource.RemoveCopy(copy2);
5959 EXPECT_EQ(resource.GetCurrentResources(),
5960 std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0}));
5961 EXPECT_TRUE(resource.HasEnoughResource(1, 3, 2.0));
5962 }
5963
TEST_F(AsynchronousCopyResourceTest,PropagateRemove)5964 TEST_F(AsynchronousCopyResourceTest, PropagateRemove) {
5965 // time: 0 1 2 3 4 5 6 7 8 9
5966 // resource: 2 2 2 2 2 2 2 2 2 2
5967 // add:6,10,2 +-----+ OK
5968 // resource: 2 2 2 2 2 2 2 0 2 2
5969 // add:5,9,2 +-----+ OK
5970 // resource: 2 2 2 2 2 2 0 0 2 2
5971 // add:4,8,2 +-----+ OK
5972 // resource: 2 2 2 2 2 0 0 0 2 2
5973 // add:3,7,2 +-----+ OK
5974 // resource: 2 2 2 2 0 0 0 0 2 2
5975 // add:2,6,2 +-----+ OK
5976 // resource: 2 2 2 0 0 0 0 0 2 2
5977 // add:1,5,2 +-----+ OK
5978 // resource: 2 2 0 0 0 0 0 0 2 2
5979 // add:0,4,3 +-----+ OK
5980 // resource: 2 0 0 0 0 0 0 0 1 2
5981 // add:0,5,3 +-------+ OK
5982 // resource: 2 0 0 0 0 0 0 0 0 0
5983 // rem:0,5,3 +-------+
5984 // resource: 2 0 0 0 0 0 0 0 1 2
5985 // rem:0,4,3 +-----+
5986 // resource: 2 2 0 0 0 0 0 0 2 2
5987 auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5988 AsynchronousCopyResource resource(
5989 {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0});
5990 EXPECT_TRUE(resource.HasEnoughResource(6, 10, 2.0));
5991 resource.AddCopy({6, 10, 2.0, alternate_mem_space, 0});
5992 EXPECT_EQ(
5993 resource.GetCurrentResources(),
5994 std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 2.0, 2.0}));
5995 EXPECT_TRUE(resource.HasEnoughResource(5, 9, 2.0));
5996 resource.AddCopy({5, 9, 2.0, alternate_mem_space, 1});
5997 EXPECT_TRUE(resource.HasEnoughResource(4, 8, 2.0));
5998 resource.AddCopy({4, 8, 2.0, alternate_mem_space, 2});
5999 EXPECT_TRUE(resource.HasEnoughResource(3, 7, 2.0));
6000 resource.AddCopy({3, 7, 2.0, alternate_mem_space, 3});
6001 EXPECT_TRUE(resource.HasEnoughResource(2, 6, 2.0));
6002 resource.AddCopy({2, 6, 2.0, alternate_mem_space, 4});
6003 EXPECT_TRUE(resource.HasEnoughResource(1, 5, 2.0));
6004 resource.AddCopy({1, 5, 2.0, alternate_mem_space, 5});
6005 EXPECT_EQ(
6006 resource.GetCurrentResources(),
6007 std::vector<float>({2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0}));
6008 AsynchronousCopy copy1{0, 4, 3.0, alternate_mem_space, 6};
6009 EXPECT_TRUE(resource.HasEnoughResource(0, 4, 3.0));
6010 resource.AddCopy(copy1);
6011 EXPECT_EQ(
6012 resource.GetCurrentResources(),
6013 std::vector<float>({2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0}));
6014 EXPECT_TRUE(resource.HasEnoughResource(0, 5, 3.0));
6015 AsynchronousCopy copy2{0, 5, 3.0, alternate_mem_space, 7};
6016 resource.AddCopy(copy2);
6017 EXPECT_EQ(
6018 resource.GetCurrentResources(),
6019 std::vector<float>({2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
6020 resource.RemoveCopy(copy2);
6021 EXPECT_EQ(
6022 resource.GetCurrentResources(),
6023 std::vector<float>({2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0}));
6024 resource.RemoveCopy(copy1);
6025 EXPECT_EQ(
6026 resource.GetCurrentResources(),
6027 std::vector<float>({2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0}));
6028 }
6029
TEST_F(AsynchronousCopyResourceTest,StartAtZeroAndRemove)6030 TEST_F(AsynchronousCopyResourceTest, StartAtZeroAndRemove) {
6031 // time: 0 1 2 3 4
6032 // resource: 0 0 1 1 2
6033 // add:0,4,2 +-----+ OK
6034 // resource: 0 0 0 0 2
6035 // rem:0,4,2 +-----+
6036 // resource: 0 0 1 1 2
6037 // add:0,4,2 +-----+ OK
6038 // resource: 0 0 0 0 2
6039 auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
6040 AsynchronousCopyResource resource({0.0, 0.0, 1.0, 1.0, 2.0});
6041 AsynchronousCopy copy1{0, 4, 2.0, alternate_mem_space, 0};
6042 EXPECT_TRUE(resource.HasEnoughResource(0, 4, 2.0));
6043 resource.AddCopy(copy1);
6044 EXPECT_EQ(resource.GetCurrentResources(),
6045 std::vector<float>({0.0, 0.0, 0.0, 0.0, 2.0}));
6046 resource.RemoveCopy(copy1);
6047 EXPECT_EQ(resource.GetCurrentResources(),
6048 std::vector<float>({0.0, 0.0, 1.0, 1.0, 2.0}));
6049 resource.AddCopy(copy1);
6050 EXPECT_EQ(resource.GetCurrentResources(),
6051 std::vector<float>({0.0, 0.0, 0.0, 0.0, 2.0}));
6052 }
6053
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTest)6054 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTest) {
6055 HloComputation::Builder builder(TestName());
6056
6057 constexpr int kBatch = 8;
6058 constexpr int kFeature = 8;
6059 constexpr int kOutput = 2;
6060
6061 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6062 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6063 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6064 HloInstruction* lhs = builder.AddInstruction(
6065 HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
6066 HloInstruction* rhs = builder.AddInstruction(
6067 HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
6068
6069 DotDimensionNumbers dot_dnums;
6070 dot_dnums.add_lhs_contracting_dimensions(1);
6071 dot_dnums.add_rhs_contracting_dimensions(0);
6072 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6073 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6074
6075 auto module = CreateNewVerifiedModule();
6076 HloComputation* computation = module->AddEntryComputation(builder.Build());
6077
6078 HloSchedule schedule(module.get());
6079 schedule.set_sequence(computation, {lhs, rhs, dot});
6080 TF_CHECK_OK(module->set_schedule(schedule));
6081
6082 AssignMemorySpace(module.get());
6083
6084 auto cross_program_prefetches = module->CrossProgramPrefetches();
6085 EXPECT_EQ(cross_program_prefetches.size(), 1);
6086 if (!cross_program_prefetches.empty()) {
6087 EXPECT_EQ(cross_program_prefetches[0].first, 1);
6088 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({}));
6089 }
6090 }
6091
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTupleTest)6092 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleTest) {
6093 HloComputation::Builder builder(TestName());
6094
6095 constexpr int kBatch = 8;
6096 constexpr int kFeature = 8;
6097 constexpr int kOutput = 2;
6098
6099 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6100 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6101 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6102 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6103 HloInstruction* param = builder.AddInstruction(
6104 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
6105
6106 auto lhs = builder.AddInstruction(
6107 HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
6108 auto rhs = builder.AddInstruction(
6109 HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
6110
6111 DotDimensionNumbers dot_dnums;
6112 dot_dnums.add_lhs_contracting_dimensions(1);
6113 dot_dnums.add_rhs_contracting_dimensions(0);
6114 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6115 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6116
6117 auto module = CreateNewVerifiedModule();
6118 HloComputation* computation = module->AddEntryComputation(builder.Build());
6119
6120 HloSchedule schedule(module.get());
6121 schedule.set_sequence(computation, {param, lhs, rhs, dot});
6122 TF_CHECK_OK(module->set_schedule(schedule));
6123
6124 AssignMemorySpace(module.get());
6125
6126 auto cross_program_prefetches = module->CrossProgramPrefetches();
6127 EXPECT_EQ(cross_program_prefetches.size(), 1);
6128 if (!cross_program_prefetches.empty()) {
6129 EXPECT_EQ(cross_program_prefetches[0].first, 0);
6130 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
6131 }
6132 }
6133
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchBitcastTest)6134 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTest) {
6135 HloComputation::Builder builder(TestName());
6136
6137 constexpr int kBatch = 8;
6138 constexpr int kFeature = 8;
6139 constexpr int kOutput = 2;
6140
6141 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6142 auto rhs_shape = ShapeUtil::MakeShape(F32, {kOutput, kFeature});
6143 auto bitcast_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6144 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6145 HloInstruction* lhs = builder.AddInstruction(
6146 HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
6147 HloInstruction* rhs = builder.AddInstruction(
6148 HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
6149
6150 auto bitcast =
6151 builder.AddInstruction(HloInstruction::CreateBitcast(bitcast_shape, rhs));
6152
6153 DotDimensionNumbers dot_dnums;
6154 dot_dnums.add_lhs_contracting_dimensions(1);
6155 dot_dnums.add_rhs_contracting_dimensions(0);
6156 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6157 result_shape, lhs, bitcast, dot_dnums, DefaultPrecisionConfig(2)));
6158
6159 auto module = CreateNewVerifiedModule();
6160 HloComputation* computation = module->AddEntryComputation(builder.Build());
6161
6162 HloSchedule schedule(module.get());
6163 schedule.set_sequence(computation, {lhs, rhs, bitcast, dot});
6164 TF_CHECK_OK(module->set_schedule(schedule));
6165
6166 AssignMemorySpace(module.get());
6167
6168 auto cross_program_prefetches = module->CrossProgramPrefetches();
6169 EXPECT_EQ(cross_program_prefetches.size(), 1);
6170 if (!cross_program_prefetches.empty()) {
6171 EXPECT_EQ(cross_program_prefetches[0].first, 1);
6172 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({}));
6173 }
6174 }
6175
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchBitcastTupleTest)6176 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTupleTest) {
6177 HloComputation::Builder builder(TestName());
6178
6179 constexpr int kBatch = 8;
6180 constexpr int kFeature = 8;
6181 constexpr int kOutput = 2;
6182
6183 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6184 auto rhs_shape = ShapeUtil::MakeShape(F32, {kOutput, kFeature});
6185 auto bitcast_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6186 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6187 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6188 HloInstruction* param = builder.AddInstruction(
6189 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
6190
6191 auto lhs = builder.AddInstruction(
6192 HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
6193 auto rhs = builder.AddInstruction(
6194 HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
6195
6196 auto bitcast =
6197 builder.AddInstruction(HloInstruction::CreateBitcast(bitcast_shape, rhs));
6198
6199 DotDimensionNumbers dot_dnums;
6200 dot_dnums.add_lhs_contracting_dimensions(1);
6201 dot_dnums.add_rhs_contracting_dimensions(0);
6202 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6203 result_shape, lhs, bitcast, dot_dnums, DefaultPrecisionConfig(2)));
6204
6205 auto module = CreateNewVerifiedModule();
6206 HloComputation* computation = module->AddEntryComputation(builder.Build());
6207
6208 HloSchedule schedule(module.get());
6209 schedule.set_sequence(computation, {param, lhs, rhs, bitcast, dot});
6210 TF_CHECK_OK(module->set_schedule(schedule));
6211
6212 AssignMemorySpace(module.get());
6213
6214 auto cross_program_prefetches = module->CrossProgramPrefetches();
6215 EXPECT_EQ(cross_program_prefetches.size(), 1);
6216 if (!cross_program_prefetches.empty()) {
6217 EXPECT_EQ(cross_program_prefetches[0].first, 0);
6218 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
6219 }
6220 }
6221
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchNestedTupleTest)6222 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNestedTupleTest) {
6223 HloComputation::Builder builder(TestName());
6224
6225 constexpr int kBatch = 8;
6226 constexpr int kFeature = 8;
6227 constexpr int kOutput = 2;
6228
6229 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6230 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6231 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6232 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6233 auto tuple_tuple_shape = ShapeUtil::MakeTupleShape({tuple_shape});
6234 HloInstruction* param = builder.AddInstruction(
6235 HloInstruction::CreateParameter(0, tuple_tuple_shape, "p0"));
6236
6237 auto gte = builder.AddInstruction(
6238 HloInstruction::CreateGetTupleElement(tuple_shape, param, 0));
6239
6240 auto lhs = builder.AddInstruction(
6241 HloInstruction::CreateGetTupleElement(lhs_shape, gte, 0));
6242 auto rhs = builder.AddInstruction(
6243 HloInstruction::CreateGetTupleElement(rhs_shape, gte, 1));
6244
6245 DotDimensionNumbers dot_dnums;
6246 dot_dnums.add_lhs_contracting_dimensions(1);
6247 dot_dnums.add_rhs_contracting_dimensions(0);
6248 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6249 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6250
6251 auto module = CreateNewVerifiedModule();
6252 HloComputation* computation = module->AddEntryComputation(builder.Build());
6253
6254 HloSchedule schedule(module.get());
6255 schedule.set_sequence(computation, {param, gte, lhs, rhs, dot});
6256 TF_CHECK_OK(module->set_schedule(schedule));
6257
6258 AssignMemorySpace(module.get());
6259
6260 auto cross_program_prefetches = module->CrossProgramPrefetches();
6261 EXPECT_EQ(cross_program_prefetches.size(), 0);
6262 }
6263
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchUnusedParamTest)6264 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchUnusedParamTest) {
6265 HloComputation::Builder builder(TestName());
6266
6267 constexpr int kFeature = 8;
6268 constexpr int kOutput = 2;
6269
6270 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6271 HloInstruction* param = builder.AddInstruction(
6272 HloInstruction::CreateParameter(0, rhs_shape, "p0"));
6273
6274 auto module = CreateNewVerifiedModule();
6275 HloComputation* computation = module->AddEntryComputation(builder.Build());
6276
6277 HloSchedule schedule(module.get());
6278 schedule.set_sequence(computation, {param});
6279 TF_CHECK_OK(module->set_schedule(schedule));
6280
6281 AssignMemorySpace(module.get());
6282
6283 auto cross_program_prefetches = module->CrossProgramPrefetches();
6284 EXPECT_EQ(cross_program_prefetches.size(), 0);
6285 }
6286
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTooBigTest)6287 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTooBigTest) {
6288 HloComputation::Builder builder(TestName());
6289
6290 constexpr int kBatch = 8;
6291 constexpr int kFeature = 8;
6292 constexpr int kOutput = 8;
6293
6294 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6295 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6296 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6297 HloInstruction* lhs = builder.AddInstruction(
6298 HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
6299 HloInstruction* rhs = builder.AddInstruction(
6300 HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
6301
6302 DotDimensionNumbers dot_dnums;
6303 dot_dnums.add_lhs_contracting_dimensions(1);
6304 dot_dnums.add_rhs_contracting_dimensions(0);
6305 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6306 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6307
6308 auto module = CreateNewVerifiedModule();
6309 HloComputation* computation = module->AddEntryComputation(builder.Build());
6310
6311 HloSchedule schedule(module.get());
6312 schedule.set_sequence(computation, {lhs, rhs, dot});
6313 TF_CHECK_OK(module->set_schedule(schedule));
6314
6315 AssignMemorySpace(module.get());
6316
6317 auto cross_program_prefetches = module->CrossProgramPrefetches();
6318 EXPECT_EQ(cross_program_prefetches.size(), 0);
6319 }
6320
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTooBigTupleTest)6321 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTooBigTupleTest) {
6322 HloComputation::Builder builder(TestName());
6323
6324 constexpr int kBatch = 8;
6325 constexpr int kFeature = 8;
6326 constexpr int kOutput = 8;
6327
6328 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6329 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6330 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6331 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6332 HloInstruction* param = builder.AddInstruction(
6333 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
6334
6335 auto lhs = builder.AddInstruction(
6336 HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
6337 auto rhs = builder.AddInstruction(
6338 HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
6339
6340 DotDimensionNumbers dot_dnums;
6341 dot_dnums.add_lhs_contracting_dimensions(1);
6342 dot_dnums.add_rhs_contracting_dimensions(0);
6343 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6344 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6345
6346 auto module = CreateNewVerifiedModule();
6347 HloComputation* computation = module->AddEntryComputation(builder.Build());
6348
6349 HloSchedule schedule(module.get());
6350 schedule.set_sequence(computation, {param, lhs, rhs, dot});
6351 TF_CHECK_OK(module->set_schedule(schedule));
6352
6353 AssignMemorySpace(module.get());
6354
6355 auto cross_program_prefetches = module->CrossProgramPrefetches();
6356 EXPECT_EQ(cross_program_prefetches.size(), 0);
6357 }
6358
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchFusionTest)6359 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) {
6360 HloComputation::Builder builder(TestName());
6361
6362 constexpr int kBatch = 2;
6363 constexpr int kFeature = 2;
6364 constexpr int kOutput = 2;
6365
6366 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6367 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6368 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6369
6370 auto module = CreateNewVerifiedModule();
6371 HloComputation::Builder fusion_builder("fusion");
6372 {
6373 HloInstruction* lhs = fusion_builder.AddInstruction(
6374 HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
6375 HloInstruction* rhs = fusion_builder.AddInstruction(
6376 HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
6377 DotDimensionNumbers dot_dnums;
6378 dot_dnums.add_lhs_contracting_dimensions(1);
6379 dot_dnums.add_rhs_contracting_dimensions(0);
6380 auto dot = fusion_builder.AddInstruction(HloInstruction::CreateDot(
6381 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6382 (void)dot;
6383 }
6384 HloComputation* fusion_computation =
6385 module->AddEmbeddedComputation(fusion_builder.Build());
6386
6387 auto activations = builder.AddInstruction(HloInstruction::CreateConstant(
6388 LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
6389 auto weights = builder.AddInstruction(HloInstruction::CreateConstant(
6390 LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
6391 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
6392 result_shape, HloInstruction::FusionKind::kCustom, {activations, weights},
6393 fusion_computation));
6394
6395 HloComputation* computation = module->AddEntryComputation(builder.Build());
6396
6397 HloSchedule schedule(module.get());
6398 schedule.set_sequence(computation, {activations, weights, fusion});
6399 TF_CHECK_OK(module->set_schedule(schedule));
6400
6401 AssignMemorySpace(module.get());
6402
6403 auto cross_program_prefetches = module->CrossProgramPrefetches();
6404 EXPECT_EQ(cross_program_prefetches.size(), 0);
6405 }
6406
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchFusionTupleTest)6407 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTupleTest) {
6408 HloComputation::Builder builder(TestName());
6409
6410 constexpr int kBatch = 2;
6411 constexpr int kFeature = 2;
6412 constexpr int kOutput = 2;
6413
6414 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6415 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6416 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6417 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6418
6419 auto module = CreateNewVerifiedModule();
6420 HloComputation::Builder fusion_builder("fusion");
6421 {
6422 HloInstruction* param = fusion_builder.AddInstruction(
6423 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
6424 auto lhs = fusion_builder.AddInstruction(
6425 HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
6426 auto rhs = fusion_builder.AddInstruction(
6427 HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
6428 DotDimensionNumbers dot_dnums;
6429 dot_dnums.add_lhs_contracting_dimensions(1);
6430 dot_dnums.add_rhs_contracting_dimensions(0);
6431 auto dot = fusion_builder.AddInstruction(HloInstruction::CreateDot(
6432 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6433 (void)dot;
6434 }
6435 HloComputation* fusion_computation =
6436 module->AddEmbeddedComputation(fusion_builder.Build());
6437
6438 auto activations = builder.AddInstruction(HloInstruction::CreateConstant(
6439 LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
6440 auto weights = builder.AddInstruction(HloInstruction::CreateConstant(
6441 LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
6442 HloInstruction* tuple = builder.AddInstruction(
6443 HloInstruction::CreateTuple({activations, weights}));
6444 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
6445 result_shape, HloInstruction::FusionKind::kCustom, {tuple},
6446 fusion_computation));
6447
6448 HloComputation* computation = module->AddEntryComputation(builder.Build());
6449
6450 HloSchedule schedule(module.get());
6451 schedule.set_sequence(computation, {activations, weights, tuple, fusion});
6452 TF_CHECK_OK(module->set_schedule(schedule));
6453
6454 AssignMemorySpace(module.get());
6455
6456 auto cross_program_prefetches = module->CrossProgramPrefetches();
6457 EXPECT_EQ(cross_program_prefetches.size(), 0);
6458 }
6459
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchPinnedTest)6460 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTest) {
6461 HloComputation::Builder builder(TestName());
6462
6463 constexpr int kBatch = 8;
6464 constexpr int kFeature = 8;
6465 constexpr int kOutput = 2;
6466
6467 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6468 auto rhs_shape = ShapeUtil::MakeShapeWithLayout(
6469 F32, {kFeature, kOutput},
6470 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
6471 /*element_size_in_bits=*/0, kAlternateMemorySpace);
6472 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6473 HloInstruction* lhs = builder.AddInstruction(
6474 HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
6475 HloInstruction* rhs = builder.AddInstruction(
6476 HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
6477
6478 DotDimensionNumbers dot_dnums;
6479 dot_dnums.add_lhs_contracting_dimensions(1);
6480 dot_dnums.add_rhs_contracting_dimensions(0);
6481 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6482 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6483
6484 auto module = CreateNewVerifiedModule();
6485 HloComputation* computation = module->AddEntryComputation(builder.Build());
6486
6487 HloSchedule schedule(module.get());
6488 schedule.set_sequence(computation, {lhs, rhs, dot});
6489 TF_CHECK_OK(module->set_schedule(schedule));
6490
6491 Options options;
6492 options.max_size_in_bytes = 128;
6493 options.alignment_in_bytes = 8;
6494 options.verify = true;
6495 options.is_allowed_in_alternate_mem_fn = [](const HloValue& value) {
6496 return true;
6497 };
6498 std::unique_ptr<PresetAssignments> preset_assignments = AssignMemorySpace(
6499 module.get(),
6500 /*max_outstanding_async_copies=*/-1,
6501 /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2, options);
6502
6503 auto cross_program_prefetches = module->CrossProgramPrefetches();
6504 EXPECT_EQ(cross_program_prefetches.size(), 0);
6505 }
6506
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchPinnedTupleTest)6507 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTupleTest) {
6508 HloComputation::Builder builder(TestName());
6509
6510 constexpr int kBatch = 8;
6511 constexpr int kFeature = 8;
6512 constexpr int kOutput = 2;
6513
6514 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6515 auto rhs_shape = ShapeUtil::MakeShapeWithLayout(
6516 F32, {kFeature, kOutput},
6517 /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
6518 /*element_size_in_bits=*/0, kAlternateMemorySpace);
6519 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6520 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6521 HloInstruction* param = builder.AddInstruction(
6522 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
6523
6524 auto lhs = builder.AddInstruction(
6525 HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
6526 auto rhs = builder.AddInstruction(
6527 HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
6528
6529 DotDimensionNumbers dot_dnums;
6530 dot_dnums.add_lhs_contracting_dimensions(1);
6531 dot_dnums.add_rhs_contracting_dimensions(0);
6532 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6533 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6534
6535 auto module = CreateNewVerifiedModule();
6536 HloComputation* computation = module->AddEntryComputation(builder.Build());
6537
6538 HloSchedule schedule(module.get());
6539 schedule.set_sequence(computation, {param, lhs, rhs, dot});
6540 TF_CHECK_OK(module->set_schedule(schedule));
6541
6542 Options options;
6543 options.max_size_in_bytes = 128;
6544 options.alignment_in_bytes = 8;
6545 options.verify = true;
6546 options.is_allowed_in_alternate_mem_fn = [](const HloValue& value) {
6547 return true;
6548 };
6549 std::unique_ptr<PresetAssignments> preset_assignments = AssignMemorySpace(
6550 module.get(),
6551 /*max_outstanding_async_copies=*/-1,
6552 /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2, options);
6553
6554 auto cross_program_prefetches = module->CrossProgramPrefetches();
6555 EXPECT_EQ(cross_program_prefetches.size(), 0);
6556 }
6557
TEST_P(MemorySpaceAssignmentTest,CrossProgramRootDupMayAlias)6558 TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDupMayAlias) {
6559 absl::string_view hlo_string = R"(
6560 HloModule cross_program_prefetch, is_scheduled=true, input_output_alias={ {}: (0, {}, may-alias) }
6561 ENTRY CrossProgramPrefetch {
6562 c0 = s32[1,2] constant({{77, 77}})
6563 c1 = s32[] constant(0)
6564 p0 = s32[2,2] parameter(0)
6565 ROOT dup = s32[2,2] dynamic-update-slice(s32[2,2] p0, s32[1,2] c0, s32[] c1, s32[] c1)
6566 }
6567 )";
6568 TF_ASSERT_OK_AND_ASSIGN(auto module,
6569 ParseAndReturnVerifiedModule(hlo_string));
6570 auto preset_assignments = AssignMemorySpace(
6571 module.get(), /*max_outstanding_async_copies=*/-1,
6572 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6573
6574 auto cross_program_prefetches = module->CrossProgramPrefetches();
6575 EXPECT_EQ(cross_program_prefetches.size(), 1);
6576 }
6577
TEST_P(MemorySpaceAssignmentTest,CrossProgramRootDup)6578 TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDup) {
6579 absl::string_view hlo_string = R"(
6580 HloModule cross_program_prefetch, is_scheduled=true
6581 ENTRY CrossProgramPrefetch {
6582 c0 = s32[1,2] constant({{77, 77}})
6583 c1 = s32[] constant(0)
6584 p0 = s32[2,2] parameter(0)
6585 ROOT dup = s32[2,2] dynamic-update-slice(s32[2,2] p0, s32[1,2] c0, s32[] c1, s32[] c1)
6586 }
6587 )";
6588 TF_ASSERT_OK_AND_ASSIGN(auto module,
6589 ParseAndReturnVerifiedModule(hlo_string));
6590 auto preset_assignments = AssignMemorySpace(
6591 module.get(), /*max_outstanding_async_copies=*/-1,
6592 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6593
6594 auto cross_program_prefetches = module->CrossProgramPrefetches();
6595 EXPECT_EQ(cross_program_prefetches.size(), 1);
6596 }
6597
TEST_P(MemorySpaceAssignmentTest,CrossProgramRootDupDot)6598 TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDupDot) {
6599 // Cross program prefetch since the parameter and the root don't alias.
6600 absl::string_view hlo_string = R"(
6601 HloModule cross_program_prefetch, is_scheduled=true
6602 ENTRY CrossProgramPrefetch {
6603 c0 = s32[1,2] constant({{77, 77}})
6604 c1 = s32[] constant(0)
6605 p0 = s32[2,2] parameter(0)
6606 p1 = s32[2,2] parameter(1)
6607 dup = s32[2,2] dynamic-update-slice(s32[2,2] p0, s32[1,2] c0, s32[] c1, s32[] c1)
6608 ROOT dot = s32[2,2] dot(p1, dup), lhs_contracting_dims={0}, rhs_contracting_dims={0}
6609 }
6610 )";
6611 TF_ASSERT_OK_AND_ASSIGN(auto module,
6612 ParseAndReturnVerifiedModule(hlo_string));
6613 auto preset_assignments = AssignMemorySpace(
6614 module.get(), /*max_outstanding_async_copies=*/-1,
6615 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6616
6617 auto cross_program_prefetches = module->CrossProgramPrefetches();
6618 EXPECT_EQ(cross_program_prefetches.size(), 1);
6619 }
6620
TEST_P(MemorySpaceAssignmentTest,CrossProgramRootDotMayAlias)6621 TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDotMayAlias) {
6622 absl::string_view hlo_string = R"(
6623 HloModule cross_program_prefetch, is_scheduled=true, input_output_alias={ {}: (0, {}, may-alias) }
6624 ENTRY CrossProgramPrefetch {
6625 p0 = s32[2,2] parameter(0)
6626 p1 = s32[2,2] parameter(1)
6627 ROOT dot = s32[2,2] dot(p1, p0), lhs_contracting_dims={0}, rhs_contracting_dims={0}
6628 }
6629 )";
6630 TF_ASSERT_OK_AND_ASSIGN(auto module,
6631 ParseAndReturnVerifiedModule(hlo_string));
6632 auto preset_assignments = AssignMemorySpace(
6633 module.get(), /*max_outstanding_async_copies=*/-1,
6634 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6635
6636 auto cross_program_prefetches = module->CrossProgramPrefetches();
6637 EXPECT_EQ(cross_program_prefetches.size(), 1);
6638 }
6639
TEST_P(MemorySpaceAssignmentTest,CrossProgramRootParameter)6640 TEST_P(MemorySpaceAssignmentTest, CrossProgramRootParameter) {
6641 absl::string_view hlo_string = R"(
6642 HloModule cross_program_prefetch, is_scheduled=true
6643 ENTRY CrossProgramPrefetch {
6644 p0 = s32[2,2] parameter(0)
6645 ROOT bitcast = u32[2,2] bitcast(p0)
6646 }
6647 )";
6648 TF_ASSERT_OK_AND_ASSIGN(auto module,
6649 ParseAndReturnVerifiedModule(hlo_string));
6650 auto preset_assignments = AssignMemorySpace(
6651 module.get(), /*max_outstanding_async_copies=*/-1,
6652 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6653
6654 auto cross_program_prefetches = module->CrossProgramPrefetches();
6655 EXPECT_EQ(cross_program_prefetches.size(), 1);
6656 }
6657
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchNoReuse)6658 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) {
6659 // This test is for checking if the cross-program-prefetched buffer is freed
6660 // after its last use and there is an end-of-program prefetch.
6661 absl::string_view hlo_string = R"(
6662 HloModule cross_program_prefetch, is_scheduled=true
6663
6664 ENTRY CrossProgramPrefetch {
6665 p0 = f32[8,8]{1,0} parameter(0)
6666 p1 = f32[8,2]{1,0} parameter(1)
6667 dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6668 negate.1 = f32[8,2]{1,0} negate(dot)
6669 negate.2 = f32[8,2]{1,0} negate(negate.1)
6670 negate.3 = f32[8,2]{1,0} negate(negate.2)
6671 negate.4 = f32[8,2]{1,0} negate(negate.3)
6672 negate.5 = f32[8,2]{1,0} negate(negate.4)
6673 negate.6 = f32[8,2]{1,0} negate(negate.5)
6674 negate.7 = f32[8,2]{1,0} negate(negate.6)
6675 negate.8 = f32[8,2]{1,0} negate(negate.7)
6676 ROOT negate.9 = f32[8,2]{1,0} negate(negate.8)
6677 }
6678 )";
6679 TF_ASSERT_OK_AND_ASSIGN(auto module,
6680 ParseAndReturnVerifiedModule(hlo_string));
6681 auto preset_assignments = AssignMemorySpace(
6682 module.get(), /*max_outstanding_async_copies=*/-1,
6683 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6684
6685 auto cross_program_prefetches = module->CrossProgramPrefetches();
6686 EXPECT_EQ(cross_program_prefetches.size(), 1);
6687 if (!cross_program_prefetches.empty()) {
6688 EXPECT_EQ(cross_program_prefetches[0].first, 1);
6689 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({}));
6690 }
6691
6692 TF_ASSERT_OK_AND_ASSIGN(
6693 std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
6694 HloDataflowAnalysis::Run(*module));
6695 const HloValue& cross_program_prefetched_value =
6696 dataflow_analysis->GetValueDefinedAt(
6697 module->entry_computation()->parameter_instruction(1), {});
6698 // Expect that there are two prefetches that use this value, one is the
6699 // cross-program prefetch, the other is the end-of-program prefetch.
6700 auto is_cross_program_prefetch = [](const HloUse& use) {
6701 return use.instruction->opcode() == HloOpcode::kCopyStart &&
6702 use.instruction->is_cross_program_prefetch();
6703 };
6704 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6705 is_cross_program_prefetch),
6706 1);
6707 auto is_end_of_program_prefetch = [](const HloUse& use) {
6708 return use.instruction->opcode() == HloOpcode::kCopyStart &&
6709 !use.instruction->is_cross_program_prefetch();
6710 };
6711 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6712 is_end_of_program_prefetch),
6713 1);
6714 // Also verify that the copy-done for the end-of-program prefetch is the last
6715 // instruction in schedule.
6716 const HloInstruction* last_instruction =
6717 module->schedule()
6718 .sequence(module->entry_computation())
6719 .instructions()[module->entry_computation()->instruction_count() - 1];
6720 EXPECT_THAT(last_instruction, op::CopyDone());
6721 EXPECT_NE(last_instruction, module->entry_computation()->root_instruction());
6722 // Cross program prefetch would use offset 0 because that's the first
6723 // assignment. Since we are freeing the cross-program prefetch buffer, we
6724 // would also expect to see some of the intermediate computations (one of the
6725 // negate ops) to also get 0 offset allocations.
6726 bool has_zero_offset_allocations = false;
6727 for (auto pos_and_chunk : preset_assignments->chunks()) {
6728 if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate &&
6729 pos_and_chunk.second.offset == 0) {
6730 has_zero_offset_allocations = true;
6731 }
6732 }
6733 EXPECT_TRUE(has_zero_offset_allocations);
6734 }
6735
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTupleNoReuse)6736 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleNoReuse) {
6737 // This test is for checking if the cross-program-prefetched buffer is freed
6738 // after its last use and there is an end-of-program prefetch.
6739 absl::string_view hlo_string = R"(
6740 HloModule cross_program_prefetch, is_scheduled=true
6741
6742 ENTRY CrossProgramPrefetch {
6743 p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
6744 get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
6745 get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
6746 dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6747 negate.1 = f32[8,2]{1,0} negate(dot)
6748 negate.2 = f32[8,2]{1,0} negate(negate.1)
6749 negate.3 = f32[8,2]{1,0} negate(negate.2)
6750 negate.4 = f32[8,2]{1,0} negate(negate.3)
6751 negate.5 = f32[8,2]{1,0} negate(negate.4)
6752 negate.6 = f32[8,2]{1,0} negate(negate.5)
6753 negate.7 = f32[8,2]{1,0} negate(negate.6)
6754 negate.8 = f32[8,2]{1,0} negate(negate.7)
6755 ROOT negate.9 = f32[8,2]{1,0} negate(negate.8)
6756 }
6757 )";
6758 TF_ASSERT_OK_AND_ASSIGN(auto module,
6759 ParseAndReturnVerifiedModule(hlo_string));
6760 auto preset_assignments = AssignMemorySpace(
6761 module.get(), /*max_outstanding_async_copies=*/-1,
6762 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6763
6764 auto cross_program_prefetches = module->CrossProgramPrefetches();
6765 EXPECT_EQ(cross_program_prefetches.size(), 1);
6766 if (!cross_program_prefetches.empty()) {
6767 EXPECT_EQ(cross_program_prefetches[0].first, 0);
6768 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
6769 }
6770
6771 TF_ASSERT_OK_AND_ASSIGN(
6772 std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
6773 HloDataflowAnalysis::Run(*module));
6774 const HloValue& cross_program_prefetched_value =
6775 dataflow_analysis->GetValueDefinedAt(
6776 module->entry_computation()->parameter_instruction(0), {1});
6777 // Expect that there are two prefetches that use this value, one is the
6778 // cross-program prefetch, the other is the end-of-program prefetch.
6779 auto is_cross_program_prefetch = [](const HloUse& use) {
6780 return use.instruction->opcode() == HloOpcode::kCopyStart &&
6781 use.instruction->is_cross_program_prefetch();
6782 };
6783 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6784 is_cross_program_prefetch),
6785 1);
6786 auto is_end_of_program_prefetch = [](const HloUse& use) {
6787 return use.instruction->opcode() == HloOpcode::kCopyStart &&
6788 !use.instruction->is_cross_program_prefetch();
6789 };
6790 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6791 is_end_of_program_prefetch),
6792 1);
6793 // Also verify that the copy-done for the end-of-program prefetch is the last
6794 // instruction in schedule.
6795 const HloInstruction* last_instruction =
6796 module->schedule()
6797 .sequence(module->entry_computation())
6798 .instructions()[module->entry_computation()->instruction_count() - 1];
6799 EXPECT_THAT(last_instruction, op::CopyDone());
6800 EXPECT_NE(last_instruction, module->entry_computation()->root_instruction());
6801 // Cross program prefetch would use offset 0 because that's the first
6802 // assignment. Since we are freeing the cross-program prefetch buffer, we
6803 // would also expect to see some of the intermediate computations (one of the
6804 // negate ops) to also get 0 offset allocations.
6805 bool has_zero_offset_allocations = false;
6806 for (auto pos_and_chunk : preset_assignments->chunks()) {
6807 if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate &&
6808 pos_and_chunk.second.offset == 0) {
6809 has_zero_offset_allocations = true;
6810 }
6811 }
6812 EXPECT_TRUE(has_zero_offset_allocations);
6813 }
6814
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchReuse)6815 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchReuse) {
6816 // This tests the scenario that the cross-program-prefetched buffer is used
6817 // again close to the end of the computation. In this case, it is better not
6818 // to free the buffer.
6819 absl::string_view hlo_string = R"(
6820 HloModule cross_program_prefetch, is_scheduled=true
6821
6822 ENTRY CrossProgramPrefetch {
6823 p0 = f32[8,8]{1,0} parameter(0)
6824 p1 = f32[8,2]{1,0} parameter(1)
6825 dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6826 negate.1 = f32[8,2]{1,0} negate(dot)
6827 negate.2 = f32[8,2]{1,0} negate(negate.1)
6828 negate.3 = f32[8,2]{1,0} negate(negate.2)
6829 negate.4 = f32[8,2]{1,0} negate(negate.3)
6830 negate.5 = f32[8,2]{1,0} negate(negate.4)
6831 negate.6 = f32[8,2]{1,0} negate(negate.5)
6832 negate.7 = f32[8,2]{1,0} negate(negate.6)
6833 negate.8 = f32[8,2]{1,0} negate(negate.7)
6834 ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
6835 }
6836 )";
6837 TF_ASSERT_OK_AND_ASSIGN(auto module,
6838 ParseAndReturnVerifiedModule(hlo_string));
6839
6840 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
6841 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6842
6843 auto cross_program_prefetches = module->CrossProgramPrefetches();
6844 EXPECT_EQ(cross_program_prefetches.size(), 1);
6845 if (!cross_program_prefetches.empty()) {
6846 EXPECT_EQ(cross_program_prefetches[0].first, 1);
6847 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({}));
6848 }
6849
6850 TF_ASSERT_OK_AND_ASSIGN(
6851 std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
6852 HloDataflowAnalysis::Run(*module));
6853 const HloValue& cross_program_prefetched_value =
6854 dataflow_analysis->GetValueDefinedAt(
6855 module->entry_computation()->parameter_instruction(1), {});
6856 // Expect that there is one prefetch that use this value, the cross-program
6857 // prefetch. There shouldn't be an end-of-program prefetch.
6858 auto is_cross_program_prefetch = [](const HloUse& use) {
6859 return use.instruction->opcode() == HloOpcode::kCopyStart &&
6860 use.instruction->is_cross_program_prefetch();
6861 };
6862 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6863 is_cross_program_prefetch),
6864 1);
6865 auto is_end_of_program_prefetch = [](const HloUse& use) {
6866 return use.instruction->opcode() == HloOpcode::kCopyStart &&
6867 !use.instruction->is_cross_program_prefetch();
6868 };
6869 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6870 is_end_of_program_prefetch),
6871 0);
6872 }
6873
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTupleReuse)6874 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleReuse) {
6875 // This tests the scenario that the cross-program-prefetched buffer is used
6876 // again close to the end of the computation. In this case, it is better not
6877 // to free the buffer.
6878 absl::string_view hlo_string = R"(
6879 HloModule cross_program_prefetch, is_scheduled=true
6880
6881 ENTRY CrossProgramPrefetch {
6882 p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
6883 get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
6884 get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
6885 dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6886 negate.1 = f32[8,2]{1,0} negate(dot)
6887 negate.2 = f32[8,2]{1,0} negate(negate.1)
6888 negate.3 = f32[8,2]{1,0} negate(negate.2)
6889 negate.4 = f32[8,2]{1,0} negate(negate.3)
6890 negate.5 = f32[8,2]{1,0} negate(negate.4)
6891 negate.6 = f32[8,2]{1,0} negate(negate.5)
6892 negate.7 = f32[8,2]{1,0} negate(negate.6)
6893 negate.8 = f32[8,2]{1,0} negate(negate.7)
6894 ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, get-tuple-element.1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
6895 }
6896 )";
6897 TF_ASSERT_OK_AND_ASSIGN(auto module,
6898 ParseAndReturnVerifiedModule(hlo_string));
6899
6900 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
6901 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6902
6903 auto cross_program_prefetches = module->CrossProgramPrefetches();
6904 EXPECT_EQ(cross_program_prefetches.size(), 1);
6905 if (!cross_program_prefetches.empty()) {
6906 EXPECT_EQ(cross_program_prefetches[0].first, 0);
6907 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
6908 }
6909
6910 TF_ASSERT_OK_AND_ASSIGN(
6911 std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
6912 HloDataflowAnalysis::Run(*module));
6913 const HloValue& cross_program_prefetched_value =
6914 dataflow_analysis->GetValueDefinedAt(
6915 module->entry_computation()->parameter_instruction(0), {1});
6916 // Expect that there is one prefetch that use this value, the cross-program
6917 // prefetch. There shouldn't be an end-of-program prefetch.
6918 auto is_cross_program_prefetch = [](const HloUse& use) {
6919 return use.instruction->opcode() == HloOpcode::kCopyStart &&
6920 use.instruction->is_cross_program_prefetch();
6921 };
6922 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6923 is_cross_program_prefetch),
6924 1);
6925 auto is_end_of_program_prefetch = [](const HloUse& use) {
6926 return use.instruction->opcode() == HloOpcode::kCopyStart &&
6927 !use.instruction->is_cross_program_prefetch();
6928 };
6929 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6930 is_end_of_program_prefetch),
6931 0);
6932 }
6933
6934 using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
6935
TEST_F(CostAnalysisPrefetchIntervalPickerTest,PrefetchIntervalOrder)6936 TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
6937 absl::string_view hlo_string = R"(
6938 HloModule bug, is_scheduled=true
6939
6940 ENTRY Entry {
6941 param0 = f32[2,4] parameter(0)
6942 a = f32[2,4] negate(param0)
6943 b = f32[2,4] negate(a)
6944 c = f32[2,4] negate(b)
6945 d = f32[2,4] negate(c)
6946 e = f32[2,4] negate(d)
6947 f = f32[2,4] negate(e)
6948 g = f32[2,4] negate(f)
6949 h = f32[2,4] negate(g)
6950 i = f32[2,4] negate(h)
6951 j = f32[2,4] negate(i)
6952 k = f32[2,4] negate(j)
6953 l = f32[2,4] negate(k)
6954 m = f32[2,4] negate(l)
6955 n = f32[2,4] negate(m)
6956 o = f32[2,4] negate(n)
6957 p = f32[2,4] negate(o)
6958 q = f32[2,4] negate(p)
6959 r = f32[2,4] negate(q)
6960 s = f32[2,4] negate(r)
6961 t = f32[2,4] negate(s)
6962 u = f32[2,4] negate(t)
6963 ROOT v = f32[2,4] add(u, param0)
6964 }
6965 )";
6966 TF_ASSERT_OK_AND_ASSIGN(auto module,
6967 ParseAndReturnVerifiedModule(hlo_string));
6968
6969 HloCostAnalysis hlo_cost_analysis(ShapeSize);
6970 Options options;
6971 TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
6972 FakeMemorySpaceAssignmentCostAnalysis::Create(
6973 hlo_cost_analysis, *module, options));
6974 CostAnalysisPrefetchIntervalPicker interval_picker(
6975 *cost_analysis,
6976 /*min_overlap_to_async_copy_ratio=*/1.0,
6977 /*preferred_overlap_to_async_copy_ratio=*/2.0,
6978 /*max_overlap_to_mem_size_async_copy_ratio=*/4.0,
6979 /*mem_size_bytes=*/32);
6980
6981 HloInstruction* root = module->entry_computation()->root_instruction();
6982 const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
6983 interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22);
6984
6985 // Expect that the first interval is (15, 22), which has elapsed time of 6.0,
6986 // twice of the async copy elased (3.0). Then we expect that intervals will be
6987 // visited in alternating increasing and decreasing orders until hitting the
6988 // min and max async copy overlap ratios, which are the intervals (18, 22)
6989 // and (9, 22) respectively.
6990 LOG(INFO) << interval_picker.ToDebugString();
6991 EXPECT_EQ(interval_picker.Next(), 15);
6992 LOG(INFO) << interval_picker.ToDebugString();
6993 EXPECT_EQ(interval_picker.Next(), 16);
6994 LOG(INFO) << interval_picker.ToDebugString();
6995 EXPECT_EQ(interval_picker.Next(), 14);
6996 LOG(INFO) << interval_picker.ToDebugString();
6997 EXPECT_EQ(interval_picker.Next(), 17);
6998 LOG(INFO) << interval_picker.ToDebugString();
6999 EXPECT_EQ(interval_picker.Next(), 13);
7000 LOG(INFO) << interval_picker.ToDebugString();
7001 EXPECT_EQ(interval_picker.Next(), 18); // Min async overlap ratio reached.
7002 LOG(INFO) << interval_picker.ToDebugString();
7003 EXPECT_EQ(interval_picker.Next(), 12);
7004 LOG(INFO) << interval_picker.ToDebugString();
7005 EXPECT_EQ(interval_picker.Next(), 11);
7006 LOG(INFO) << interval_picker.ToDebugString();
7007 EXPECT_EQ(interval_picker.Next(), 10);
7008 LOG(INFO) << interval_picker.ToDebugString();
7009 EXPECT_EQ(interval_picker.Next(), 9); // Max async overlap ratio reached.
7010 LOG(INFO) << interval_picker.ToDebugString();
7011 EXPECT_TRUE(interval_picker.Done());
7012
7013 // Expect that if the time between start_time and end_time is too short, there
7014 // won't be any available intervals.
7015 interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22);
7016 LOG(INFO) << interval_picker.ToDebugString();
7017 EXPECT_TRUE(interval_picker.Done());
7018 }
7019
TEST_F(CostAnalysisPrefetchIntervalPickerTest,PrefetchIntervalOrderWhile)7020 TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) {
7021 absl::string_view hlo_string = R"(
7022 HloModule bug, is_scheduled=true
7023
7024 while_condition {
7025 param1 = (f32[2,4]) parameter(0) // 19
7026 ROOT cond = pred[] constant(true) // 20
7027 }
7028
7029 while_body {
7030 param2 = (f32[2,4]) parameter(0) // 21
7031 gte2 = f32[2,4] get-tuple-element(param2), index=0 // 22
7032 add = f32[2,4] add(gte2, gte2) // 23
7033 ROOT tuple2 = (f32[2,4]) tuple(add) // 24
7034 }
7035
7036 ENTRY Entry {
7037 param0 = f32[2,4] parameter(0) // 0
7038 a = f32[2,4] negate(param0) // 1
7039 b = f32[2,4] negate(a) // 2
7040 c = f32[2,4] negate(b) // 3
7041 d = f32[2,4] negate(c) // 4
7042 e = f32[2,4] negate(d) // 5
7043 f = f32[2,4] negate(e) // 6
7044 g = f32[2,4] negate(f) // 7
7045 h = f32[2,4] negate(g) // 8
7046 i = f32[2,4] negate(h) // 9
7047 j = f32[2,4] negate(i) // 10
7048 k = f32[2,4] negate(j) // 11
7049 l = f32[2,4] negate(k) // 12
7050 m = f32[2,4] negate(l) // 13
7051 n = f32[2,4] negate(m) // 14
7052 o = f32[2,4] negate(n) // 15
7053 p = f32[2,4] negate(o) // 16
7054 q = f32[2,4] negate(p) // 17
7055 tuple = (f32[2,4]) tuple(q) // 18
7056 while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body // 25
7057 gte1 = f32[2,4] get-tuple-element(while), index=0 // 26
7058 r = f32[2,4] negate(gte1) // 27
7059 s = f32[2,4] negate(r) // 28
7060 t = f32[2,4] negate(s) // 29
7061 u = f32[2,4] negate(t) // 30
7062 ROOT v = f32[2,4] add(u, param0) // 31
7063 }
7064 )";
7065 TF_ASSERT_OK_AND_ASSIGN(auto module,
7066 ParseAndReturnVerifiedModule(hlo_string));
7067
7068 HloCostAnalysis hlo_cost_analysis(ShapeSize);
7069 Options options;
7070 TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
7071 FakeMemorySpaceAssignmentCostAnalysis::Create(
7072 hlo_cost_analysis, *module, options));
7073 CostAnalysisPrefetchIntervalPicker interval_picker(
7074 *cost_analysis,
7075 /*min_overlap_to_async_copy_ratio=*/1.0,
7076 /*preferred_overlap_to_async_copy_ratio=*/2.0,
7077 /*max_overlap_to_mem_size_async_copy_ratio=*/12.0,
7078 /*mem_size_bytes=*/32);
7079
7080 EXPECT_EQ(cost_analysis->options()
7081 .xla_tpu_memory_space_assignment_while_execution_count,
7082 5);
7083 HloInstruction* root = module->entry_computation()->root_instruction();
7084 const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
7085 interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31);
7086
7087 // Because there are while loop computations between [19, 24], we ensure that
7088 // the interval picker avoids this interval.
7089 LOG(INFO) << interval_picker.ToDebugString();
7090 EXPECT_EQ(interval_picker.Next(), 25);
7091 LOG(INFO) << interval_picker.ToDebugString();
7092 EXPECT_EQ(interval_picker.Next(), 26);
7093 LOG(INFO) << interval_picker.ToDebugString();
7094 EXPECT_EQ(interval_picker.Next(), 18);
7095 LOG(INFO) << interval_picker.ToDebugString();
7096 EXPECT_EQ(interval_picker.Next(), 27); // Min async overlap ratio reached.
7097 LOG(INFO) << interval_picker.ToDebugString();
7098 EXPECT_EQ(interval_picker.Next(), 17); // Max async overlap ratio reached.
7099 LOG(INFO) << interval_picker.ToDebugString();
7100 EXPECT_TRUE(interval_picker.Done());
7101 }
7102
TEST_F(CostAnalysisPrefetchIntervalPickerTest,NestedWhile)7103 TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) {
7104 // This test is to check against a bug where we didn't assign
7105 // while_nest_level_ for while instructions, and defaulting to 0. This could
7106 // cause the prefetch interval logic to think a nested while instruction is
7107 // the same level as the outermost computation.
7108 absl::string_view hlo_string = R"(
7109 HloModule bug, is_scheduled=true
7110
7111 while_condition.2 {
7112 param1 = (f32[2,4]) parameter(0) // 11
7113 ROOT cond = pred[] constant(true) // 12
7114 }
7115
7116 while_body.2 {
7117 param2 = (f32[2,4]) parameter(0) // 13
7118 gte2 = f32[2,4] get-tuple-element(param2), index=0 // 14
7119 add = f32[2,4] add(gte2, gte2) // 15
7120 ROOT tuple2 = (f32[2,4]) tuple(add) // 16
7121 }
7122
7123 while_condition.1 {
7124 param3 = (f32[2,4]) parameter(0) // 5
7125 ROOT cond = pred[] constant(true) // 6
7126 }
7127
7128 while_body.1 {
7129 param4 = (f32[2,4]) parameter(0) // 7
7130 gte1 = f32[2,4] get-tuple-element(param4), index=0 // 8
7131 add1 = f32[2,4] add(gte1, gte1) // 9
7132 tuple1 = (f32[2,4]) tuple(add1) // 10
7133 while = (f32[2,4]) while(tuple1), condition=while_condition.2, body=while_body.2 // 17
7134 gte2 = f32[2,4] get-tuple-element(while), index=0 // 18
7135 add2 = f32[2,4] add(gte2, gte2) // 19
7136 ROOT tuple2 = (f32[2,4]) tuple(add2) // 20
7137 }
7138
7139 ENTRY Entry {
7140 param0 = f32[2,4] parameter(0) // 0
7141 a = f32[2,4] negate(param0) // 1
7142 b = f32[2,4] negate(a) // 2
7143 c = f32[2,4] negate(b) // 3
7144 tuple = (f32[2,4]) tuple(c) // 4
7145 while = (f32[2,4]) while(tuple), condition=while_condition.1, body=while_body.1 // 21
7146 gte1 = f32[2,4] get-tuple-element(while), index=0 // 22
7147 ROOT root = f32[2,4] add(gte1, param0) // 23
7148 }
7149 )";
7150 TF_ASSERT_OK_AND_ASSIGN(auto module,
7151 ParseAndReturnVerifiedModule(hlo_string));
7152
7153 HloCostAnalysis hlo_cost_analysis(ShapeSize);
7154 Options options;
7155 TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
7156 FakeMemorySpaceAssignmentCostAnalysis::Create(
7157 hlo_cost_analysis, *module, options));
7158 CostAnalysisPrefetchIntervalPicker interval_picker(
7159 *cost_analysis,
7160 /*min_overlap_to_async_copy_ratio=*/1.0,
7161 /*preferred_overlap_to_async_copy_ratio=*/2.0,
7162 /*max_overlap_to_mem_size_async_copy_ratio=*/12.0,
7163 /*mem_size_bytes=*/32);
7164
7165 HloInstruction* root = module->entry_computation()->root_instruction();
7166 const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
7167 const Shape& shape = root->operand(1)->shape();
7168
7169 // We expect the root's latest prefetch start time to be before the while loop
7170 // (logical time 4).
7171 EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0,
7172 /*end_time=*/23, &use),
7173 4);
7174 }
7175
TEST_F(CostAnalysisPrefetchIntervalPickerTest,ConsecutiveConditionals)7176 TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) {
7177 // This is a test for b/170668492, where prefetching for consecutive
7178 // conditionals can cause the prefetch to start in the conditional's
7179 // computation.
7180 absl::string_view hlo_string = R"(
7181 HloModule bug, is_scheduled=true
7182
7183 true_computation.0 {
7184 p0 = (f32[3]{0}) parameter(0) // 5
7185 gte = f32[3]{0} get-tuple-element(p0), index=0 // 6
7186 ROOT neg1 = f32[3]{0} negate(gte) // 7
7187 }
7188
7189 false_computation.0 {
7190 p0 = (f32[3]{0}) parameter(0) // 8
7191 gte = f32[3]{0} get-tuple-element(p0), index=0 // 9
7192 ROOT neg2 = f32[3]{0} negate(gte) // 10
7193 }
7194
7195 true_computation.1 {
7196 p0 = (f32[3]{0}) parameter(0) // 12
7197 gte = f32[3]{0} get-tuple-element(p0), index=0 // 13
7198 ROOT neg1 = f32[3]{0} negate(gte) // 14
7199 }
7200
7201 false_computation.1 {
7202 p0 = (f32[3]{0}) parameter(0) // 15
7203 gte = f32[3]{0} get-tuple-element(p0), index=0 // 16
7204 ROOT neg2 = f32[3]{0} negate(gte) // 17
7205 }
7206
7207 ENTRY entry {
7208 p0 = f32[3]{0} parameter(0) // 0
7209 p1 = f32[3]{0} parameter(1) // 1
7210 p2 = pred[] parameter(2) // 2
7211 tuple0 = (f32[3]{0}) tuple(p0) // 3
7212 tuple1 = (f32[3]{0}) tuple(p1) // 4
7213 conditional0 = f32[3]{0} conditional(p2, tuple0, tuple0), true_computation=true_computation.0, false_computation=false_computation.0 // 11
7214 conditional1 = f32[3]{0} conditional(p2, tuple1, tuple1), true_computation=true_computation.1, false_computation=false_computation.1 // 18
7215 ROOT tuple2 = (f32[3]{0}, f32[3]{0}) tuple(conditional0, conditional1) // 19
7216 }
7217 )";
7218 TF_ASSERT_OK_AND_ASSIGN(auto module,
7219 ParseAndReturnVerifiedModule(hlo_string));
7220
7221 HloCostAnalysis hlo_cost_analysis(ShapeSize);
7222 Options options;
7223 TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
7224 FakeMemorySpaceAssignmentCostAnalysis::Create(
7225 hlo_cost_analysis, *module, options));
7226 CostAnalysisPrefetchIntervalPicker interval_picker(
7227 *cost_analysis,
7228 /*min_overlap_to_async_copy_ratio=*/1.0,
7229 /*preferred_overlap_to_async_copy_ratio=*/2.0,
7230 /*max_overlap_to_mem_size_async_copy_ratio=*/12.0,
7231 /*mem_size_bytes=*/32);
7232
7233 LOG(INFO) << module->ToString();
7234
7235 HloInstruction* conditional1 =
7236 module->entry_computation()->GetInstructionWithName("conditional1");
7237 const HloUse use{conditional1, /*operand_number=*/1, /*operand_index=*/{0}};
7238 const Shape& shape =
7239 module->entry_computation()->parameter_instruction(0)->shape();
7240
7241 // Expect that the prefetch to start before conditional0's called
7242 // computations.
7243 EXPECT_LT(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0,
7244 /*end_time=*/11, &use),
7245 5);
7246 }
7247
TEST_F(CostAnalysisPrefetchIntervalPickerTest,EarliestLatestWindowTooSmall)7248 TEST_F(CostAnalysisPrefetchIntervalPickerTest, EarliestLatestWindowTooSmall) {
7249 // This tests the scenario where there is an op that takes a long time (tanh
7250 // in this example) and as a result the earliest and latest times both fall
7251 // inside this long-running op. In this case, we should still return a valid
7252 // prefetch interval just before the long-running op.
7253 absl::string_view hlo_string = R"(
7254 HloModule bug, is_scheduled=true
7255
7256 ENTRY Entry {
7257 param0 = f32[2,4] parameter(0)
7258 negate = f32[2,4] negate(param0)
7259 tanh = f32[2,4] tanh(param0)
7260 ROOT add = f32[2,4] add(tanh, negate)
7261 }
7262 )";
7263 TF_ASSERT_OK_AND_ASSIGN(auto module,
7264 ParseAndReturnVerifiedModule(hlo_string));
7265
7266 HloCostAnalysis hlo_cost_analysis(ShapeSize);
7267 Options options;
7268 TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
7269 FakeMemorySpaceAssignmentCostAnalysis::Create(
7270 hlo_cost_analysis, *module, options));
7271 cost_analysis->SetOverrideForGetInstructionElapsed(
7272 [](const HloInstruction& hlo) {
7273 if (hlo.opcode() == HloOpcode::kTanh) {
7274 return 20.0;
7275 }
7276 return 1.0;
7277 });
7278 CostAnalysisPrefetchIntervalPicker interval_picker(
7279 *cost_analysis,
7280 /*min_overlap_to_async_copy_ratio=*/1.0,
7281 /*preferred_overlap_to_async_copy_ratio=*/2.0,
7282 /*max_overlap_to_mem_size_async_copy_ratio=*/12.0,
7283 /*mem_size_bytes=*/32);
7284
7285 HloInstruction* root = module->entry_computation()->root_instruction();
7286 const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
7287 interval_picker.Begin(use, /*start_time=*/1, /*end_time=*/3);
7288
7289 LOG(INFO) << interval_picker.ToDebugString();
7290 EXPECT_FALSE(interval_picker.Done());
7291 EXPECT_EQ(interval_picker.Next(), 1);
7292 EXPECT_TRUE(interval_picker.Done());
7293 }
7294
7295 } // namespace
7296 } // namespace xla
7297