xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/memory_space_assignment_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/instruction_hoister.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 
23 namespace xla {
24 namespace {
25 
26 namespace op = xla::testing::opcode_matchers;
27 using memory_space_assignment::AsynchronousCopy;
28 using memory_space_assignment::AsynchronousCopyResource;
29 using memory_space_assignment::CostAnalysisPrefetchIntervalPicker;
30 using memory_space_assignment::InstructionCountPrefetchIntervalPicker;
31 using memory_space_assignment::MemorySpaceAssignment;
32 using memory_space_assignment::MemorySpaceAssignmentCostAnalysis;
33 using memory_space_assignment::Options;
34 using memory_space_assignment::PrefetchIntervalPicker;
35 using memory_space_assignment::PresetAssignments;
36 
37 constexpr int64_t kPointerSize = 8;
38 constexpr float kAsyncCopyBandwidth = 100;
39 constexpr float kAlternateMemBandwidth = 1000;
40 constexpr float kBytesPerSecond = 100;
41 constexpr float kFlopsPerSecond = 1000;
42 constexpr float kTranscendentalsPerSecond = 10;
43 
ShapeSize(const Shape & shape)44 int64_t ShapeSize(const Shape& shape) {
45   return ShapeUtil::ByteSizeOf(shape, kPointerSize);
46 }
47 
48 class MemorySpaceAssignmentTest : public HloTestBase,
49                                   public ::testing::WithParamInterface<bool> {
50  protected:
51   // We use the following two memory space values to describe the default (slow
52   // and large) and alternate (fast and small) memory spaces.
53   const int64_t kDefaultMemorySpace = 0;
54   const int64_t kAlternateMemorySpace = 1;
55 
AssignMemorySpaceUsingCostAnalysis(HloModule * module,std::optional<Options> memory_space_assignment_options=std::nullopt)56   std::unique_ptr<PresetAssignments> AssignMemorySpaceUsingCostAnalysis(
57       HloModule* module,
58       std::optional<Options> memory_space_assignment_options = std::nullopt) {
59     HloCostAnalysis::Options cost_options{ShapeSize};
60     cost_options.set_flops_per_second(kFlopsPerSecond);
61     cost_options.set_bytes_per_second(kBytesPerSecond);
62     cost_options.set_transcendentals_per_second(kTranscendentalsPerSecond);
63     HloCostAnalysis hlo_cost_analysis(cost_options);
64     for (HloComputation* computation : module->MakeNonfusionComputations()) {
65       TF_CHECK_OK(computation->Accept(&hlo_cost_analysis));
66     }
67     auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
68 
69     Options options;
70     if (memory_space_assignment_options.has_value()) {
71       options = *memory_space_assignment_options;
72     } else {
73       options.async_copy_bandwidth_bytes_per_second = kAsyncCopyBandwidth;
74       options.alternate_mem_bandwidth_bytes_per_second = kAlternateMemBandwidth;
75     }
76     auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create(
77                              hlo_cost_analysis, options, *module)
78                              .ValueOrDie();
79     CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
80         CostAnalysisPrefetchIntervalPicker(
81             *cost_analysis, /*min_overlap_to_async_copy_ratio=*/0.8,
82             /*preferred_overlap_to_async_copy_ratio=*/1.5,
83             /*max_overlap_to_mem_size_async_copy_ratio=*/10.0,
84             /*mem_size_bytes=*/128));
85     return AssignMemorySpace(
86         module, /*max_outstanding_async_copies=*/-1,
87         MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
88             *cost_analysis, &cache_),
89         &prefetch_interval_picker, memory_space_assignment_options);
90   }
91 
AssignMemorySpace(HloModule * module,int64_t max_outstanding_async_copies=-1,int64_t max_prefetch_interval=10,int64_t min_prefetch_interval=2,std::optional<Options> options=std::nullopt)92   std::unique_ptr<PresetAssignments> AssignMemorySpace(
93       HloModule* module, int64_t max_outstanding_async_copies = -1,
94       int64_t max_prefetch_interval = 10, int64_t min_prefetch_interval = 2,
95       std::optional<Options> options = std::nullopt) {
96     InstructionHoister instruction_hoister;
97     TF_CHECK_OK(instruction_hoister.Run(module).status());
98     InstructionCountPrefetchIntervalPicker prefetch_interval_picker(
99         min_prefetch_interval, max_prefetch_interval);
100     return AssignMemorySpace(module, max_outstanding_async_copies,
101                              /*buffer_interval_compare=*/{},
102                              &prefetch_interval_picker, options);
103   }
104 
AssignMemorySpace(HloModule * module,int64_t max_outstanding_async_copies,std::optional<MemorySpaceAssignment::BufferIntervalCompare> buffer_interval_compare,PrefetchIntervalPicker * prefetch_interval_picker,std::optional<Options> memory_space_assignment_options=std::nullopt)105   std::unique_ptr<PresetAssignments> AssignMemorySpace(
106       HloModule* module, int64_t max_outstanding_async_copies,
107       std::optional<MemorySpaceAssignment::BufferIntervalCompare>
108           buffer_interval_compare,
109       PrefetchIntervalPicker* prefetch_interval_picker,
110       std::optional<Options> memory_space_assignment_options = std::nullopt) {
111     auto size_fn = [](const BufferValue& buffer) {
112       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
113     };
114 
115     auto is_allowed_in_alternate_mem = [](const HloValue& value) {
116       // Check if the value belongs to the entry computation.
117       HloInstruction* instruction = value.instruction();
118       HloComputation* computation = instruction->parent();
119       bool in_entry_computation =
120           (computation == computation->parent()->entry_computation());
121       if (in_entry_computation &&
122           instruction->opcode() == HloOpcode::kParameter) {
123         return false;
124       }
125       return true;
126     };
127 
128     // Only check parameters in default memory if the original module didn't
129     // have the parameters in alternate memory.
130     bool check_parameters_in_default_memory = true;
131     for (const HloInstruction* parameter :
132          module->entry_computation()->parameter_instructions()) {
133       ShapeUtil::ForEachSubshape(
134           parameter->shape(),
135           [&](const Shape& subshape, const ShapeIndex& /*index*/) {
136             if (subshape.has_layout() &&
137                 subshape.layout().memory_space() == kAlternateMemorySpace) {
138               check_parameters_in_default_memory = false;
139             }
140           });
141     }
142 
143     Options options;
144     if (memory_space_assignment_options) {
145       options = *memory_space_assignment_options;
146     } else {
147       options.max_size_in_bytes = 128;
148       options.alignment_in_bytes = 8;
149       options.verify = true;
150     }
151 
152     options.alternate_memory_space = kAlternateMemorySpace;
153     options.buffer_interval_compare = buffer_interval_compare;
154     options.prefetch_interval_picker = prefetch_interval_picker;
155     options.size_fn = size_fn;
156     if (options.is_allowed_in_alternate_mem_fn == nullptr) {
157       options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
158     }
159     options.max_outstanding_prefetches = max_outstanding_async_copies;
160     options.max_outstanding_evictions = max_outstanding_async_copies;
161     options.allocate_across_sequential_calls = GetParam();
162 
163     auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
164     std::unique_ptr<HloLiveRange> hlo_live_range =
165         HloLiveRange::Run(module->schedule(), *alias_analysis,
166                           module->entry_computation())
167             .ValueOrDie();
168 
169     std::unique_ptr<PresetAssignments> preset_assignments =
170         MemorySpaceAssignment::Run(module, *hlo_live_range, *alias_analysis,
171                                    options)
172             .ValueOrDie();
173     if (check_parameters_in_default_memory) {
174       CheckParametersInDefaultMemory(module);
175     }
176     CheckRootInDefaultMemory(module);
177     CheckPresetAssignments(preset_assignments.get());
178     return preset_assignments;
179   }
180 
CheckPresetAssignments(const PresetAssignments * preset_assignments)181   void CheckPresetAssignments(const PresetAssignments* preset_assignments) {
182     // Ensure that the exported preset assignments point to layouts in the
183     // alternate memory.  Also ensure that the positions are unique. Note that
184     // we're using a std::set instead of absl::flat_hash_set because we can make
185     // use of HloPosition's comparator logic instead of providing a hasher.
186     std::set<HloPosition> positions_in_preset_assignments;
187     for (auto& position_and_chunk : preset_assignments->chunks()) {
188       HloPosition position = position_and_chunk.first;
189       EXPECT_EQ(positions_in_preset_assignments.find(position),
190                 positions_in_preset_assignments.end());
191       positions_in_preset_assignments.insert(position);
192       const Shape& subshape =
193           ShapeUtil::GetSubshape(position.instruction->shape(), position.index);
194       EXPECT_EQ(subshape.layout().memory_space(), kAlternateMemorySpace)
195           << "Exported position is not in alternate mem: "
196           << position.ToString();
197     }
198   }
199 
CheckParametersInDefaultMemory(const HloModule * module)200   void CheckParametersInDefaultMemory(const HloModule* module) {
201     // Check that all the entry parameter subshapes are placed in default
202     // memory.
203     const HloComputation* entry_computation = module->entry_computation();
204     for (const HloInstruction* parameter :
205          entry_computation->parameter_instructions()) {
206       ShapeUtil::ForEachSubshape(
207           parameter->shape(),
208           [&](const Shape& subshape, const ShapeIndex& /*index*/) {
209             if (subshape.has_layout()) {
210               EXPECT_NE(subshape.layout().memory_space(), kAlternateMemorySpace)
211                   << "Parameter not in default memory: "
212                   << parameter->ToString();
213             }
214           });
215     }
216   }
217 
CheckRootInDefaultMemory(const HloModule * module)218   void CheckRootInDefaultMemory(const HloModule* module) {
219     const HloInstruction* root =
220         module->entry_computation()->root_instruction();
221     if (root->shape().IsArray()) {
222       EXPECT_EQ(root->shape().layout().memory_space(), kDefaultMemorySpace);
223     }
224   }
225 
226   struct OutstandingAsyncCopies {
227     int64_t max_copies;
228     int64_t max_prefetches;
229     int64_t max_evictions;
230   };
231 
CountMaximumOutstandingAsyncCopies(const HloModule & module)232   /*static*/ OutstandingAsyncCopies CountMaximumOutstandingAsyncCopies(
233       const HloModule& module) {
234     OutstandingAsyncCopies copies{0, 0, 0};
235     int64_t current_copies = 0;
236     int64_t current_prefetches = 0;
237     int64_t current_evictions = 0;
238     for (HloInstruction* instruction : module.schedule()
239                                            .sequence(module.entry_computation())
240                                            .instructions()) {
241       if (instruction->opcode() == HloOpcode::kCopyStart) {
242         current_copies++;
243         if (ShapeUtil::GetSubshape(instruction->shape(), {0})
244                 .layout()
245                 .memory_space() == kAlternateMemorySpace) {
246           current_prefetches++;
247         } else {
248           current_evictions++;
249         }
250       } else if (instruction->opcode() == HloOpcode::kCopyDone) {
251         current_copies--;
252         if (instruction->shape().layout().memory_space() ==
253             kAlternateMemorySpace) {
254           current_prefetches--;
255         } else {
256           current_evictions--;
257         }
258       }
259       copies.max_copies = std::max(copies.max_copies, current_copies);
260       copies.max_prefetches =
261           std::max(copies.max_prefetches, current_prefetches);
262       copies.max_prefetches = std::max(copies.max_evictions, current_evictions);
263     }
264     return copies;
265   }
266 
GetAlternateMemoryOffset(const PresetAssignments & preset_assignments,const HloInstruction * instruction,const ShapeIndex & index={}) const267   int64_t GetAlternateMemoryOffset(const PresetAssignments& preset_assignments,
268                                    const HloInstruction* instruction,
269                                    const ShapeIndex& index = {}) const {
270     // Returns the offset of the assignment, -1 if it's not in the alternate
271     // memory.
272     const HloModule* module = instruction->parent()->parent();
273     auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
274     HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index);
275     for (auto& pos_and_chunk : preset_assignments.chunks()) {
276       for (auto& value : buffer.values()) {
277         if (pos_and_chunk.first == value->defining_position()) {
278           return pos_and_chunk.second.offset;
279         }
280       }
281     }
282     return -1;
283   }
284 
CreateEvictAndPrefetchModule()285   std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
286     HloComputation::Builder builder(TestName());
287     Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
288     HloInstruction* p0 =
289         builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
290     HloInstruction* p1 =
291         builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
292     HloInstruction* tanh = builder.AddInstruction(
293         HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
294     // tanh should be placed in the alternate memory since there isn't much
295     // contention in the beginning. However, tanh has another consumer at the
296     // end. So it should be kicked out to default memory and prefetched back in.
297     // The graph below is meant to increase the contention to force
298     // eviction/prefetch behavior.
299     HloInstruction* a = builder.AddInstruction(
300         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
301     HloInstruction* b = builder.AddInstruction(
302         HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
303     HloInstruction* c = builder.AddInstruction(
304         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
305     HloInstruction* d = builder.AddInstruction(
306         HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
307     HloInstruction* e = builder.AddInstruction(
308         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
309     HloInstruction* f = builder.AddInstruction(
310         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
311     HloInstruction* g = builder.AddInstruction(
312         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
313     HloInstruction* h = builder.AddInstruction(
314         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
315     HloInstruction* i = builder.AddInstruction(
316         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
317     HloInstruction* j = builder.AddInstruction(
318         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
319     HloInstruction* k = builder.AddInstruction(
320         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
321     HloInstruction* l = builder.AddInstruction(
322         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
323     HloInstruction* m = builder.AddInstruction(
324         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
325     HloInstruction* n = builder.AddInstruction(
326         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
327     HloInstruction* o = builder.AddInstruction(
328         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
329     // tanh is being used at the root instruction, and this should be
330     // prefetched.
331     HloInstruction* add = builder.AddInstruction(
332         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
333 
334     auto module = CreateNewVerifiedModule();
335     HloComputation* computation = module->AddEntryComputation(builder.Build());
336 
337     HloSchedule schedule(module.get());
338     schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i,
339                                         j, k, l, m, n, o, add});
340     TF_CHECK_OK(module->set_schedule(schedule));
341     return module;
342   }
343 
344   MemorySpaceAssignmentCostAnalysis::Cache cache_;
345 };
346 
347 // For testing purposes, we define a cost analysis where we can control the
348 // elapsed times of each HLO and asynchronous copy.
349 class FakeMemorySpaceAssignmentCostAnalysis
350     : public MemorySpaceAssignmentCostAnalysis {
351  public:
352   static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis & cost_analysis,const HloModule & module,const Options & options)353   Create(const HloCostAnalysis& cost_analysis, const HloModule& module,
354          const Options& options) {
355     TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
356     TF_ASSIGN_OR_RETURN(auto hlo_live_range,
357                         HloLiveRange::Run(module.schedule(), *alias_analysis,
358                                           module.entry_computation()));
359     auto call_graph = CallGraph::Build(&module);
360     return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
361         cost_analysis, options, std::move(alias_analysis),
362         std::move(hlo_live_range), std::move(call_graph)));
363   }
364 
GetInstructionElapsed(const HloInstruction & instruction) const365   float GetInstructionElapsed(
366       const HloInstruction& instruction) const override {
367     if (get_instruction_elapsed_override_) {
368       return get_instruction_elapsed_override_(instruction);
369     }
370     return 1.0;
371   }
372 
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,absl::Span<const std::pair<int64_t,ShapeIndex>> operands_in_alternate_mem,absl::Span<const ShapeIndex> outputs_in_alternate_mem) const373   float GetInstructionElapsedInAlternateMemory(
374       const HloInstruction& instruction,
375       absl::Span<const std::pair<int64_t, ShapeIndex>>
376           operands_in_alternate_mem,
377       absl::Span<const ShapeIndex> outputs_in_alternate_mem) const override {
378     if (get_instruction_elapsed_in_alternate_memory_override_) {
379       return get_instruction_elapsed_in_alternate_memory_override_(
380           instruction, operands_in_alternate_mem, outputs_in_alternate_mem);
381     }
382     if (!operands_in_alternate_mem.empty()) {
383       return 0.5;
384     } else {
385       return 1.0;
386     }
387   }
388 
GetAsyncCopyElapsed(const Shape & shape) const389   float GetAsyncCopyElapsed(const Shape& shape) const override {
390     if (get_async_copy_elapsed_override_) {
391       return get_async_copy_elapsed_override_(shape);
392     }
393     return 3.0;
394   }
395 
396   // The following methods can be used to override what the above API calls
397   // return.
SetOverrideForGetInstructionElapsed(std::function<float (const HloInstruction &)> function)398   void SetOverrideForGetInstructionElapsed(
399       std::function<float(const HloInstruction&)> function) {
400     get_instruction_elapsed_override_ = function;
401   }
SetOverrideForGetInstructionElapsedInAlternateMemory(std::function<float (const HloInstruction &,absl::Span<const std::pair<int64_t,ShapeIndex>>,absl::Span<const ShapeIndex>)> function)402   void SetOverrideForGetInstructionElapsedInAlternateMemory(
403       std::function<float(const HloInstruction&,
404                           absl::Span<const std::pair<int64_t, ShapeIndex>>,
405                           absl::Span<const ShapeIndex>)>
406           function) {
407     get_instruction_elapsed_in_alternate_memory_override_ = function;
408   }
SetOverrideForGetAsyncCopyElapsed(std::function<float (const Shape &)> function)409   void SetOverrideForGetAsyncCopyElapsed(
410       std::function<float(const Shape&)> function) {
411     get_async_copy_elapsed_override_ = function;
412   }
413 
414  protected:
FakeMemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,const Options & options,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range,std::unique_ptr<CallGraph> call_graph)415   FakeMemorySpaceAssignmentCostAnalysis(
416       const HloCostAnalysis& cost_analysis, const Options& options,
417       std::unique_ptr<HloAliasAnalysis> alias_analysis,
418       std::unique_ptr<HloLiveRange> hlo_live_range,
419       std::unique_ptr<CallGraph> call_graph)
420       : MemorySpaceAssignmentCostAnalysis(
421             cost_analysis, options, std::move(alias_analysis),
422             std::move(hlo_live_range), std::move(call_graph)) {}
423 
424  private:
425   std::function<float(const HloInstruction&)>
426       get_instruction_elapsed_override_ = nullptr;
427   std::function<float(const HloInstruction&,
428                       absl::Span<const std::pair<int64_t, ShapeIndex>>,
429                       absl::Span<const ShapeIndex>)>
430       get_instruction_elapsed_in_alternate_memory_override_ = nullptr;
431   std::function<float(const Shape&)> get_async_copy_elapsed_override_ = nullptr;
432 };
433 
TEST_P(MemorySpaceAssignmentTest,ParameterOnly)434 TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
435   // A module consisting of a single parameter. Inputs/outputs are currently
436   // excluded from memory space assignment.
437   HloComputation::Builder builder(TestName());
438   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
439   HloInstruction* p0 =
440       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
441 
442   auto module = CreateNewVerifiedModule();
443   HloComputation* computation = module->AddEntryComputation(builder.Build());
444 
445   HloSchedule schedule(module.get());
446   schedule.set_sequence(computation, {p0});
447   TF_CHECK_OK(module->set_schedule(schedule));
448 
449   AssignMemorySpace(module.get());
450 
451   EXPECT_THAT(p0, op::ShapeWithLayout(shape));
452 }
453 
TEST_P(MemorySpaceAssignmentTest,Simple)454 TEST_P(MemorySpaceAssignmentTest, Simple) {
455   // A simple module with a few simple instructions. Expect this to be
456   // transformed with CopyStart and CopyDone instructions inserted after inputs
457   // and before outputs.
458   HloComputation::Builder builder(TestName());
459   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
460   HloInstruction* p0 =
461       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
462   HloInstruction* p1 =
463       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
464   HloInstruction* add = builder.AddInstruction(
465       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1));
466   HloInstruction* sub = builder.AddInstruction(
467       HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
468   HloInstruction* mul = builder.AddInstruction(
469       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, sub));
470 
471   auto module = CreateNewVerifiedModule();
472   HloComputation* computation = module->AddEntryComputation(builder.Build());
473 
474   HloSchedule schedule(module.get());
475   schedule.set_sequence(computation, {p0, p1, add, sub, mul});
476   TF_CHECK_OK(module->set_schedule(schedule));
477 
478   auto preset_assignments = AssignMemorySpace(module.get());
479 
480   // Inputs and outputs are currently placed in the default memory. Everything
481   // else should be in the alternate memory.
482   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
483       F32, {2, 3},
484       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{},
485       /*tiles=*/{},
486       /*element_size_in_bits=*/0, kAlternateMemorySpace);
487   EXPECT_THAT(p0, op::ShapeWithLayout(shape));
488   EXPECT_THAT(p1, op::ShapeWithLayout(shape));
489   EXPECT_THAT(mul, op::ShapeWithLayout(shape));
490   EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
491   EXPECT_THAT(sub, op::ShapeWithLayout(shape_in_alternate_mem));
492 
493   // Make sure the preset assignments is sane.
494   EXPECT_EQ(preset_assignments->chunks().size(), 3);
495   EXPECT_EQ(preset_assignments->assignment_informations().size(), 1);
496   // Ensure the offset assigned to add and sub are different.
497   EXPECT_NE(preset_assignments->chunks()[0].second.offset,
498             preset_assignments->chunks()[1].second.offset);
499 }
500 
TEST_P(MemorySpaceAssignmentTest,NegateChain)501 TEST_P(MemorySpaceAssignmentTest, NegateChain) {
502   // The negate chain is long enough for asynchronous copy to be inserted
503   // between p1 and add.
504   HloComputation::Builder builder(TestName());
505   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
506   HloInstruction* p0 =
507       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
508   HloInstruction* p1 =
509       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
510   HloInstruction* negate0 = builder.AddInstruction(
511       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
512   HloInstruction* negate1 = builder.AddInstruction(
513       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
514   HloInstruction* negate2 = builder.AddInstruction(
515       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
516   HloInstruction* negate3 = builder.AddInstruction(
517       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
518   HloInstruction* negate4 = builder.AddInstruction(
519       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
520   HloInstruction* negate5 = builder.AddInstruction(
521       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
522   HloInstruction* negate6 = builder.AddInstruction(
523       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
524   HloInstruction* add = builder.AddInstruction(
525       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
526 
527   auto module = CreateNewVerifiedModule();
528   HloComputation* computation = module->AddEntryComputation(builder.Build());
529 
530   HloSchedule schedule(module.get());
531   schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
532                                       negate3, negate4, negate5, negate6, add});
533   TF_CHECK_OK(module->set_schedule(schedule));
534 
535   AssignMemorySpace(module.get());
536 
537   EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
538                                                        kDefaultMemorySpace,
539                                                        op::Parameter(1))));
540   // Parameters are in the default memory space.
541   EXPECT_THAT(p0, op::ShapeWithLayout(shape));
542   EXPECT_THAT(p1, op::ShapeWithLayout(shape));
543   // Negate instructions are in the alternate memory space (1).
544   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
545       F32, {2, 3},
546       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{},
547       /*tiles=*/{},
548       /*element_size_in_bits=*/0, kAlternateMemorySpace);
549   EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
550   EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
551   EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
552   EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
553   EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
554   EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem));
555   EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
556   // Ensure the CopyStart/CopyDone schedules.
557   const HloInstructionSequence& sequence =
558       module->schedule().sequence(computation);
559   EXPECT_THAT(sequence.instructions()[0], op::Parameter(0));
560   EXPECT_THAT(sequence.instructions()[1], op::Parameter(1));
561   EXPECT_THAT(sequence.instructions()[2], op::CopyStart());
562   EXPECT_THAT(sequence.instructions()[10], op::CopyDone());
563 }
564 
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetch)565 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetch) {
566   std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
567 
568   AssignMemorySpace(module.get());
569 
570   EXPECT_THAT(
571       module->entry_computation()->root_instruction(),
572       op::Add(op::Add(),
573               op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
574                             op::AsyncCopy(kDefaultMemorySpace,
575                                           kAlternateMemorySpace, op::Tanh()))));
576 }
577 
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies0)578 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) {
579   std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
580 
581   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0);
582 
583   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 0);
584   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 0);
585 }
586 
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies1)587 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) {
588   std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
589 
590   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
591 
592   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 1);
593   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 1);
594 }
595 
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies2)596 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) {
597   std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
598 
599   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2);
600 
601   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 2);
602   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 2);
603 }
604 
605 // TODO(berkin): This test is broken with some prefetch timing improvements.
TEST_P(MemorySpaceAssignmentTest,DISABLED_DontEvictWhenThereIsDefaultMemAllocation)606 TEST_P(MemorySpaceAssignmentTest,
607        DISABLED_DontEvictWhenThereIsDefaultMemAllocation) {
608   // This test is the same as EvictAndPrefetchLimitAsyncCopies1, except we check
609   // that there is no eviction if not necessary (due to an existing allocation
610   // in default memory).
611   HloComputation::Builder builder(TestName());
612   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
613   HloInstruction* p0 =
614       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
615   HloInstruction* p1 =
616       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
617   HloInstruction* tanh = builder.AddInstruction(
618       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
619   // tanh should be placed in the alternate memory since there isn't much
620   // contention in the beginning. However, tanh has another consumer at the end.
621   // So it should be kicked out to default memory and prefetched back in.  The
622   // graph below is meant to increase the contention to force eviction/prefetch
623   // behavior.
624   HloInstruction* a = builder.AddInstruction(
625       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
626   HloInstruction* b = builder.AddInstruction(
627       HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
628   HloInstruction* c = builder.AddInstruction(
629       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
630   HloInstruction* d = builder.AddInstruction(
631       HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
632   HloInstruction* e = builder.AddInstruction(
633       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
634   HloInstruction* f = builder.AddInstruction(
635       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
636   HloInstruction* g = builder.AddInstruction(
637       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
638   HloInstruction* h = builder.AddInstruction(
639       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
640   HloInstruction* i = builder.AddInstruction(
641       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
642   HloInstruction* j = builder.AddInstruction(
643       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
644   HloInstruction* k = builder.AddInstruction(
645       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
646   HloInstruction* l = builder.AddInstruction(
647       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
648   HloInstruction* m = builder.AddInstruction(
649       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
650   HloInstruction* n = builder.AddInstruction(
651       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
652   HloInstruction* o = builder.AddInstruction(
653       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
654   // tanh is being used at the root instruction, and this should be
655   // prefetched.
656   HloInstruction* add = builder.AddInstruction(
657       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
658 
659   auto module = CreateNewVerifiedModule();
660   HloComputation* computation = module->AddEntryComputation(builder.Build());
661 
662   HloSchedule schedule(module.get());
663   schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i,
664                                       j, k, l, m, n, o, add});
665   TF_CHECK_OK(module->set_schedule(schedule));
666 
667   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
668 
669   // We expect the second argument to multiply is prefetched c.
670   EXPECT_THAT(f, op::Multiply(op::Add(), op::CopyDone()));
671   // We make sure that the second argument to this multiply is not evicted
672   // CopyDone but is the original c.
673   EXPECT_THAT(h, op::Multiply(op::Subtract(), op::Multiply()));
674 }
675 
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchAndPrefetch)676 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchAndPrefetch) {
677   // Test for a memory corruption bug involving evict/prefetch/prefetch pattern,
678   // where the last prefetch copied from the original buffer in alternate buffer
679   // instead of evicted buffer.
680   HloComputation::Builder builder(TestName());
681   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
682   HloInstruction* p0 =
683       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
684   HloInstruction* p1 =
685       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
686   HloInstruction* tanh = builder.AddInstruction(
687       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
688   HloInstruction* a = builder.AddInstruction(
689       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
690   HloInstruction* b = builder.AddInstruction(
691       HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
692   HloInstruction* c = builder.AddInstruction(
693       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
694   HloInstruction* d = builder.AddInstruction(
695       HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
696   HloInstruction* e = builder.AddInstruction(
697       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
698   HloInstruction* f = builder.AddInstruction(
699       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
700   HloInstruction* g = builder.AddInstruction(
701       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
702   HloInstruction* h = builder.AddInstruction(
703       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
704   HloInstruction* i = builder.AddInstruction(
705       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
706   HloInstruction* j = builder.AddInstruction(
707       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
708   HloInstruction* k = builder.AddInstruction(
709       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
710   HloInstruction* l = builder.AddInstruction(
711       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
712   HloInstruction* m = builder.AddInstruction(
713       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
714   HloInstruction* n = builder.AddInstruction(
715       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
716   HloInstruction* o = builder.AddInstruction(
717       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
718   HloInstruction* add0 = builder.AddInstruction(
719       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
720   HloInstruction* negate0 = builder.AddInstruction(
721       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add0));
722   HloInstruction* negate1 = builder.AddInstruction(
723       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
724   HloInstruction* negate2 = builder.AddInstruction(
725       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
726   HloInstruction* negate3 = builder.AddInstruction(
727       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
728   HloInstruction* negate4 = builder.AddInstruction(
729       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
730   HloInstruction* negate5 = builder.AddInstruction(
731       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
732   HloInstruction* negate6 = builder.AddInstruction(
733       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
734   HloInstruction* negate7 = builder.AddInstruction(
735       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
736   HloInstruction* negate8 = builder.AddInstruction(
737       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7));
738   HloInstruction* negate9 = builder.AddInstruction(
739       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8));
740   HloInstruction* add1 = builder.AddInstruction(
741       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate9, tanh));
742 
743   auto module = CreateNewVerifiedModule();
744   HloComputation* computation = module->AddEntryComputation(builder.Build());
745 
746   HloSchedule schedule(module.get());
747   schedule.set_sequence(
748       computation,
749       {p0,      p1,      tanh,    a,       b,       c,       d,       e,
750        f,       g,       h,       i,       j,       k,       l,       m,
751        n,       o,       add0,    negate0, negate1, negate2, negate3, negate4,
752        negate5, negate6, negate7, negate8, negate9, add1});
753   TF_CHECK_OK(module->set_schedule(schedule));
754 
755   AssignMemorySpace(module.get());
756 
757   // Check that both prefetches (add0 and add1) prefetch from the eviction
758   // instead of tanh, which will be placed in the alternate memory directly.
759   EXPECT_THAT(
760       add0,
761       op::Add(op::Add(),
762               op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
763                             op::AsyncCopy(kDefaultMemorySpace,
764                                           kAlternateMemorySpace, op::Tanh()))));
765   EXPECT_THAT(
766       add1,
767       op::Add(op::Negate(),
768               op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
769                             op::AsyncCopy(kDefaultMemorySpace,
770                                           kAlternateMemorySpace, op::Tanh()))));
771 }
772 
TEST_P(MemorySpaceAssignmentTest,While)773 TEST_P(MemorySpaceAssignmentTest, While) {
774   auto module = CreateNewVerifiedModule();
775   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
776   Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
777   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
778 
779   auto cond_builder = HloComputation::Builder("WhileCond");
780   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
781   HloInstruction* cond_param = cond_builder.AddInstruction(
782       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
783   HloInstruction* cond_iter = cond_builder.AddInstruction(
784       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
785   HloInstruction* cond_limit = cond_builder.AddInstruction(
786       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
787   // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
788   HloInstruction* cond_lt = cond_builder.AddInstruction(
789       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
790                                     cond_limit, ComparisonDirection::kLt));
791   HloComputation* cond_computation =
792       module->AddEmbeddedComputation(cond_builder.Build());
793 
794   auto body_builder = HloComputation::Builder("WhileBody");
795   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
796   HloInstruction* body_param = body_builder.AddInstruction(
797       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
798   HloInstruction* body_iter = body_builder.AddInstruction(
799       HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
800   HloInstruction* body_data = body_builder.AddInstruction(
801       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
802   HloInstruction* body_iter_increment = body_builder.AddInstruction(
803       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
804   HloInstruction* body_iter_next =
805       body_builder.AddInstruction(HloInstruction::CreateBinary(
806           scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
807   HloInstruction* body_data_increment =
808       body_builder.AddInstruction(HloInstruction::CreateConstant(
809           LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
810   HloInstruction* body_data_mul =
811       body_builder.AddInstruction(HloInstruction::CreateBinary(
812           shape, HloOpcode::kMultiply, body_data, body_data));
813   HloInstruction* body_data_add =
814       body_builder.AddInstruction(HloInstruction::CreateBinary(
815           shape, HloOpcode::kAdd, body_data, body_data_increment));
816   HloInstruction* body_data_next =
817       body_builder.AddInstruction(HloInstruction::CreateBinary(
818           shape, HloOpcode::kAdd, body_data_add, body_data_mul));
819   HloInstruction* body_out = body_builder.AddInstruction(
820       HloInstruction::CreateTuple({body_data_next, body_iter_next}));
821   HloComputation* body_computation =
822       module->AddEmbeddedComputation(body_builder.Build());
823 
824   auto builder = HloComputation::Builder(TestName());
825   HloInstruction* data = builder.AddInstruction(
826       HloInstruction::CreateParameter(0, shape, "param_iter"));
827   HloInstruction* iter = builder.AddInstruction(
828       HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
829   HloInstruction* tuple =
830       builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
831   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
832       tuple_shape, cond_computation, body_computation, tuple));
833   HloComputation* entry_computation =
834       module->AddEntryComputation(builder.Build());
835 
836   HloSchedule schedule(module.get());
837   schedule.set_sequence(cond_computation,
838                         {cond_param, cond_iter, cond_limit, cond_lt});
839   schedule.set_sequence(body_computation,
840                         {body_param, body_iter, body_data, body_iter_increment,
841                          body_iter_next, body_data_increment, body_data_mul,
842                          body_data_add, body_data_next, body_out});
843   schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
844   TF_CHECK_OK(module->set_schedule(schedule));
845 
846   AssignMemorySpace(module.get());
847 
848   // Ensure the tuple value and buffers used in the while instruction are
849   // exempted from using the alternate memory when allocating across sequential
850   // calls is disabled. However, body_data_mul is independent and can be safely
851   // be placed in the alternate memory.
852   const bool allocate_across_sequential_calls = GetParam();
853   if (!allocate_across_sequential_calls) {
854     EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape));
855     EXPECT_THAT(data, op::ShapeWithLayout(shape));
856     EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape));
857     EXPECT_THAT(body_data, op::ShapeWithLayout(shape));
858     EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape));
859     EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape));
860   }
861   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
862       F32, {2, 3},
863       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
864       /*element_size_in_bits=*/0, kAlternateMemorySpace);
865   EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem));
866 }
867 
TEST_P(MemorySpaceAssignmentTest,Tuple)868 TEST_P(MemorySpaceAssignmentTest, Tuple) {
869   HloComputation::Builder builder(TestName());
870   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
871   Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape});
872   Shape tuple_shape =
873       ShapeUtil::MakeTupleShape({shape, shape, inner_tuple_shape});
874   HloInstruction* p = builder.AddInstruction(
875       HloInstruction::CreateParameter(0, tuple_shape, "p"));
876   HloInstruction* p0 = builder.AddInstruction(
877       HloInstruction::CreateGetTupleElement(shape, p, 0));
878   HloInstruction* negate0 = builder.AddInstruction(
879       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
880   HloInstruction* negate1 = builder.AddInstruction(
881       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
882   HloInstruction* negate2 = builder.AddInstruction(
883       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
884   HloInstruction* negate3 = builder.AddInstruction(
885       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
886   HloInstruction* negate4 = builder.AddInstruction(
887       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
888   HloInstruction* negate5 = builder.AddInstruction(
889       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
890   HloInstruction* negate6 = builder.AddInstruction(
891       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
892   HloInstruction* p1 = builder.AddInstruction(
893       HloInstruction::CreateGetTupleElement(shape, p, 1));
894   HloInstruction* add = builder.AddInstruction(
895       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
896   HloInstruction* p2 = builder.AddInstruction(
897       HloInstruction::CreateGetTupleElement(inner_tuple_shape, p, 2));
898   HloInstruction* p2_0 = builder.AddInstruction(
899       HloInstruction::CreateGetTupleElement(shape, p2, 0));
900   HloInstruction* mul = builder.AddInstruction(
901       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, p2_0));
902 
903   auto module = CreateNewVerifiedModule();
904   HloComputation* computation = module->AddEntryComputation(builder.Build());
905 
906   HloSchedule schedule(module.get());
907   schedule.set_sequence(
908       computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
909                     negate6, p1, add, p2, p2_0, mul});
910   TF_CHECK_OK(module->set_schedule(schedule));
911 
912   AssignMemorySpace(module.get());
913 
914   EXPECT_THAT(
915       mul,
916       op::Multiply(op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
917                                                        kDefaultMemorySpace,
918                                                        op::GetTupleElement())),
919                    op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
920                                  op::GetTupleElement(op::GetTupleElement()))));
921 }
922 
TEST_P(MemorySpaceAssignmentTest,Bitcast)923 TEST_P(MemorySpaceAssignmentTest, Bitcast) {
924   // Bitcasts can cause the position in the alternate memory to appear multiple
925   // times in the preset assignments. This test ensure the preset assignments
926   // refer to unique positions.
927   HloComputation::Builder builder(TestName());
928   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
929   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
930   HloInstruction* p0 =
931       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
932   HloInstruction* p1 = builder.AddInstruction(
933       HloInstruction::CreateParameter(1, param_shape, "p1"));
934   HloInstruction* negate = builder.AddInstruction(
935       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
936   HloInstruction* bitcast = builder.AddInstruction(
937       HloInstruction::CreateBitcast(param_shape, negate));
938   HloInstruction* add = builder.AddInstruction(
939       HloInstruction::CreateBinary(param_shape, HloOpcode::kAdd, bitcast, p1));
940 
941   auto module = CreateNewVerifiedModule();
942   HloComputation* computation = module->AddEntryComputation(builder.Build());
943 
944   HloSchedule schedule(module.get());
945   schedule.set_sequence(computation, {p0, p1, negate, bitcast, add});
946   TF_CHECK_OK(module->set_schedule(schedule));
947 
948   AssignMemorySpace(module.get());
949 
950   bitcast = add->mutable_operand(0);
951   EXPECT_EQ(bitcast->opcode(), HloOpcode::kBitcast);
952   EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
953 }
954 
TEST_P(MemorySpaceAssignmentTest,Bitcast2)955 TEST_P(MemorySpaceAssignmentTest, Bitcast2) {
956   HloComputation::Builder builder(TestName());
957   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
958   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
959   HloInstruction* p0 =
960       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
961   HloInstruction* p1 = builder.AddInstruction(
962       HloInstruction::CreateParameter(1, param_shape, "p1"));
963   HloInstruction* negate0 = builder.AddInstruction(
964       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
965   HloInstruction* negate1 = builder.AddInstruction(
966       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
967   HloInstruction* negate2 = builder.AddInstruction(
968       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
969   HloInstruction* negate3 = builder.AddInstruction(
970       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
971   HloInstruction* negate4 = builder.AddInstruction(
972       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
973   HloInstruction* bitcast =
974       builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
975   HloInstruction* add = builder.AddInstruction(
976       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4));
977 
978   auto module = CreateNewVerifiedModule();
979   HloComputation* computation = module->AddEntryComputation(builder.Build());
980 
981   HloSchedule schedule(module.get());
982   schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
983                                       negate3, negate4, bitcast, add});
984   TF_CHECK_OK(module->set_schedule(schedule));
985 
986   AssignMemorySpace(module.get());
987 
988   EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
989             kAlternateMemorySpace);
990 }
991 
TEST_P(MemorySpaceAssignmentTest,Bitcast3)992 TEST_P(MemorySpaceAssignmentTest, Bitcast3) {
993   HloComputation::Builder builder(TestName());
994   Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
995   Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
996   Shape shape3 = ShapeUtil::MakeShape(F32, {1, 6});
997   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
998   HloInstruction* p0 =
999       builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0"));
1000   HloInstruction* p1 = builder.AddInstruction(
1001       HloInstruction::CreateParameter(1, param_shape, "p1"));
1002   HloInstruction* negate0 = builder.AddInstruction(
1003       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, p0));
1004   HloInstruction* negate1 = builder.AddInstruction(
1005       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate0));
1006   HloInstruction* negate2 = builder.AddInstruction(
1007       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate1));
1008   HloInstruction* negate3 = builder.AddInstruction(
1009       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate2));
1010   HloInstruction* negate4 = builder.AddInstruction(
1011       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate3));
1012   HloInstruction* bitcast1 =
1013       builder.AddInstruction(HloInstruction::CreateBitcast(shape1, p1));
1014   HloInstruction* add = builder.AddInstruction(
1015       HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, bitcast1, negate4));
1016   HloInstruction* bitcast2 =
1017       builder.AddInstruction(HloInstruction::CreateBitcast(shape3, p1));
1018   HloInstruction* bitcast3 =
1019       builder.AddInstruction(HloInstruction::CreateBitcast(shape2, bitcast2));
1020   HloInstruction* bitcast4 =
1021       builder.AddInstruction(HloInstruction::CreateBitcast(shape2, add));
1022   HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
1023       shape2, HloOpcode::kMultiply, bitcast3, bitcast4));
1024 
1025   auto module = CreateNewVerifiedModule();
1026   HloComputation* computation = module->AddEntryComputation(builder.Build());
1027 
1028   HloSchedule schedule(module.get());
1029   schedule.set_sequence(computation,
1030                         {p0, p1, negate0, negate1, negate2, negate3, negate4,
1031                          bitcast1, add, bitcast2, bitcast3, bitcast4, mul});
1032   TF_CHECK_OK(module->set_schedule(schedule));
1033 
1034   AssignMemorySpace(module.get());
1035 
1036   // We expect one bitcast on the LHS of multiply since bitcast(bitcast(foo)) is
1037   // converted to bitcast(foo).
1038   EXPECT_THAT(
1039       mul,
1040       op::Multiply(
1041           op::Bitcast(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
1042                                     op::Parameter(1))),
1043           op::Bitcast(op::Add(
1044               op::Bitcast(op::AsyncCopy(kAlternateMemorySpace,
1045                                         kDefaultMemorySpace, op::Parameter(1))),
1046               op::Negate()))));
1047   EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
1048             kAlternateMemorySpace);
1049   EXPECT_EQ(add->shape().layout().memory_space(), kAlternateMemorySpace);
1050   // bitcast2 will no longer have a consumer and should get DCE'd, so we don't
1051   // care about its memory space.
1052   EXPECT_EQ(mul->operand(0)->shape().layout().memory_space(),
1053             kAlternateMemorySpace);
1054   EXPECT_EQ(mul->operand(1)->shape().layout().memory_space(),
1055             kAlternateMemorySpace);
1056 }
1057 
TEST_P(MemorySpaceAssignmentTest,BitcastTuple)1058 TEST_P(MemorySpaceAssignmentTest, BitcastTuple) {
1059   HloComputation::Builder builder(TestName());
1060   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1061   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1062   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
1063 
1064   auto module = CreateNewVerifiedModule();
1065   HloComputation::Builder fusion_builder("fusion");
1066   HloInstruction* fusion_param = fusion_builder.AddInstruction(
1067       HloInstruction::CreateParameter(0, tuple_shape, "p"));
1068   HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
1069       HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
1070   HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
1071       HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
1072   fusion_builder.AddInstruction(HloInstruction::CreateBinary(
1073       shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
1074   HloComputation* fusion_computation =
1075       module->AddEmbeddedComputation(fusion_builder.Build());
1076 
1077   HloInstruction* p0 =
1078       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
1079   HloInstruction* p1 = builder.AddInstruction(
1080       HloInstruction::CreateParameter(1, param_shape, "p1"));
1081   HloInstruction* negate0 = builder.AddInstruction(
1082       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
1083   HloInstruction* negate1 = builder.AddInstruction(
1084       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1085   HloInstruction* negate2 = builder.AddInstruction(
1086       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1087   HloInstruction* negate3 = builder.AddInstruction(
1088       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1089   HloInstruction* negate4 = builder.AddInstruction(
1090       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1091   HloInstruction* bitcast =
1092       builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
1093   HloInstruction* tuple =
1094       builder.AddInstruction(HloInstruction::CreateTuple({bitcast, p0}));
1095   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
1096       shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
1097 
1098   HloComputation* computation = module->AddEntryComputation(builder.Build());
1099 
1100   HloSchedule schedule(module.get());
1101   schedule.set_sequence(computation,
1102                         {p0, p1, negate0, negate1, negate2, negate3, negate4,
1103                          bitcast, tuple, fusion});
1104   TF_CHECK_OK(module->set_schedule(schedule));
1105 
1106   AssignMemorySpace(module.get());
1107 }
1108 
TEST_P(MemorySpaceAssignmentTest,BitcastGetTupleElementTuple)1109 TEST_P(MemorySpaceAssignmentTest, BitcastGetTupleElementTuple) {
1110   // This test pattern was encountered in
1111   // //third_party/tensorflow/compiler/xla/tests:slice_test and was causing a
1112   // breakage when there is a GetTupleElement(Tuple(Bitcast())) pattern. Also
1113   // added a GetTupleElement(GetTupleElement(Tuple(Tuple(Bitcast())))) pattern.
1114   absl::string_view hlo_string = R"(
1115   HloModule DoIt_S64_10_0_5_1.3, is_scheduled=true
1116 
1117   ENTRY %DoIt_S64_10_0_5_1.3 (p0.1: (u32[10], u32[10])) -> (u32[5], u32[5]) {
1118     %p0.1 = (u32[10]{0:T(128)}, u32[10]{0:T(128)}) parameter(0)
1119     %get-tuple-element.1 = u32[10]{0:T(128)} get-tuple-element((u32[10]{0:T(128)}, u32[10]{0:T(128)}) %p0.1), index=1
1120     %bitcast.1 = u32[5]{0:T(128)} bitcast(u32[10]{0:T(128)} %get-tuple-element.1)
1121     %get-tuple-element = u32[10]{0:T(128)} get-tuple-element((u32[10]{0:T(128)}, u32[10]{0:T(128)}) %p0.1), index=0
1122     %bitcast = u32[5]{0:T(128)} bitcast(u32[10]{0:T(128)} %get-tuple-element)
1123     %tuple.1 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) tuple(u32[5]{0:T(128)} %bitcast, u32[5]{0:T(128)} %bitcast.1)
1124     %tuple.3 = ((u32[5]{0:T(128)}, u32[5]{0:T(128)}), (u32[5]{0:T(128)}, u32[5]{0:T(128)})) tuple(%tuple.1, %tuple.1)
1125     %get-tuple-element.4 = u32[5]{0:T(128)} get-tuple-element((u32[5]{0:T(128)}, u32[5]{0:T(128)}) %tuple.1), index=0
1126     %get-tuple-element.5 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) get-tuple-element(%tuple.3), index=0
1127     %get-tuple-element.6 = u32[5]{0:T(128)} get-tuple-element((u32[5]{0:T(128)}, u32[5]{0:T(128)}) %get-tuple-element.5), index=1
1128     %copy.2 = u32[5]{0:T(128)} copy(u32[5]{0:T(128)} %get-tuple-element.4)
1129     %copy.3 = u32[5]{0:T(128)} copy(u32[5]{0:T(128)} %get-tuple-element.6)
1130     ROOT %tuple.2 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) tuple(u32[5]{0:T(128)} %copy.2, u32[5]{0:T(128)} %copy.3)
1131   }
1132   )";
1133 
1134   TF_ASSERT_OK_AND_ASSIGN(auto module,
1135                           ParseAndReturnVerifiedModule(hlo_string));
1136   AssignMemorySpace(module.get());
1137 }
1138 
TEST_P(MemorySpaceAssignmentTest,GetSimplifiedOperandBug)1139 TEST_P(MemorySpaceAssignmentTest, GetSimplifiedOperandBug) {
1140   // Test case for a bug finding Bitcasts in GTE(Tuple(...)) pattern.
1141   absl::string_view hlo_string = R"(
1142   HloModule sort.16, is_scheduled=true
1143 
1144   ENTRY %sort.16 (param.0.1: s32[1], param.1.2: f32[1], param.2.3: u32[1], param.3.4: s32[1]) -> (s32[1], f32[1], u32[1], s32[1]) {
1145     %param.3.4 = s32[1]{0:T(128)} parameter(3)
1146     %param.2.3 = u32[1]{0:T(128)} parameter(2)
1147     %param.1.2 = f32[1]{0:T(128)} parameter(1)
1148     %param.0.1 = s32[1]{0:T(128)} parameter(0)
1149     %tuple.1 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %param.0.1, f32[1]{0:T(128)} %param.1.2, u32[1]{0:T(128)} %param.2.3, s32[1]{0:T(128)} %param.3.4)
1150     %get-tuple-element.4 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=0
1151     %get-tuple-element.5 = f32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=1
1152     %get-tuple-element.6 = u32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=2
1153     %get-tuple-element.7 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=3
1154     %copy.4 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.4)
1155     %copy.5 = f32[1]{0:T(128)} copy(f32[1]{0:T(128)} %get-tuple-element.5)
1156     %copy.6 = u32[1]{0:T(128)} copy(u32[1]{0:T(128)} %get-tuple-element.6)
1157     %copy.7 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.7)
1158     ROOT %tuple.2 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %copy.4, f32[1]{0:T(128)} %copy.5, u32[1]{0:T(128)} %copy.6, s32[1]{0:T(128)} %copy.7)
1159 }
1160   )";
1161 
1162   TF_ASSERT_OK_AND_ASSIGN(auto module,
1163                           ParseAndReturnVerifiedModule(hlo_string));
1164   AssignMemorySpace(module.get());
1165 }
1166 
TEST_P(MemorySpaceAssignmentTest,BitcastMultiUse)1167 TEST_P(MemorySpaceAssignmentTest, BitcastMultiUse) {
1168   // When there is a pattern where a bitcast has multiple uses (negate0 and add)
1169   // and one is in the default memory and the other is in alternate memory, they
1170   // both need their own bitcast.
1171   HloComputation::Builder builder(TestName());
1172   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1173   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1174   HloInstruction* p0 = builder.AddInstruction(
1175       HloInstruction::CreateParameter(0, param_shape, "p1"));
1176   HloInstruction* bitcast =
1177       builder.AddInstruction(HloInstruction::CreateBitcast(shape, p0));
1178   HloInstruction* negate0 = builder.AddInstruction(
1179       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, bitcast));
1180   HloInstruction* negate1 = builder.AddInstruction(
1181       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1182   HloInstruction* negate2 = builder.AddInstruction(
1183       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1184   HloInstruction* negate3 = builder.AddInstruction(
1185       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1186   HloInstruction* negate4 = builder.AddInstruction(
1187       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1188   HloInstruction* add = builder.AddInstruction(
1189       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4));
1190 
1191   auto module = CreateNewVerifiedModule();
1192   HloComputation* computation = module->AddEntryComputation(builder.Build());
1193 
1194   HloSchedule schedule(module.get());
1195   schedule.set_sequence(computation, {p0, bitcast, negate0, negate1, negate2,
1196                                       negate3, negate4, add});
1197   TF_CHECK_OK(module->set_schedule(schedule));
1198 
1199   AssignMemorySpace(module.get());
1200   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
1201       F32, {2, 3},
1202       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
1203       /*element_size_in_bits=*/0, kAlternateMemorySpace);
1204   EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape));
1205   EXPECT_THAT(add->operand(0), op::ShapeWithLayout(shape_in_alternate_mem));
1206 }
1207 
TEST_P(MemorySpaceAssignmentTest,BitcastMultiUseTuple)1208 TEST_P(MemorySpaceAssignmentTest, BitcastMultiUseTuple) {
1209   // Same as BitcastMultUse but the second use is a tuple.
1210   HloComputation::Builder builder(TestName());
1211   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1212   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1213   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
1214 
1215   auto module = CreateNewVerifiedModule();
1216   HloComputation::Builder fusion_builder("fusion");
1217   HloInstruction* fusion_param = fusion_builder.AddInstruction(
1218       HloInstruction::CreateParameter(0, tuple_shape, "p"));
1219   HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
1220       HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
1221   HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
1222       HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
1223   fusion_builder.AddInstruction(HloInstruction::CreateBinary(
1224       shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
1225   HloComputation* fusion_computation =
1226       module->AddEmbeddedComputation(fusion_builder.Build());
1227 
1228   HloInstruction* p0 = builder.AddInstruction(
1229       HloInstruction::CreateParameter(0, param_shape, "p1"));
1230   HloInstruction* bitcast =
1231       builder.AddInstruction(HloInstruction::CreateBitcast(shape, p0));
1232   HloInstruction* negate0 = builder.AddInstruction(
1233       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, bitcast));
1234   HloInstruction* negate1 = builder.AddInstruction(
1235       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1236   HloInstruction* negate2 = builder.AddInstruction(
1237       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1238   HloInstruction* negate3 = builder.AddInstruction(
1239       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1240   HloInstruction* negate4 = builder.AddInstruction(
1241       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1242   HloInstruction* tuple =
1243       builder.AddInstruction(HloInstruction::CreateTuple({bitcast, negate4}));
1244   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
1245       shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
1246 
1247   HloComputation* computation = module->AddEntryComputation(builder.Build());
1248 
1249   HloSchedule schedule(module.get());
1250   schedule.set_sequence(computation, {p0, bitcast, negate0, negate1, negate2,
1251                                       negate3, negate4, tuple, fusion});
1252   TF_CHECK_OK(module->set_schedule(schedule));
1253 
1254   AssignMemorySpace(module.get());
1255   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
1256       F32, {2, 3},
1257       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
1258       /*element_size_in_bits=*/0, kAlternateMemorySpace);
1259   EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape));
1260   EXPECT_THAT(fusion->operand(0)->operand(0),
1261               op::ShapeWithLayout(shape_in_alternate_mem));
1262 }
1263 
TEST_P(MemorySpaceAssignmentTest,BitcastScheduleBug)1264 TEST_P(MemorySpaceAssignmentTest, BitcastScheduleBug) {
1265   // Bitcasts can force asynchronous copies to be scheduled too early, possibly
1266   // leading to memory corruption.
1267   //  Bug:
1268   //    p0------------------>neg-->neg-->neg ... -->neg-->neg-->neg->add
1269   //                                                                 /
1270   //    p1->cs->cd->bitcast-----------------------------------------+
1271   //
1272   //  Expected:
1273   //    p0-->neg-->neg-->neg ... -->neg-->neg-->neg------------->add
1274   //                                                             /
1275   //    p1--------------------->cs----------------->cd->bitcast-+
1276   HloComputation::Builder builder(TestName());
1277   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1278   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1279   HloInstruction* p0 =
1280       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
1281   HloInstruction* p1 = builder.AddInstruction(
1282       HloInstruction::CreateParameter(1, param_shape, "p1"));
1283   HloInstruction* bitcast =
1284       builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
1285   HloInstruction* negate0 = builder.AddInstruction(
1286       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
1287   HloInstruction* negate1 = builder.AddInstruction(
1288       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1289   HloInstruction* negate2 = builder.AddInstruction(
1290       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1291   HloInstruction* negate3 = builder.AddInstruction(
1292       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1293   HloInstruction* negate4 = builder.AddInstruction(
1294       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1295   HloInstruction* negate5 = builder.AddInstruction(
1296       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
1297   HloInstruction* negate6 = builder.AddInstruction(
1298       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
1299   HloInstruction* negate7 = builder.AddInstruction(
1300       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
1301   HloInstruction* negate8 = builder.AddInstruction(
1302       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7));
1303   HloInstruction* negate9 = builder.AddInstruction(
1304       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8));
1305   HloInstruction* add = builder.AddInstruction(
1306       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate9));
1307 
1308   auto module = CreateNewVerifiedModule();
1309   HloComputation* computation = module->AddEntryComputation(builder.Build());
1310 
1311   HloSchedule schedule(module.get());
1312   schedule.set_sequence(
1313       computation, {p0, p1, bitcast, negate0, negate1, negate2, negate3,
1314                     negate4, negate5, negate6, negate7, negate8, negate9, add});
1315   TF_CHECK_OK(module->set_schedule(schedule));
1316 
1317   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
1318                     /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/4);
1319 
1320   EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
1321             kAlternateMemorySpace);
1322   const auto& instructions =
1323       module->schedule().sequence(module->entry_computation()).instructions();
1324   for (int i = 0; i < instructions.size(); ++i) {
1325     // Expect that there is a negate before and after the CopyStart and there is
1326     // a negate before CopyDone.
1327     if (instructions.at(i)->opcode() == HloOpcode::kCopyStart) {
1328       EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate);
1329       EXPECT_EQ(instructions.at(i + 1)->opcode(), HloOpcode::kNegate);
1330     } else if (instructions.at(i)->opcode() == HloOpcode::kCopyDone) {
1331       EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate);
1332     }
1333   }
1334 }
1335 
TEST_P(MemorySpaceAssignmentTest,AddDependency)1336 TEST_P(MemorySpaceAssignmentTest, AddDependency) {
1337   // Make sure add-dependency is not optimized away.
1338   absl::string_view hlo_string = R"(
1339   HloModule AddDependency, is_scheduled=true
1340 
1341   ENTRY %AddDependency (p: f32[3]) -> f32[3] {
1342     %p = f32[3]{0} parameter(0)
1343     %neg0 = f32[3]{0} negate(f32[3]{0} %p)
1344     %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
1345     %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
1346     %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
1347     %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
1348     %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
1349     %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
1350     %token0 = token[] after-all()
1351     %add_dep = f32[3]{0} add-dependency(f32[3]{0} %p, token[] %token0)
1352     ROOT %add = f32[3]{0} add(f32[3]{0} %add_dep, f32[3]{0} %neg6)
1353   }
1354   )";
1355 
1356   TF_ASSERT_OK_AND_ASSIGN(auto module,
1357                           ParseAndReturnVerifiedModule(hlo_string));
1358   AssignMemorySpace(module.get());
1359 
1360   EXPECT_THAT(module->entry_computation()->root_instruction(),
1361               op::Add(op::AddDependency(), op::Negate()));
1362 }
1363 
TEST_P(MemorySpaceAssignmentTest,WhileAllocationBug)1364 TEST_P(MemorySpaceAssignmentTest, WhileAllocationBug) {
1365   // This test is carefully crafted to include two multiply ops sized [4,3] in a
1366   // while body. For testing purposes, we have provided a BufferIntervalCompare
1367   // such that first multiply, then tanh, then other HloValues will be
1368   // allocated. The memory is sized just enough to fit two [4,3] buffers.
1369   // Because the multiplies in the while body are going to be allocated in the
1370   // alternate memory first, the tanh that is fed inside the while loop should
1371   // not be placed in the alternate memory. Otherwise, we will corrupt memory.
1372   absl::string_view hlo_string = R"(
1373   HloModule WhileAllocationBug, is_scheduled=true
1374 
1375   %WhileBody (body_param: (f32[4,3], f32[])) -> (f32[4,3], f32[]) {
1376     %body_param = (f32[4,3]{1,0}, f32[]) parameter(0)
1377     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[]) %body_param), index=1
1378     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[]) %body_param), index=0
1379     %constant.1 = f32[] constant(1)
1380     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1381     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1382     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.2)
1383     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1384     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1385     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1386     ROOT %tuple = (f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[] %add)
1387   }
1388 
1389   %WhileCond (cond_param: (f32[4,3], f32[])) -> pred[] {
1390     %cond_param = (f32[4,3]{1,0}, f32[]) parameter(0)
1391     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[]) %cond_param), index=1
1392     %constant = f32[] constant(50)
1393     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1394   }
1395 
1396   ENTRY %Entry (param_iter: f32[4,3], param_data: f32[], p2: f32[4,3]) -> f32[4,3] {
1397     %param_data = f32[] parameter(1)
1398     %param_iter = f32[4,3]{1,0} parameter(0)
1399     %p2 = f32[4,3]{1,0} parameter(2)
1400     %tanh = f32[4,3]{1,0} tanh(f32[4,3]{1,0} %param_iter)
1401     %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1402     %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1403     %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1404     %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1405     %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1406     %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1407     %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1408     %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %tanh)
1409     %tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %tanh, f32[] %param_data)
1410     %while = (f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1411     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[]) %while), index=0
1412     ROOT %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.3, f32[4,3]{1,0} %add.4)
1413   }
1414   )";
1415 
1416   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
1417       [](const MemorySpaceAssignment::BufferInterval& a,
1418          const MemorySpaceAssignment::BufferInterval& b) {
1419         bool a_is_mul =
1420             a.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply;
1421         bool b_is_mul =
1422             b.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply;
1423         if (a_is_mul && !b_is_mul) {
1424           return true;
1425         }
1426         if (!a_is_mul && b_is_mul) {
1427           return false;
1428         }
1429         bool a_is_tanh =
1430             a.buffer->defining_instruction()->opcode() == HloOpcode::kTanh;
1431         bool b_is_tanh =
1432             b.buffer->defining_instruction()->opcode() == HloOpcode::kTanh;
1433         if (a_is_tanh && !b_is_tanh) {
1434           return true;
1435         }
1436         if (!a_is_tanh && b_is_tanh) {
1437           return false;
1438         }
1439         return a.buffer->id() < b.buffer->id();
1440       };
1441   TF_ASSERT_OK_AND_ASSIGN(auto module,
1442                           ParseAndReturnVerifiedModule(hlo_string));
1443 
1444   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
1445   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
1446                     buffer_interval_compare, &prefetch_interval_picker);
1447 
1448   for (const HloInstruction* instruction :
1449        module->entry_computation()->instructions()) {
1450     if (instruction->opcode() == HloOpcode::kWhile) {
1451       const Shape& while_subshape =
1452           ShapeUtil::GetSubshape(instruction->shape(), {0});
1453       // We expect shape {0} to either be in default memory for the entire while
1454       // loop or there has to be an eviction within the while loop.
1455       if (while_subshape.layout().memory_space() == kAlternateMemorySpace) {
1456         const HloInstruction* body_param =
1457             instruction->while_body()->parameter_instruction(0);
1458         const HloInstruction* gte = nullptr;
1459         for (const HloInstruction* user : body_param->users()) {
1460           if (user->opcode() == HloOpcode::kGetTupleElement &&
1461               user->tuple_index() == 0) {
1462             gte = user;
1463             break;
1464           }
1465         }
1466         EXPECT_NE(gte, nullptr);
1467         const HloInstruction* copy_start = nullptr;
1468         for (const HloInstruction* user : gte->users()) {
1469           if (user->opcode() == HloOpcode::kCopyStart) {
1470             copy_start = user;
1471             break;
1472           }
1473         }
1474         EXPECT_NE(copy_start, nullptr);
1475         const Shape& copy_start_subshape =
1476             ShapeUtil::GetSubshape(copy_start->shape(), {0});
1477 
1478         EXPECT_NE(copy_start_subshape.layout().memory_space(),
1479                   kAlternateMemorySpace);
1480       }
1481     }
1482   }
1483 }
1484 
TEST_P(MemorySpaceAssignmentTest,ConsecutiveWhileLoops)1485 TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoops) {
1486   absl::string_view hlo_string = R"(
1487   HloModule WhileAllocationBug, is_scheduled=true
1488 
1489   %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1490     %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1491     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1492     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1493     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1494     %constant.1 = f32[] constant(1)
1495     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1496     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1497     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
1498     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1499     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1500     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1501     ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1502   }
1503 
1504   %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1505     %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1506     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1507     %constant = f32[] constant(50)
1508     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1509   }
1510 
1511   %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1512     %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1513     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1514     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1515     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1516     %constant.1 = f32[] constant(1)
1517     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1518     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1519     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
1520     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1521     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1522     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1523     ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1524   }
1525 
1526   %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1527     %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1528     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1529     %constant = f32[] constant(50)
1530     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1531   }
1532 
1533   ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1534     %param_iter = f32[] parameter(1)
1535     %param_data = f32[4,3]{1,0} parameter(0)
1536     %p2 = f32[4,3]{1,0} parameter(2)
1537     %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1538     %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1539     %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1540     %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1541     %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1542     %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1543     %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1544     %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1545     %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1546     %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1547     %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1548     %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1549     %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
1550     %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} get-tuple-element.5, f32[] %param_iter)
1551     %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
1552     %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
1553     ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.6, f32[4,3]{1,0} %add.3)
1554   }
1555   )";
1556 
1557   TF_ASSERT_OK_AND_ASSIGN(auto module,
1558                           ParseAndReturnVerifiedModule(hlo_string));
1559   AssignMemorySpace(module.get());
1560 }
1561 
TEST_P(MemorySpaceAssignmentTest,WhileLiveRangeBug)1562 TEST_P(MemorySpaceAssignmentTest, WhileLiveRangeBug) {
1563   // Tests against while live ranges being incorrect and the verifier
1564   // complaining about a conflict.
1565   absl::string_view hlo_string = R"(
1566   HloModule WhileAllocationBug, is_scheduled=true
1567 
1568   %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1569     %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1570     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1571     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1572     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1573     %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1574     %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1575     %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1576     %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1577     %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1578     %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1579     %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1580     %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1581     %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1582     %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1583     %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1584     %constant.1 = f32[] constant(1)
1585     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1586     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1587     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1588     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1589     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1590     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1591     ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1592   }
1593 
1594   %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1595     %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1596     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1597     %constant = f32[] constant(50)
1598     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1599   }
1600 
1601   ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1602     %param_iter = f32[] parameter(1)
1603     %param_data = f32[4,3]{1,0} parameter(0)
1604     %p2 = f32[4,3]{1,0} parameter(2)
1605     %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1606     %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1607     %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1608     %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1609     %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1610     %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1611     %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1612     %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1613     %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1614     %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1615     %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1616     %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
1617     %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1618     ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %add.3)
1619   }
1620   )";
1621 
1622   TF_ASSERT_OK_AND_ASSIGN(auto module,
1623                           ParseAndReturnVerifiedModule(hlo_string));
1624   AssignMemorySpace(module.get());
1625 }
1626 
TEST_P(MemorySpaceAssignmentTest,ConsecutiveWhileLoopsOneBuffer)1627 TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoopsOneBuffer) {
1628   // Tests against a bug when there are consecutive while loops with one buffer
1629   // (the value doesn't change in the buffer), the parameter can be colored in
1630   // the alternate memory space.
1631   absl::string_view hlo_string = R"(
1632   HloModule WhileAllocationBug, is_scheduled=true
1633 
1634   %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1635     %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1636     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1637     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1638     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1639     %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1640     %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1641     %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1642     %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1643     %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1644     %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1645     %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1646     %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1647     %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1648     %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1649     %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1650     %constant.1 = f32[] constant(1)
1651     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1652     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1653     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1654     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1655     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1656     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1657     ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1658   }
1659 
1660   %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1661     %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1662     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1663     %constant = f32[] constant(50)
1664     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1665   }
1666 
1667   %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1668     %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1669     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1670     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1671     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1672     %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1673     %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1674     %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1675     %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1676     %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1677     %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1678     %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1679     %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1680     %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1681     %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1682     %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1683     %constant.1 = f32[] constant(1)
1684     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1685     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1686     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1687     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1688     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1689     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1690     ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1691   }
1692 
1693   %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1694     %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1695     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1696     %constant = f32[] constant(50)
1697     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1698   }
1699 
1700   ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1701     %param_iter = f32[] parameter(1)
1702     %param_data = f32[4,3]{1,0} parameter(0)
1703     %p2 = f32[4,3]{1,0} parameter(2)
1704     %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1705     %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1706     %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1707     %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1708     %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1709     %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1710     %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1711     %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1712     %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1713     %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1714     %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1715     %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1716     %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} param_data, f32[] %param_iter)
1717     %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
1718     %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
1719     %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=1
1720     ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %get-tuple-element.6)
1721   }
1722   )";
1723 
1724   TF_ASSERT_OK_AND_ASSIGN(auto module,
1725                           ParseAndReturnVerifiedModule(hlo_string));
1726   AssignMemorySpace(module.get());
1727 }
1728 
TEST_P(MemorySpaceAssignmentTest,WhileCondAliasBug)1729 TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
1730   // While loop is the root of the entry computation. We should ensure the
1731   // output of the entry computation remains to be in default memory space.
1732   // Test from //third_party/tensorflow/compiler/xla/tests:while_test
1733   // WhileTest.WhileWithPrngScalarResult.
1734   absl::string_view hlo_string = R"(
1735   HloModule WhileWithPrngScalarResult.18, is_scheduled=true
1736 
1737   %fused_computation (param_0.1: s32[6], param_1.3: s32[1], param_2.3: s32[5]) -> s32[6] {
1738     %param_1.3 = s32[1]{0:T(128)} parameter(1)
1739     %constant.2 = s32[]{:T(128)} constant(-2147483648)
1740     %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
1741     %param_2.3 = s32[5]{0:T(128)} parameter(2)
1742     %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
1743     %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
1744     %param_0.1 = s32[6]{0:T(128)} parameter(0)
1745     ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
1746   }
1747 
1748   %body.3 (prev.4: s32[6]) -> s32[6] {
1749     %constant.7 = s32[]{:T(128)} constant(100)
1750     %constant.6 = s32[]{:T(128)} constant(0)
1751     %constant.5 = s32[1]{0:T(128)} constant({1})
1752     %prev.4 = s32[6]{0:T(128)} parameter(0)
1753     %rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
1754     %neg = s32[1]{0:T(128)} negate(s32[1]{0:T(128)} %constant.5)
1755     ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %neg, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
1756   }
1757 
1758   %WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
1759     %constant.15 = s32[]{:T(128)} constant(1)
1760     %prev.12 = s32[6]{0:T(128)} parameter(0)
1761     %bitcast.1 = s32[1]{0:T(128)} bitcast(s32[6]{0:T(128)} %prev.12)
1762     %bitcast = s32[]{:T(128)} bitcast(s32[1]{0:T(128)} %bitcast.1)
1763     ROOT %compare.16 = pred[]{:T(128)E(32)} compare(s32[]{:T(128)} %constant.15, s32[]{:T(128)} %bitcast), direction=GT
1764   }
1765 
1766   ENTRY %WhileWithPrngScalarResult.18 () -> s32[6] {
1767     %constant.1 = s32[]{:T(128)} constant(0)
1768     %broadcast.2 = s32[6]{0:T(128)} broadcast(s32[]{:T(128)} %constant.1), dimensions={}
1769     ROOT %while.17 = s32[6]{0:T(128)} while(s32[6]{0:T(128)} %broadcast.2), condition=%WhileWithPrngScalarResult.11, body=%body.3
1770   }
1771   )";
1772 
1773   TF_ASSERT_OK_AND_ASSIGN(auto module,
1774                           ParseAndReturnVerifiedModule(hlo_string));
1775   AssignMemorySpace(module.get());
1776 }
1777 
TEST_P(MemorySpaceAssignmentTest,WhileInPlaceBuffer)1778 TEST_P(MemorySpaceAssignmentTest, WhileInPlaceBuffer) {
1779   // Ensure that a dynamic update slice within a while loop is able to get an
1780   // alternate memory allocation.
1781   absl::string_view hlo_string = R"(
1782   HloModule Module, is_scheduled=true
1783 
1784   fused_computation {
1785     param0 = f32[2,3] parameter(0)
1786     constant.1 = f32[] constant(0)
1787     broadcast = f32[2,1] broadcast(constant.1), dimensions={}
1788     constant.3 = s32[] constant(0)
1789     ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
1790   }
1791 
1792   %WhileBody (body_param: (f32[2,3], f32[2,3], f32[])) -> (f32[2,3], f32[2,3], f32[]) {
1793     %body_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
1794     %get-tuple-element.1 = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=2
1795     %get-tuple-element.2 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=0
1796     %get-tuple-element.3 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=1
1797     %fusion = f32[2,3]{1,0} fusion(get-tuple-element.3), kind=kLoop, calls=fused_computation
1798     %multiply = f32[2,3]{1,0} multiply(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %fusion)
1799     ROOT %tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} %multiply, f32[2,3]{1,0} %fusion, f32[] %get-tuple-element.1)
1800   }
1801 
1802   %WhileCond (cond_param: (f32[2,3], f32[2,3], f32[])) -> pred[] {
1803     %cond_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
1804     %get-tuple-element = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %cond_param), index=2
1805     %constant = f32[] constant(50)
1806     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1807   }
1808 
1809   ENTRY %Entry (param_data: f32[2,3], param_iter: f32[], p2: f32[2,3]) -> f32[2,3] {
1810     %param_iter = f32[] parameter(1)
1811     %param_data = f32[2,3]{1,0} parameter(0)
1812     %p2 = f32[2,3]{1,0} parameter(2)
1813     %copy1 = f32[2,3]{1,0} copy(param_data)
1814     %copy2 = f32[2,3]{1,0} copy(p2)
1815     %tuple.1 = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} copy1, f32[2,3]{1,0} copy2, f32[] %param_iter)
1816     %while = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) while((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1817     %get-tuple-element.4 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %while), index=0
1818     ROOT %copy3 = f32[2,3]{1,0} copy(get-tuple-element.4)
1819   }
1820   )";
1821 
1822   TF_ASSERT_OK_AND_ASSIGN(auto module,
1823                           ParseAndReturnVerifiedModule(hlo_string));
1824   AssignMemorySpace(module.get());
1825   const HloInstruction* while_op =
1826       module->entry_computation()->GetInstructionWithName("while");
1827   if (GetParam()) {
1828     EXPECT_EQ(
1829         ShapeUtil::GetSubshape(while_op->shape(), {1}).layout().memory_space(),
1830         kAlternateMemorySpace);
1831   }
1832 }
1833 
TEST_P(MemorySpaceAssignmentTest,WhileSharedBufferVerificationBug)1834 TEST_P(MemorySpaceAssignmentTest, WhileSharedBufferVerificationBug) {
1835   // Tests a spurious verification failure when a while has the same value
1836   // passed in twice (copy0) and that value is evicted within the while loop.
1837   absl::string_view hlo_string = R"(
1838   HloModule module, is_scheduled=true
1839 
1840   while_cond {
1841     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1842     ROOT gte = pred[] get-tuple-element(p0), index=3
1843   }
1844 
1845   while_body {
1846     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1847     gte0 = f32[3]{0} get-tuple-element(p0), index=0
1848     gte1 = f32[3]{0} get-tuple-element(p0), index=1
1849     gte2 = f32[3]{0} get-tuple-element(p0), index=2
1850     gte3 = pred[] get-tuple-element(p0), index=3
1851     add = f32[3]{0} add(gte0, gte0)
1852     negate0 = f32[3]{0} negate(add)
1853     negate1 = f32[3]{0} negate(negate0)
1854     negate2 = f32[3]{0} negate(negate1)
1855     negate3 = f32[3]{0} negate(negate2)
1856     negate4 = f32[3]{0} negate(negate3)
1857     negate5 = f32[3]{0} negate(negate4)
1858     negate6 = f32[3]{0} negate(negate5)
1859     negate7 = f32[3]{0} negate(negate6)
1860     negate8 = f32[3]{0} negate(negate7)
1861     negate9 = f32[3]{0} negate(negate8)
1862     negate10 = f32[3]{0} negate(negate9)
1863     negate11 = f32[3]{0} negate(negate10)
1864     negate12 = f32[3]{0} negate(negate11)
1865     negate13 = f32[3]{0} negate(negate12)
1866     negate14 = f32[3]{0} negate(negate13)
1867     negate15 = f32[3]{0} negate(negate14)
1868     negate16 = f32[3]{0} negate(negate15)
1869     ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, gte0, negate16, gte3)
1870   }
1871 
1872   ENTRY entry {
1873     p0 = f32[3]{0} parameter(0)
1874     p1 = pred[] parameter(1)
1875     copy0 = f32[3]{0} copy(p0)
1876     copy1 = f32[3]{0} copy(p0)
1877     tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
1878     while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
1879     ROOT gte = f32[3]{0} get-tuple-element(while), index=2
1880   }
1881   )";
1882   TF_ASSERT_OK_AND_ASSIGN(auto module,
1883                           ParseAndReturnVerifiedModule(hlo_string));
1884   AssignMemorySpace(module.get());
1885 }
1886 
TEST_P(MemorySpaceAssignmentTest,b228599972)1887 TEST_P(MemorySpaceAssignmentTest, b228599972) {
1888   absl::string_view hlo_string = R"(
1889 HloModule entry, is_scheduled=true
1890 
1891 fused_computation {
1892   %p0 = f32[2,3]{1,0} parameter(0)
1893   %result0 = f32[2,3]{1,0} copy(%p0)
1894   %result1 = f32[2,3]{1,0} copy(%p0)
1895   ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}) tuple(%result0, %result1)
1896 }
1897 
1898 ENTRY entry {
1899   %p0 = f32[2,3]{1,0} parameter(0)
1900   %p1 = f32[2,3]{1,0} parameter(1)
1901   %unused = (f32[2,3]{1,0}, f32[2,3]{1,0}) fusion(%p0), kind=kLoop, calls=%fused_computation
1902   %unused.0 = f32[2,3]{1,0} get-tuple-element(%unused), index=0
1903   %unused.1 = f32[2,3]{1,0} get-tuple-element(%unused), index=1
1904   %negate.0 = f32[2,3]{1,0} negate(f32[2,3]{1,0} %unused.0)
1905   %negate.1 = f32[2,3]{1,0} negate(f32[2,3]{1,0} %unused.1)
1906 
1907   ROOT %result = f32[2,3]{1,0} negate(%p1)
1908 }
1909   )";
1910   TF_ASSERT_OK_AND_ASSIGN(auto module,
1911                           ParseAndReturnVerifiedModule(hlo_string));
1912   AssignMemorySpace(module.get());
1913 }
1914 
TEST_P(MemorySpaceAssignmentTest,b172243149)1915 TEST_P(MemorySpaceAssignmentTest, b172243149) {
1916   // Tests for the failure in b/172243149, where if we skip processing
1917   // non-copy allocations that are in default memory can actually cause
1918   // failures. In this case, the problem tensor is copy0, where it is fed to
1919   // both negate, while, and add0. The copy0->negate dependency can be allocated
1920   // in the alternate memory. Then the algorithm attempts to place the
1921   // copy0->while edge in the alternate memory, but since this value isn't used
1922   // in the while loop, it won't get an alternate memory allocation. Finally for
1923   // the copy0->add0 edge, the algorithm will actually replace it with
1924   // while{0}->add0, since this is equivalent and while is defined later than
1925   // copy0. However, if we actually skip processing this while{0}->add0
1926   // allocation, we won't replace this edge, and will end up with the
1927   // copy0->add0 edge, which illegally extends the lifetime of the alternate
1928   // memory buffer in copy0.
1929   absl::string_view hlo_string = R"(
1930   HloModule module, is_scheduled=true
1931 
1932   while_cond {
1933     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1934     ROOT gte = pred[] get-tuple-element(p0), index=3
1935   }
1936 
1937   while_body {
1938     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1939     gte0 = f32[3]{0} get-tuple-element(p0), index=0
1940     gte1 = f32[3]{0} get-tuple-element(p0), index=1
1941     gte2 = f32[3]{0} get-tuple-element(p0), index=2
1942     gte3 = pred[] get-tuple-element(p0), index=3
1943     add = f32[3]{0} add(gte1, gte2)
1944     negate0 = f32[3]{0} negate(add)
1945     negate1 = f32[3]{0} negate(negate0)
1946     negate2 = f32[3]{0} negate(negate1)
1947     negate3 = f32[3]{0} negate(negate2)
1948     negate4 = f32[3]{0} negate(negate3)
1949     negate5 = f32[3]{0} negate(negate4)
1950     negate6 = f32[3]{0} negate(negate5)
1951     negate7 = f32[3]{0} negate(negate6)
1952     negate8 = f32[3]{0} negate(negate7)
1953     negate9 = f32[3]{0} negate(negate8)
1954     negate10 = f32[3]{0} negate(negate9)
1955     negate11 = f32[3]{0} negate(negate10)
1956     negate12 = f32[3]{0} negate(negate11)
1957     negate13 = f32[3]{0} negate(negate12)
1958     negate14 = f32[3]{0} negate(negate13)
1959     negate15 = f32[3]{0} negate(negate14)
1960     negate16 = f32[3]{0} negate(negate15)
1961     ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, negate16, gte3)
1962   }
1963 
1964   ENTRY entry {
1965     p0 = f32[3]{0} parameter(0)
1966     p1 = pred[] parameter(1)
1967     copy0 = f32[3]{0} copy(p0)
1968     copy1 = f32[3]{0} copy(p0)
1969     copy2 = f32[3]{0} copy(p0)
1970     negate = f32[3]{0} negate(copy0)
1971     tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, copy2, p1)
1972     while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
1973     gte = f32[3]{0} get-tuple-element(while), index=2
1974     add0 = f32[3]{0} add(negate, copy0)
1975     ROOT add1 = f32[3]{0} add(add0, gte)
1976   }
1977   )";
1978   TF_ASSERT_OK_AND_ASSIGN(auto module,
1979                           ParseAndReturnVerifiedModule(hlo_string));
1980   AssignMemorySpace(module.get());
1981 }
1982 
TEST_P(MemorySpaceAssignmentTest,ControlPredecessorsBug)1983 TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
1984   // Having control_predecessors on an HLO was preventing us from DCEing an op
1985   // that doesn't have any users (tuple.1). The scheduler assumes the graph is
1986   // fully DCEed, which causes some instructions not to be scheduled.
1987   absl::string_view hlo_string = R"(
1988   HloModule sort.16, is_scheduled=true
1989 
1990   ENTRY %sort.16 (param.0.1: s32[1], param.1.2: f32[1], param.2.3: u32[1], param.3.4: s32[1]) -> (s32[1], f32[1], u32[1], s32[1]) {
1991     %param.3.4 = s32[1]{0:T(128)} parameter(3)
1992     %param.2.3 = u32[1]{0:T(128)} parameter(2)
1993     %param.1.2 = f32[1]{0:T(128)} parameter(1)
1994     %param.0.1 = s32[1]{0:T(128)} parameter(0)
1995     %tuple.1 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %param.0.1, f32[1]{0:T(128)} %param.1.2, u32[1]{0:T(128)} %param.2.3, s32[1]{0:T(128)} %param.3.4), control-predecessors={%param.0.1}
1996     %get-tuple-element.4 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=0
1997     %get-tuple-element.5 = f32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=1
1998     %get-tuple-element.6 = u32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=2
1999     %get-tuple-element.7 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=3
2000     %copy.4 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.4)
2001     %copy.5 = f32[1]{0:T(128)} copy(f32[1]{0:T(128)} %get-tuple-element.5)
2002     %copy.6 = u32[1]{0:T(128)} copy(u32[1]{0:T(128)} %get-tuple-element.6)
2003     %copy.7 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.7)
2004     ROOT %tuple.2 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %copy.4, f32[1]{0:T(128)} %copy.5, u32[1]{0:T(128)} %copy.6, s32[1]{0:T(128)} %copy.7)
2005 }
2006   )";
2007 
2008   TF_ASSERT_OK_AND_ASSIGN(auto module,
2009                           ParseAndReturnVerifiedModule(hlo_string));
2010   AssignMemorySpace(module.get());
2011 }
2012 
TEST_P(MemorySpaceAssignmentTest,ConditionalShouldBeAllocatedInAlternateMem)2013 TEST_P(MemorySpaceAssignmentTest, ConditionalShouldBeAllocatedInAlternateMem) {
2014   // Checks if simple conditionals get alternate memory allocations.
2015   absl::string_view hlo_string = R"(
2016   HloModule CondAllocation, is_scheduled=true
2017 
2018   true_computation {
2019     p0 = (f32[3]{0}) parameter(0)
2020     gte = f32[3]{0} get-tuple-element(p0), index=0
2021     ROOT neg1 = f32[3]{0} negate(gte)
2022   }
2023 
2024   false_computation {
2025     p0 = (f32[3]{0}) parameter(0)
2026     gte = f32[3]{0} get-tuple-element(p0), index=0
2027     ROOT neg2 = f32[3]{0} negate(gte)
2028   }
2029 
2030   ENTRY entry {
2031     p0 = f32[3]{0} parameter(0)
2032     p1 = pred[] parameter(1)
2033     copy = f32[3]{0} copy(p0)
2034     tuple = (f32[3]{0}) tuple(copy)
2035     ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
2036   }
2037   )";
2038   TF_ASSERT_OK_AND_ASSIGN(auto module,
2039                           ParseAndReturnVerifiedModule(hlo_string));
2040   AssignMemorySpace(module.get());
2041 
2042   if (GetParam()) {
2043     // Check that copy and gtes got alternate memory allocations.
2044     auto copy =
2045         module->GetComputationWithName("entry")->GetInstructionWithName("copy");
2046     EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace);
2047     auto neg1 = module->GetComputationWithName("true_computation")
2048                     ->GetInstructionWithName("neg1");
2049     auto neg1_operand = neg1->operand(0);
2050     EXPECT_EQ(neg1_operand->shape().layout().memory_space(),
2051               kAlternateMemorySpace);
2052     auto neg2 = module->GetComputationWithName("false_computation")
2053                     ->GetInstructionWithName("neg2");
2054     auto neg2_operand = neg2->operand(0);
2055     EXPECT_EQ(neg2_operand->shape().layout().memory_space(),
2056               kAlternateMemorySpace);
2057   }
2058 }
2059 
TEST_P(MemorySpaceAssignmentTest,ConditionalAvoidsUnnecessaryPrefetch)2060 TEST_P(MemorySpaceAssignmentTest, ConditionalAvoidsUnnecessaryPrefetch) {
2061   // Checks if we avoid unnecessary allocation in alternate memory if the input
2062   // won't be used in the computation for a long time.
2063   absl::string_view hlo_string = R"(
2064   HloModule CondAllocation, is_scheduled=true
2065 
2066   true_computation {
2067     p0 = (f32[3]{0}, f32[3]{0}) parameter(0)
2068     gte0 = f32[3]{0} get-tuple-element(p0), index=0
2069     neg0 = f32[3]{0} negate(gte0)
2070     neg1 = f32[3]{0} negate(neg0)
2071     neg2 = f32[3]{0} negate(neg1)
2072     neg3 = f32[3]{0} negate(neg2)
2073     neg4 = f32[3]{0} negate(neg3)
2074     neg5 = f32[3]{0} negate(neg4)
2075     neg6 = f32[3]{0} negate(neg5)
2076     neg7 = f32[3]{0} negate(neg6)
2077     neg8 = f32[3]{0} negate(neg7)
2078     neg9 = f32[3]{0} negate(neg8)
2079     gte1 = f32[3]{0} get-tuple-element(p0), index=1
2080     ROOT add = f32[3]{0} add(neg9, gte1)
2081   }
2082 
2083   false_computation {
2084     p0 = (f32[3]{0}) parameter(0)
2085     gte = f32[3]{0} get-tuple-element(p0), index=0
2086     ROOT neg = f32[3]{0} negate(gte)
2087   }
2088 
2089   ENTRY entry {
2090     p0 = f32[3]{0} parameter(0)
2091     p1 = pred[] parameter(1)
2092     copy0 = f32[3]{0} copy(p0)
2093     copy1 = f32[3]{0} copy(p0)
2094     tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1)
2095     tuple1 = (f32[3]{0}) tuple(copy0)
2096     ROOT conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation
2097   }
2098   )";
2099   TF_ASSERT_OK_AND_ASSIGN(auto module,
2100                           ParseAndReturnVerifiedModule(hlo_string));
2101   AssignMemorySpace(module.get());
2102 
2103   if (GetParam()) {
2104     // Check that copy1 doesn't get unnecessarily allocated in alternate mem
2105     // (due to long negate chain in true_computation) but is prefetched before
2106     // add.
2107     auto copy0 =
2108         module->GetComputationWithName("entry")->GetInstructionWithName(
2109             "copy0");
2110     EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace);
2111     auto copy1 =
2112         module->GetComputationWithName("entry")->GetInstructionWithName(
2113             "copy1");
2114     EXPECT_EQ(copy1->shape().layout().memory_space(), kDefaultMemorySpace);
2115     auto add = module->GetComputationWithName("true_computation")
2116                    ->GetInstructionWithName("add");
2117     auto add_operand = add->operand(1);
2118     EXPECT_EQ(add_operand->shape().layout().memory_space(),
2119               kAlternateMemorySpace);
2120   }
2121 }
2122 
TEST_P(MemorySpaceAssignmentTest,ConditionalMultiUse)2123 TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUse) {
2124   // Make sure there is an evict when there is a conditional use followed by
2125   // another use.
2126   absl::string_view hlo_string = R"(
2127   HloModule CondAllocation, is_scheduled=true
2128 
2129   true_computation {
2130     p0 = (f32[3]{0}, f32[3]{0}) parameter(0)
2131     gte0 = f32[3]{0} get-tuple-element(p0), index=0
2132     gte1 = f32[3]{0} get-tuple-element(p0), index=1
2133     add0 = f32[3]{0} add(gte0, gte1)
2134     neg0 = f32[3]{0} negate(add0)
2135     neg1 = f32[3]{0} negate(neg0)
2136     neg2 = f32[3]{0} negate(neg1)
2137     neg3 = f32[3]{0} negate(neg2)
2138     neg4 = f32[3]{0} negate(neg3)
2139     neg5 = f32[3]{0} negate(neg4)
2140     neg6 = f32[3]{0} negate(neg5)
2141     neg7 = f32[3]{0} negate(neg6)
2142     neg8 = f32[3]{0} negate(neg7)
2143     ROOT neg9 = f32[3]{0} negate(neg8)
2144   }
2145 
2146   false_computation {
2147     p0 = (f32[3]{0}) parameter(0)
2148     gte = f32[3]{0} get-tuple-element(p0), index=0
2149     ROOT neg = f32[3]{0} negate(gte)
2150   }
2151 
2152   ENTRY entry {
2153     p0 = f32[3]{0} parameter(0)
2154     p1 = pred[] parameter(1)
2155     copy0 = f32[3]{0} copy(p0)
2156     copy1 = f32[3]{0} copy(p0)
2157     tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1)
2158     tuple1 = (f32[3]{0}) tuple(copy0)
2159     conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation
2160     ROOT add1 = f32[3]{0} add(copy1, conditional)
2161   }
2162   )";
2163   TF_ASSERT_OK_AND_ASSIGN(auto module,
2164                           ParseAndReturnVerifiedModule(hlo_string));
2165   AssignMemorySpace(module.get());
2166 
2167   if (GetParam()) {
2168     // Make sure the copy1->add edge is in alternate memory. Before conditional,
2169     // this should be evicted to default memory and neg uses the input from
2170     // default memory.
2171     auto copy1 =
2172         module->GetComputationWithName("entry")->GetInstructionWithName(
2173             "copy1");
2174     EXPECT_EQ(copy1->shape().layout().memory_space(), kAlternateMemorySpace);
2175     auto add0 = module->GetComputationWithName("true_computation")
2176                     ->GetInstructionWithName("add0");
2177     auto add0_operand = add0->operand(1);
2178     EXPECT_EQ(add0_operand->shape().layout().memory_space(),
2179               kAlternateMemorySpace);
2180     auto add1 =
2181         module->GetComputationWithName("entry")->GetInstructionWithName("add1");
2182     auto add1_operand = add1->operand(0);
2183     EXPECT_EQ(add1_operand->shape().layout().memory_space(),
2184               kDefaultMemorySpace);
2185     EXPECT_EQ(add1_operand->opcode(), HloOpcode::kCopyDone);
2186   }
2187 }
2188 
TEST_P(MemorySpaceAssignmentTest,ConditionalMultiUseInWhile)2189 TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUseInWhile) {
2190   absl::string_view hlo_string = R"(
2191   HloModule CondAllocation, is_scheduled=true
2192 
2193   true_computation {
2194     p0 = (f32[3]{0}) parameter(0)
2195     gte = f32[3]{0} get-tuple-element(p0), index=0
2196     ROOT neg1 = f32[3]{0} negate(gte)
2197   }
2198 
2199   false_computation {
2200     p0 = (f32[3]{0}) parameter(0)
2201     gte = f32[3]{0} get-tuple-element(p0), index=0
2202     ROOT neg2 = f32[3]{0} negate(gte)
2203   }
2204 
2205   while_cond {
2206     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
2207     ROOT gte = pred[] get-tuple-element(p0), index=2
2208   }
2209 
2210   while_body {
2211     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
2212     gte0 = f32[3]{0} get-tuple-element(p0), index=0
2213     gte1 = f32[3]{0} get-tuple-element(p0), index=1
2214     gte2 = pred[] get-tuple-element(p0), index=2
2215     cond_tuple = (f32[3]{0}) tuple(gte0)
2216     conditional = f32[3]{0} conditional(gte2, cond_tuple, cond_tuple), true_computation=true_computation, false_computation=false_computation
2217     add = f32[3]{0} add(conditional, gte1)
2218     neg0 = f32[3]{0} negate(add)
2219     neg1 = f32[3]{0} negate(neg0)
2220     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, neg1, gte2)
2221   }
2222 
2223   ENTRY entry {
2224     p0 = f32[3]{0} parameter(0)
2225     p1 = pred[] parameter(1)
2226     copy0 = f32[3]{0} copy(p0)
2227     copy1 = f32[3]{0} copy(p0)
2228     tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, p1)
2229     while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
2230     ROOT gte = f32[3]{0} get-tuple-element(while), index=1
2231   }
2232   )";
2233   TF_ASSERT_OK_AND_ASSIGN(auto module,
2234                           ParseAndReturnVerifiedModule(hlo_string));
2235   AssignMemorySpace(module.get());
2236 
2237   if (GetParam()) {
2238     // Make sure copy1/while{0}/cond_tuple{0} gets alternate memory allocation.
2239     // This will force an eviction and a prefetch for while body root.
2240     auto copy0 =
2241         module->GetComputationWithName("entry")->GetInstructionWithName(
2242             "copy0");
2243     EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace);
2244     auto conditional = module->GetComputationWithName("while_body")
2245                            ->GetInstructionWithName("conditional");
2246     auto conditional_operand = conditional->operand(1);
2247     EXPECT_EQ(ShapeUtil::GetSubshape(conditional_operand->shape(), {0})
2248                   .layout()
2249                   .memory_space(),
2250               kAlternateMemorySpace);
2251     auto while_root =
2252         module->GetComputationWithName("while_body")->root_instruction();
2253     auto while_root_operand = while_root->operand(0);
2254     EXPECT_THAT(
2255         while_root_operand,
2256         op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
2257                       op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
2258                                     op::GetTupleElement(op::Parameter(0)))));
2259   }
2260 }
2261 
TEST_P(MemorySpaceAssignmentTest,NestedConditional)2262 TEST_P(MemorySpaceAssignmentTest, NestedConditional) {
2263   absl::string_view hlo_string = R"(
2264   HloModule CondAllocation, is_scheduled=true
2265 
2266   true_computation2 {
2267     p0 = (f32[3]{0}) parameter(0)
2268     gte = f32[3]{0} get-tuple-element(p0), index=0
2269     ROOT neg1 = f32[3]{0} negate(gte)
2270   }
2271 
2272   false_computation2 {
2273     p0 = (f32[3]{0}) parameter(0)
2274     gte = f32[3]{0} get-tuple-element(p0), index=0
2275     ROOT neg2 = f32[3]{0} negate(gte)
2276   }
2277 
2278   true_computation1 {
2279     p0 = (f32[3]{0}) parameter(0)
2280     gte = f32[3]{0} get-tuple-element(p0), index=0
2281     slice = f32[1]{0} slice(gte), slice={[0:1]}
2282     bitcast = f32[] bitcast(slice)
2283     constant = f32[] constant(0.0)
2284     compare = pred[] compare(bitcast, constant), direction=GT
2285     ROOT conditional = f32[3]{0} conditional(compare, p0, p0), true_computation=true_computation2, false_computation=false_computation2
2286   }
2287 
2288   false_computation1 {
2289     p0 = (f32[3]{0}) parameter(0)
2290     gte = f32[3]{0} get-tuple-element(p0), index=0
2291     ROOT neg3 = f32[3]{0} negate(gte)
2292   }
2293 
2294 
2295   ENTRY entry {
2296     p0 = f32[3]{0} parameter(0)
2297     p1 = pred[] parameter(1)
2298     copy = f32[3]{0} copy(p0)
2299     tuple = (f32[3]{0}) tuple(copy)
2300     ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
2301   }
2302   )";
2303   TF_ASSERT_OK_AND_ASSIGN(auto module,
2304                           ParseAndReturnVerifiedModule(hlo_string));
2305   AssignMemorySpace(module.get());
2306 
2307   if (GetParam()) {
2308     // Make sure alternate memory allocation gets propagated into both levels of
2309     // conditional.
2310     auto copy =
2311         module->GetComputationWithName("entry")->GetInstructionWithName("copy");
2312     EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace);
2313     auto neg1_operand = module->GetComputationWithName("true_computation2")
2314                             ->GetInstructionWithName("neg1")
2315                             ->operand(0);
2316     auto neg2_operand = module->GetComputationWithName("false_computation2")
2317                             ->GetInstructionWithName("neg2")
2318                             ->operand(0);
2319     auto neg3_operand = module->GetComputationWithName("false_computation1")
2320                             ->GetInstructionWithName("neg3")
2321                             ->operand(0);
2322     EXPECT_EQ(neg1_operand->shape().layout().memory_space(),
2323               kAlternateMemorySpace);
2324     EXPECT_EQ(neg2_operand->shape().layout().memory_space(),
2325               kAlternateMemorySpace);
2326     EXPECT_EQ(neg3_operand->shape().layout().memory_space(),
2327               kAlternateMemorySpace);
2328   }
2329 }
2330 
TEST_P(MemorySpaceAssignmentTest,NestedConditionalBufferReuseVerificationBug)2331 TEST_P(MemorySpaceAssignmentTest, NestedConditionalBufferReuseVerificationBug) {
2332   // Tests a spurious verification failure when there are nested conditionals
2333   // and the innermost conditional computation reuses the buffer. Here, both the
2334   // parameter of true_computation2 and neg2 will get the same buffer. Make sure
2335   // that verification doesn't claim a failure in this case.
2336   absl::string_view hlo_string = R"(
2337   HloModule CondAllocation, is_scheduled=true
2338 
2339   true_computation2 {
2340     p0 = (f32[3]{0}) parameter(0)
2341     gte = f32[3]{0} get-tuple-element(p0), index=0
2342     neg1 = f32[3]{0} negate(gte)
2343     neg2 = f32[3]{0} negate(neg1)
2344     ROOT neg3 = f32[3]{0} negate(neg2)
2345   }
2346 
2347   false_computation2 {
2348     p0 = (f32[3]{0}) parameter(0)
2349     gte = f32[3]{0} get-tuple-element(p0), index=0
2350     ROOT neg4 = f32[3]{0} negate(gte)
2351   }
2352 
2353   true_computation1 {
2354     p0 = (f32[3]{0}) parameter(0)
2355     gte = f32[3]{0} get-tuple-element(p0), index=0
2356     slice = f32[1]{0} slice(gte), slice={[0:1]}
2357     bitcast = f32[] bitcast(slice)
2358     constant = f32[] constant(0.0)
2359     compare = pred[] compare(bitcast, constant), direction=GT
2360     tuple = (f32[3]{0}) tuple(gte)
2361     ROOT conditional = f32[3]{0} conditional(compare, tuple, tuple), true_computation=true_computation2, false_computation=false_computation2
2362   }
2363 
2364   false_computation1 {
2365     p0 = (f32[3]{0}) parameter(0)
2366     gte = f32[3]{0} get-tuple-element(p0), index=0
2367     ROOT neg5 = f32[3]{0} negate(gte)
2368   }
2369 
2370   ENTRY entry {
2371     p0 = f32[3]{0} parameter(0)
2372     p1 = pred[] parameter(1)
2373     copy = f32[3]{0} copy(p0)
2374     tuple = (f32[3]{0}) tuple(copy)
2375     ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
2376   }
2377   )";
2378   TF_ASSERT_OK_AND_ASSIGN(auto module,
2379                           ParseAndReturnVerifiedModule(hlo_string));
2380   AssignMemorySpace(module.get());
2381 }
2382 
TEST_P(MemorySpaceAssignmentTest,ConditionalComputationBufferOverlapBeforeParam)2383 TEST_P(MemorySpaceAssignmentTest,
2384        ConditionalComputationBufferOverlapBeforeParam) {
2385   absl::string_view hlo_string = R"(
2386   HloModule CondAllocation, is_scheduled=true
2387 
2388   true_computation {
2389     p0 = (f32[3]{0}) parameter(0)
2390     gte = f32[3]{0} get-tuple-element(p0), index=0
2391     ROOT neg2 = f32[3]{0} negate(gte)
2392   }
2393 
2394   false_computation {
2395     c = f32[3]{0} constant({0.0, 1.0, 2.0})
2396     neg0 = f32[3]{0} negate(c)
2397     neg1 = f32[3]{0} negate(neg0)
2398     p0 = (f32[3]{0}) parameter(0)
2399     gte = f32[3]{0} get-tuple-element(p0), index=0
2400     ROOT add = f32[3]{0} add(gte, neg1)
2401   }
2402 
2403   ENTRY entry {
2404     p0 = f32[3]{0} parameter(0)
2405     p1 = pred[] parameter(1)
2406     copy = f32[3]{0} copy(p0)
2407     tuple = (f32[3]{0}) tuple(copy)
2408     ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
2409   }
2410   )";
2411   TF_ASSERT_OK_AND_ASSIGN(auto module,
2412                           ParseAndReturnVerifiedModule(hlo_string));
2413   auto preset_assignments = AssignMemorySpace(module.get());
2414 
2415   auto get_offset = [&](absl::string_view hlo_name) {
2416     for (const auto& chunk : preset_assignments->chunks()) {
2417       if (chunk.first.instruction->name() == hlo_name) {
2418         return chunk.second.offset;
2419       }
2420     }
2421     return static_cast<int64_t>(-1);
2422   };
2423 
2424   int64_t copy_offset = get_offset("copy");
2425   int64_t neg0_offset = get_offset("neg0");
2426   EXPECT_NE(copy_offset, -1);
2427   EXPECT_NE(neg0_offset, -1);
2428   EXPECT_NE(copy_offset, neg0_offset);
2429 }
2430 
TEST_P(MemorySpaceAssignmentTest,RequestIdentifierShouldNotBeAllocatedInAlternateMem)2431 TEST_P(MemorySpaceAssignmentTest,
2432        RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
2433   // Ensure that request identifier returned by Send/Recv HLOs are not allocated
2434   // in the alternate memory.
2435   absl::string_view hlo_string = R"(
2436   HloModule SendRecv, is_scheduled=true
2437 
2438   ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2439     %p = f32[3]{0} parameter(0)
2440     %after-all = token[] after-all()
2441     %recv.4 = (f32[3]{0}, u32[], token[]) recv(token[] %after-all), channel_id=7
2442     %recv-done.4 = (f32[3]{0}, token[]) recv-done((f32[3]{0}, u32[], token[]) %recv.4), channel_id=7
2443     %token.1 = token[] get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=1
2444     %data = f32[3]{0} get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=0
2445     %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %data, token[] %token.1), channel_id=2
2446     %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2447     ROOT %add = f32[3]{0} add(f32[3]{0} %p, f32[3]{0} %data)
2448   }
2449   )";
2450 
2451   TF_ASSERT_OK_AND_ASSIGN(auto module,
2452                           ParseAndReturnVerifiedModule(hlo_string));
2453   AssignMemorySpace(module.get());
2454 
2455   for (const HloInstruction* instruction :
2456        module->entry_computation()->instructions()) {
2457     if (instruction->opcode() == HloOpcode::kSend ||
2458         instruction->opcode() == HloOpcode::kRecv) {
2459       const Shape& request_identifier_shape =
2460           ShapeUtil::GetSubshape(instruction->shape(), {1});
2461       EXPECT_NE(request_identifier_shape.layout().memory_space(),
2462                 kAlternateMemorySpace);
2463     }
2464   }
2465 }
2466 
TEST_P(MemorySpaceAssignmentTest,SendDoneShouldHaveSendOperand)2467 TEST_P(MemorySpaceAssignmentTest, SendDoneShouldHaveSendOperand) {
2468   // Ensure that SendDone has only a Send operand.
2469   absl::string_view hlo_string = R"(
2470   HloModule SendRecv, is_scheduled=true
2471 
2472   ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2473     %p0 = f32[3]{0} parameter(0)
2474     %p1 = f32[3]{0} parameter(1)
2475     %neg0 = f32[3]{0} negate(f32[3]{0} %p1)
2476     %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
2477     %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
2478     %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
2479     %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
2480     %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
2481     %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
2482     %after-all = token[] after-all()
2483     %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2
2484     %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2485     ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6)
2486   }
2487   )";
2488 
2489   TF_ASSERT_OK_AND_ASSIGN(auto module,
2490                           ParseAndReturnVerifiedModule(hlo_string));
2491   AssignMemorySpace(module.get());
2492 }
2493 
TEST_P(MemorySpaceAssignmentTest,SendAndSendDoneShouldGetSameAllocation)2494 TEST_P(MemorySpaceAssignmentTest, SendAndSendDoneShouldGetSameAllocation) {
2495   // Ensure that Send and SendDone have the same allocation.
2496   absl::string_view hlo_string = R"(
2497   HloModule SendRecv, is_scheduled=true
2498 
2499   ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2500     %p0 = f32[3]{0} parameter(0)
2501     %p1 = f32[3]{0} parameter(1)
2502     %after-all = token[] after-all()
2503     %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2
2504     %neg0 = f32[3]{0} negate(f32[3]{0} %p1)
2505     %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
2506     %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
2507     %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
2508     %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
2509     %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
2510     %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
2511     %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2512     ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6)
2513   }
2514   )";
2515 
2516   TF_ASSERT_OK_AND_ASSIGN(auto module,
2517                           ParseAndReturnVerifiedModule(hlo_string));
2518   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
2519                     /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/4);
2520 }
2521 
TEST_P(MemorySpaceAssignmentTest,LastUseOpt)2522 TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
2523   // Test that checks the last use optimization. It uses two buffers that should
2524   // be placed in alternate memory.
2525   //
2526   //      +-------+
2527   //     /         \
2528   // add1--->sub1   +-------->mul2
2529   //              mul1===>add2
2530   //
2531   // Without the last use optimization, the mul1 buffer will be assigned first
2532   // (because it is larger) to offset 0. Then, add1 will be scheduled for the
2533   // add1 to sub1 segment. Because offset 0 is available, it will get that
2534   // offset. But because offset 0 is not available in the sub1 to mul2 offset,
2535   // it will end up in unnecessary copies. With the last use optimization, these
2536   // copies can be optimized away.
2537   HloComputation::Builder builder(TestName());
2538   Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
2539   Shape shape2 = ShapeUtil::MakeShape(F32, {2, 4});
2540   PaddingConfig padding_config = MakeEdgePaddingConfig({{0, 0}, {0, 1}});
2541   HloInstruction* p0 =
2542       builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0"));
2543   HloInstruction* p1 =
2544       builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1"));
2545   HloInstruction* add1 = builder.AddInstruction(
2546       HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, p0, p0));
2547   HloInstruction* sub1 = builder.AddInstruction(
2548       HloInstruction::CreateBinary(shape1, HloOpcode::kSubtract, p0, add1));
2549   HloInstruction* mul1 = builder.AddInstruction(
2550       HloInstruction::CreateBinary(shape2, HloOpcode::kMultiply, p1, p1));
2551   HloInstruction* add2 = builder.AddInstruction(
2552       HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, mul1, p1));
2553   HloInstruction* mul2 = builder.AddInstruction(
2554       HloInstruction::CreateBinary(shape1, HloOpcode::kMultiply, add1, sub1));
2555   HloInstruction* padding_value = builder.AddInstruction(
2556       HloInstruction::CreateConstant(LiteralUtil::Zero(F32)));
2557   HloInstruction* padded_mul2 = builder.AddInstruction(
2558       HloInstruction::CreatePad(shape2, mul2, padding_value, padding_config));
2559   HloInstruction* add3 = builder.AddInstruction(
2560       HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, add2, padded_mul2));
2561 
2562   auto module = CreateNewVerifiedModule();
2563   HloComputation* computation = module->AddEntryComputation(builder.Build());
2564 
2565   HloSchedule schedule(module.get());
2566   schedule.set_sequence(computation, {p0, p1, add1, sub1, mul1, add2, mul2,
2567                                       padding_value, padded_mul2, add3});
2568   TF_CHECK_OK(module->set_schedule(schedule));
2569 
2570   AssignMemorySpace(module.get());
2571 
2572   EXPECT_THAT(
2573       mul2,
2574       op::Multiply(
2575           op::Add(op::Parameter(0), op::Parameter(0)),
2576           op::Subtract(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
2577                                      op::Parameter(0)),
2578                        op::Add(op::Parameter(0), op::Parameter(0)))));
2579 }
2580 
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule1)2581 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
2582   // Test to ensure CopyStart/CopyDone is placed only in the entry computation.
2583   auto module = CreateNewVerifiedModule();
2584   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2585   Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
2586   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
2587 
2588   auto cond_builder = HloComputation::Builder("WhileCond");
2589   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
2590   HloInstruction* cond_param = cond_builder.AddInstruction(
2591       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
2592   HloInstruction* cond_iter = cond_builder.AddInstruction(
2593       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
2594   HloInstruction* cond_limit = cond_builder.AddInstruction(
2595       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
2596   // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
2597   HloInstruction* cond_lt = cond_builder.AddInstruction(
2598       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
2599                                     cond_limit, ComparisonDirection::kLt));
2600   HloComputation* cond_computation =
2601       module->AddEmbeddedComputation(cond_builder.Build());
2602 
2603   auto body_builder = HloComputation::Builder("WhileBody");
2604   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
2605   HloInstruction* body_param = body_builder.AddInstruction(
2606       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
2607   HloInstruction* body_iter = body_builder.AddInstruction(
2608       HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
2609   HloInstruction* body_data = body_builder.AddInstruction(
2610       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
2611   HloInstruction* body_iter_increment = body_builder.AddInstruction(
2612       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
2613   HloInstruction* body_iter_next =
2614       body_builder.AddInstruction(HloInstruction::CreateBinary(
2615           scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
2616   HloInstruction* body_data_increment =
2617       body_builder.AddInstruction(HloInstruction::CreateConstant(
2618           LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
2619   HloInstruction* body_data_mul =
2620       body_builder.AddInstruction(HloInstruction::CreateBinary(
2621           shape, HloOpcode::kMultiply, body_data, body_data));
2622   HloInstruction* body_data_add =
2623       body_builder.AddInstruction(HloInstruction::CreateBinary(
2624           shape, HloOpcode::kAdd, body_data, body_data_increment));
2625   HloInstruction* body_data_next =
2626       body_builder.AddInstruction(HloInstruction::CreateBinary(
2627           shape, HloOpcode::kAdd, body_data_add, body_data_mul));
2628   HloInstruction* body_out = body_builder.AddInstruction(
2629       HloInstruction::CreateTuple({body_data_next, body_iter_next}));
2630   HloComputation* body_computation =
2631       module->AddEmbeddedComputation(body_builder.Build());
2632 
2633   auto builder = HloComputation::Builder(TestName());
2634   HloInstruction* data = builder.AddInstruction(
2635       HloInstruction::CreateParameter(0, shape, "param_iter"));
2636   HloInstruction* iter = builder.AddInstruction(
2637       HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
2638   HloInstruction* p2 =
2639       builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "p2"));
2640   HloInstruction* tuple =
2641       builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
2642   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
2643       tuple_shape, cond_computation, body_computation, tuple));
2644   HloInstruction* while_data = builder.AddInstruction(
2645       HloInstruction::CreateGetTupleElement(shape, while_op, 0));
2646   HloInstruction* add = builder.AddInstruction(
2647       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, while_data, p2));
2648   HloComputation* entry_computation =
2649       module->AddEntryComputation(builder.Build());
2650 
2651   HloSchedule schedule(module.get());
2652   schedule.set_sequence(cond_computation,
2653                         {cond_param, cond_iter, cond_limit, cond_lt});
2654   schedule.set_sequence(body_computation,
2655                         {body_param, body_iter, body_data, body_iter_increment,
2656                          body_iter_next, body_data_increment, body_data_mul,
2657                          body_data_add, body_data_next, body_out});
2658   schedule.set_sequence(entry_computation,
2659                         {iter, data, p2, tuple, while_op, while_data, add});
2660   TF_CHECK_OK(module->set_schedule(schedule));
2661 
2662   AssignMemorySpace(module.get(), -1, 50);
2663 }
2664 
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule2)2665 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
2666   auto module = CreateNewVerifiedModule();
2667   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2668   Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2669 
2670   auto call_builder = HloComputation::Builder("Call");
2671   HloInstruction* call_param = call_builder.AddInstruction(
2672       HloInstruction::CreateParameter(0, shape, "call_param"));
2673   HloInstruction* call_param2 = call_builder.AddInstruction(
2674       HloInstruction::CreateParameter(1, shape2, "call_param2"));
2675   HloInstruction* slice = call_builder.AddInstruction(
2676       HloInstruction::CreateSlice(shape, call_param2, {0, 0}, {2, 3}, {1, 1}));
2677   HloInstruction* mul =
2678       call_builder.AddInstruction(HloInstruction::CreateBinary(
2679           shape, HloOpcode::kMultiply, call_param, slice));
2680   HloInstruction* negate0 = call_builder.AddInstruction(
2681       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2682   HloInstruction* negate1 = call_builder.AddInstruction(
2683       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2684   HloInstruction* negate2 = call_builder.AddInstruction(
2685       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2686   HloInstruction* negate3 = call_builder.AddInstruction(
2687       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2688   HloInstruction* negate4 = call_builder.AddInstruction(
2689       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2690   HloInstruction* negate5 = call_builder.AddInstruction(
2691       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2692   HloInstruction* negate6 = call_builder.AddInstruction(
2693       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2694   HloInstruction* negate7 = call_builder.AddInstruction(
2695       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2696   HloInstruction* add0 =
2697       call_builder.AddInstruction(HloInstruction::CreateBinary(
2698           shape, HloOpcode::kAdd, call_param, negate7));
2699   HloComputation* call_computation =
2700       module->AddEmbeddedComputation(call_builder.Build());
2701 
2702   auto builder = HloComputation::Builder(TestName());
2703   HloInstruction* p0 =
2704       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2705   HloInstruction* p1 =
2706       builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1"));
2707   HloInstruction* add1 = builder.AddInstruction(
2708       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2709   HloInstruction* add2 = builder.AddInstruction(
2710       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2711   HloInstruction* negate8 = builder.AddInstruction(
2712       HloInstruction::CreateUnary(shape2, HloOpcode::kNegate, p1));
2713   HloInstruction* call = builder.AddInstruction(
2714       HloInstruction::CreateCall(shape, {add1, negate8}, call_computation));
2715   HloInstruction* add3 = builder.AddInstruction(
2716       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, add1));
2717   HloInstruction* add4 = builder.AddInstruction(
2718       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add3));
2719   HloInstruction* add5 = builder.AddInstruction(
2720       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add2, add4));
2721   HloComputation* entry_computation =
2722       module->AddEntryComputation(builder.Build());
2723 
2724   HloSchedule schedule(module.get());
2725   schedule.set_sequence(
2726       call_computation,
2727       {call_param, call_param2, slice, mul, negate0, negate1, negate2, negate3,
2728        negate4, negate5, negate6, negate7, add0});
2729   schedule.set_sequence(entry_computation,
2730                         {p0, p1, add1, add2, negate8, call, add3, add4, add5});
2731   TF_CHECK_OK(module->set_schedule(schedule));
2732 
2733   AssignMemorySpace(module.get(), -1, 5);
2734 }
2735 
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule3)2736 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
2737   auto module = CreateNewVerifiedModule();
2738   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2739   Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2740 
2741   auto call_builder = HloComputation::Builder("Call");
2742   HloInstruction* call_param = call_builder.AddInstruction(
2743       HloInstruction::CreateParameter(0, shape, "call_param"));
2744   // Use shape2 here which is larger (scheduled earlier) to occupy alternate
2745   // memory at the beginning. This should cause a situation where the prefetch
2746   // of add1 later in the function body gets the wrong offset which cannot be
2747   // communicated to the outside the function.
2748   HloInstruction* iota =
2749       call_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0));
2750   HloInstruction* slice = call_builder.AddInstruction(
2751       HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1}));
2752   HloInstruction* mul =
2753       call_builder.AddInstruction(HloInstruction::CreateBinary(
2754           shape, HloOpcode::kMultiply, call_param, slice));
2755   HloInstruction* negate0 = call_builder.AddInstruction(
2756       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2757   HloInstruction* negate1 = call_builder.AddInstruction(
2758       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2759   HloInstruction* negate2 = call_builder.AddInstruction(
2760       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2761   HloInstruction* negate3 = call_builder.AddInstruction(
2762       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2763   HloInstruction* negate4 = call_builder.AddInstruction(
2764       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2765   HloInstruction* negate5 = call_builder.AddInstruction(
2766       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2767   HloInstruction* negate6 = call_builder.AddInstruction(
2768       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2769   HloInstruction* negate7 = call_builder.AddInstruction(
2770       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2771   HloInstruction* add0 =
2772       call_builder.AddInstruction(HloInstruction::CreateBinary(
2773           shape, HloOpcode::kAdd, call_param, negate7));
2774   HloComputation* call_computation =
2775       module->AddEmbeddedComputation(call_builder.Build());
2776 
2777   auto builder = HloComputation::Builder(TestName());
2778   HloInstruction* p0 =
2779       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2780   HloInstruction* add1 = builder.AddInstruction(
2781       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2782   HloInstruction* add2 = builder.AddInstruction(
2783       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2784   HloInstruction* call = builder.AddInstruction(
2785       HloInstruction::CreateCall(shape, {add1}, call_computation));
2786   HloInstruction* add3 = builder.AddInstruction(
2787       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add1));
2788   HloComputation* entry_computation =
2789       module->AddEntryComputation(builder.Build());
2790 
2791   HloSchedule schedule(module.get());
2792   schedule.set_sequence(
2793       call_computation,
2794       {call_param, iota, slice, mul, negate0, negate1, negate2, negate3,
2795        negate4, negate5, negate6, negate7, add0});
2796   schedule.set_sequence(entry_computation, {p0, add1, add2, call, add3});
2797   TF_CHECK_OK(module->set_schedule(schedule));
2798 
2799   AssignMemorySpace(module.get(), -1, 5);
2800 }
2801 
2802 // TODO(berkin): This might be an incorrect input graph, investigate.
TEST_P(MemorySpaceAssignmentTest,DISABLED_NonEntryComputationSchedule4)2803 TEST_P(MemorySpaceAssignmentTest, DISABLED_NonEntryComputationSchedule4) {
2804   auto module = CreateNewVerifiedModule();
2805   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2806   Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2807 
2808   auto true_builder = HloComputation::Builder("True");
2809   HloInstruction* true_param = true_builder.AddInstruction(
2810       HloInstruction::CreateParameter(0, shape, "true_param"));
2811   HloInstruction* iota =
2812       true_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0));
2813   HloInstruction* slice = true_builder.AddInstruction(
2814       HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1}));
2815   HloInstruction* mul =
2816       true_builder.AddInstruction(HloInstruction::CreateBinary(
2817           shape, HloOpcode::kMultiply, true_param, slice));
2818   HloInstruction* negate0 = true_builder.AddInstruction(
2819       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2820   HloInstruction* negate1 = true_builder.AddInstruction(
2821       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2822   HloInstruction* negate2 = true_builder.AddInstruction(
2823       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2824   HloInstruction* negate3 = true_builder.AddInstruction(
2825       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2826   HloInstruction* negate4 = true_builder.AddInstruction(
2827       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2828   HloInstruction* negate5 = true_builder.AddInstruction(
2829       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2830   HloInstruction* negate6 = true_builder.AddInstruction(
2831       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2832   HloInstruction* negate7 = true_builder.AddInstruction(
2833       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2834   HloInstruction* add0 =
2835       true_builder.AddInstruction(HloInstruction::CreateBinary(
2836           shape, HloOpcode::kAdd, true_param, negate7));
2837   HloComputation* true_computation =
2838       module->AddEmbeddedComputation(true_builder.Build());
2839 
2840   auto false_builder = HloComputation::Builder("False");
2841   HloInstruction* false_param = false_builder.AddInstruction(
2842       HloInstruction::CreateParameter(0, shape, "false_param"));
2843   HloComputation* false_computation =
2844       module->AddEmbeddedComputation(false_builder.Build());
2845 
2846   auto builder = HloComputation::Builder(TestName());
2847   HloInstruction* p0 =
2848       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2849   HloInstruction* add1 = builder.AddInstruction(
2850       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2851   HloInstruction* add2 = builder.AddInstruction(
2852       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2853   HloInstruction* pred = builder.AddInstruction(
2854       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2855   HloInstruction* conditional =
2856       builder.AddInstruction(HloInstruction::CreateConditional(
2857           shape, pred, add1, true_computation, add2, false_computation));
2858   HloInstruction* add3 = builder.AddInstruction(
2859       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, conditional, add1));
2860   HloComputation* entry_computation =
2861       module->AddEntryComputation(builder.Build());
2862 
2863   HloSchedule schedule(module.get());
2864   schedule.set_sequence(
2865       true_computation,
2866       {true_param, iota, slice, mul, negate0, negate1, negate2, negate3,
2867        negate4, negate5, negate6, negate7, add0});
2868   schedule.set_sequence(false_computation, {false_param});
2869   schedule.set_sequence(entry_computation,
2870                         {p0, add1, add2, pred, conditional, add3});
2871   TF_CHECK_OK(module->set_schedule(schedule));
2872 
2873   AssignMemorySpace(module.get(), -1, 5);
2874 }
2875 
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule5)2876 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
2877   // This test reproduces the failure in b/143288178.  Given a graph like the
2878   // following:
2879   //
2880   // ... = foo(a)
2881   // tuple = tuple((..., a)
2882   // ... = while(tuple) {
2883   //   p = param(0)
2884   //   a1 = get-tuple-element(p), index=n-1
2885   //   ...
2886   //   ROOT tuple((..., a1))
2887   // }
2888   //
2889   // If a copy to alternate memory is inserted before foo, and if the size of
2890   // the while body is less than max prefetch interval so that the copy-done is
2891   // kept in the alternate memory, then we end up referring to the copy-done in
2892   // the root instruction of the while loop body. I.e.,
2893   //
2894   // cs = copy-start(a)
2895   // ...
2896   // cd = copy-done(cs)
2897   // ... = foo(cd)
2898   // tuple = tuple((..., cd)
2899   // ... = while(tuple) {
2900   //   p = param(0)
2901   //   a1 = get-tuple-element(p), index=n-1
2902   //   ...
2903   //   ROOT tuple((..., cd))  <-- Error: cd belongs to outside computation.
2904   // }
2905   //
2906   auto module = CreateNewVerifiedModule();
2907   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2908   Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
2909   Shape tuple_shape =
2910       ShapeUtil::MakeTupleShape({shape, scalar_shape, scalar_shape});
2911 
2912   auto cond_builder = HloComputation::Builder("WhileCond");
2913   HloInstruction* cond_param = cond_builder.AddInstruction(
2914       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
2915   HloInstruction* cond_iter = cond_builder.AddInstruction(
2916       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
2917   HloInstruction* cond_limit = cond_builder.AddInstruction(
2918       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
2919   HloInstruction* cond_lt = cond_builder.AddInstruction(
2920       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
2921                                     cond_limit, ComparisonDirection::kLt));
2922   HloComputation* cond_computation =
2923       module->AddEmbeddedComputation(cond_builder.Build());
2924 
2925   auto body_builder = HloComputation::Builder("WhileBody");
2926   HloInstruction* body_param = body_builder.AddInstruction(
2927       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
2928   HloInstruction* body_iter = body_builder.AddInstruction(
2929       HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
2930   HloInstruction* body_data = body_builder.AddInstruction(
2931       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
2932   HloInstruction* body_iter_increment = body_builder.AddInstruction(
2933       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
2934   HloInstruction* body_iter_next =
2935       body_builder.AddInstruction(HloInstruction::CreateBinary(
2936           scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
2937   HloInstruction* body_data2 = body_builder.AddInstruction(
2938       HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 2));
2939   HloInstruction* body_out = body_builder.AddInstruction(
2940       HloInstruction::CreateTuple({body_data, body_iter_next, body_data2}));
2941   HloComputation* body_computation =
2942       module->AddEmbeddedComputation(body_builder.Build());
2943 
2944   auto builder = HloComputation::Builder(TestName());
2945   HloInstruction* data = builder.AddInstruction(
2946       HloInstruction::CreateParameter(0, shape, "param_data"));
2947   HloInstruction* iter = builder.AddInstruction(
2948       HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
2949   HloInstruction* data2 = builder.AddInstruction(
2950       HloInstruction::CreateParameter(2, scalar_shape, "param_data2"));
2951   HloInstruction* negate0 = builder.AddInstruction(
2952       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
2953   HloInstruction* negate1 = builder.AddInstruction(
2954       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2955   HloInstruction* negate2 = builder.AddInstruction(
2956       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2957   HloInstruction* negate3 = builder.AddInstruction(
2958       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2959   HloInstruction* negate4 = builder.AddInstruction(
2960       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2961   HloInstruction* negate5 = builder.AddInstruction(
2962       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2963   HloInstruction* negate6 = builder.AddInstruction(
2964       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2965   HloInstruction* negate7 = builder.AddInstruction(
2966       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2967   HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
2968       scalar_shape, HloOpcode::kSubtract, iter, data2));
2969   HloInstruction* tuple = builder.AddInstruction(
2970       HloInstruction::CreateTuple({negate7, iter, data2}));
2971   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
2972       tuple_shape, cond_computation, body_computation, tuple));
2973   HloInstruction* while_data = builder.AddInstruction(
2974       HloInstruction::CreateGetTupleElement(scalar_shape, while_op, 1));
2975   HloInstruction* root =
2976       builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub}));
2977   HloComputation* entry_computation =
2978       module->AddEntryComputation(builder.Build());
2979 
2980   HloSchedule schedule(module.get());
2981   schedule.set_sequence(cond_computation,
2982                         {cond_param, cond_iter, cond_limit, cond_lt});
2983   schedule.set_sequence(body_computation,
2984                         {body_param, body_iter, body_data, body_iter_increment,
2985                          body_iter_next, body_data2, body_out});
2986   schedule.set_sequence(
2987       entry_computation,
2988       {iter, data, data2, negate0, negate1, negate2, negate3, negate4, negate5,
2989        negate6, negate7, sub, tuple, while_op, while_data, root});
2990   TF_CHECK_OK(module->set_schedule(schedule));
2991 
2992   // Set a large max prefetch interval so that the buffer can be kept in
2993   // alternate memory.
2994   AssignMemorySpace(module.get(), -1, 20);
2995 }
2996 
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule6)2997 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) {
2998   auto module = CreateNewVerifiedModule();
2999   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
3000   Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
3001   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape, shape});
3002 
3003   auto cond_builder = HloComputation::Builder("WhileCond");
3004   HloInstruction* cond_param = cond_builder.AddInstruction(
3005       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
3006   HloInstruction* cond_iter = cond_builder.AddInstruction(
3007       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
3008   HloInstruction* cond_limit = cond_builder.AddInstruction(
3009       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
3010   HloInstruction* cond_lt = cond_builder.AddInstruction(
3011       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
3012                                     cond_limit, ComparisonDirection::kLt));
3013   HloComputation* cond_computation =
3014       module->AddEmbeddedComputation(cond_builder.Build());
3015 
3016   auto body_builder = HloComputation::Builder("WhileBody");
3017   HloInstruction* body_param = body_builder.AddInstruction(
3018       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
3019   HloInstruction* body_iter = body_builder.AddInstruction(
3020       HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
3021   HloInstruction* body_data = body_builder.AddInstruction(
3022       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
3023   HloInstruction* body_negate0 = body_builder.AddInstruction(
3024       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_data));
3025   HloInstruction* body_negate1 = body_builder.AddInstruction(
3026       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate0));
3027   HloInstruction* body_negate2 = body_builder.AddInstruction(
3028       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate1));
3029   HloInstruction* body_negate3 = body_builder.AddInstruction(
3030       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate2));
3031   HloInstruction* body_negate4 = body_builder.AddInstruction(
3032       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate3));
3033   HloInstruction* body_negate5 = body_builder.AddInstruction(
3034       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate4));
3035   HloInstruction* body_negate6 = body_builder.AddInstruction(
3036       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate5));
3037   HloInstruction* body_negate7 = body_builder.AddInstruction(
3038       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate6));
3039   HloInstruction* body_iter_increment = body_builder.AddInstruction(
3040       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
3041   HloInstruction* body_iter_next =
3042       body_builder.AddInstruction(HloInstruction::CreateBinary(
3043           scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
3044   HloInstruction* body_out = body_builder.AddInstruction(
3045       HloInstruction::CreateTuple({body_data, body_iter_next, body_negate7}));
3046   HloComputation* body_computation =
3047       module->AddEmbeddedComputation(body_builder.Build());
3048 
3049   auto builder = HloComputation::Builder(TestName());
3050   HloInstruction* data = builder.AddInstruction(
3051       HloInstruction::CreateParameter(0, shape, "param_data"));
3052   HloInstruction* iter = builder.AddInstruction(
3053       HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
3054   HloInstruction* negate0 = builder.AddInstruction(
3055       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
3056   HloInstruction* negate1 = builder.AddInstruction(
3057       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3058   HloInstruction* negate2 = builder.AddInstruction(
3059       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3060   HloInstruction* negate3 = builder.AddInstruction(
3061       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3062   HloInstruction* negate4 = builder.AddInstruction(
3063       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3064   HloInstruction* negate5 = builder.AddInstruction(
3065       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3066   HloInstruction* negate6 = builder.AddInstruction(
3067       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3068   HloInstruction* negate7 = builder.AddInstruction(
3069       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
3070   HloInstruction* tuple = builder.AddInstruction(
3071       HloInstruction::CreateTuple({data, iter, negate7}));
3072   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
3073       tuple_shape, cond_computation, body_computation, tuple));
3074   HloInstruction* while_data = builder.AddInstruction(
3075       HloInstruction::CreateGetTupleElement(shape, while_op, 0));
3076   HloInstruction* while_data2 = builder.AddInstruction(
3077       HloInstruction::CreateGetTupleElement(shape, while_op, 2));
3078   HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
3079       shape, HloOpcode::kAdd, while_data, while_data2));
3080   HloComputation* entry_computation =
3081       module->AddEntryComputation(builder.Build());
3082 
3083   HloSchedule schedule(module.get());
3084   schedule.set_sequence(cond_computation,
3085                         {cond_param, cond_iter, cond_limit, cond_lt});
3086   schedule.set_sequence(
3087       body_computation,
3088       {body_param, body_iter, body_data, body_negate0, body_negate1,
3089        body_negate2, body_negate3, body_negate4, body_negate5, body_negate6,
3090        body_negate7, body_iter_increment, body_iter_next, body_out});
3091   schedule.set_sequence(
3092       entry_computation,
3093       {iter, data, negate0, negate1, negate2, negate3, negate4, negate5,
3094        negate6, negate7, tuple, while_op, while_data, while_data2, root});
3095   TF_CHECK_OK(module->set_schedule(schedule));
3096 
3097   // Pick a large max prefetch interval to ensure all the while inputs are
3098   // allocated in the alternate memory.
3099   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
3100                     /*max_prefetch_interval=*/25);
3101 
3102   // Index {0} of the while loop argument is not written inside the while loop,
3103   // so it can be trivially placed in the alternate memory space.
3104   *ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() =
3105       LayoutUtil::MakeLayout(
3106           /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3107           /*element_size_in_bits=*/0, kAlternateMemorySpace);
3108   // Index {1} is a scalar, so it is always placed in the default memory.
3109   *ShapeUtil::GetMutableSubshape(&tuple_shape, {1})->mutable_layout() =
3110       LayoutUtil::MakeLayout(
3111           /*minor_to_major=*/{}, /*dim_level_types=*/{}, /*tiles=*/{},
3112           /*element_size_in_bits=*/0, kDefaultMemorySpace);
3113   // Index {2} of the while loop is placed in the default memory.
3114   *ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() =
3115       LayoutUtil::MakeLayout(
3116           /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3117           /*element_size_in_bits=*/0, kDefaultMemorySpace);
3118 
3119   // Expect the layout for the while loop and its aliased buffers.
3120   EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape));
3121   EXPECT_THAT(while_op->operand(0), op::ShapeWithLayout(tuple_shape));
3122   EXPECT_THAT(cond_param, op::ShapeWithLayout(tuple_shape));
3123   EXPECT_THAT(body_param, op::ShapeWithLayout(tuple_shape));
3124   EXPECT_THAT(body_out, op::ShapeWithLayout(tuple_shape));
3125 }
3126 
TEST_P(MemorySpaceAssignmentTest,DanglingCopy)3127 TEST_P(MemorySpaceAssignmentTest, DanglingCopy) {
3128   // This situation was encountered in vss, where there is a mismatch in the
3129   // memory space in preset assignments and the output graph.
3130   HloComputation::Builder builder(TestName());
3131   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3132   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3133 
3134   HloInstruction* p = builder.AddInstruction(
3135       HloInstruction::CreateParameter(0, tuple_shape, "p"));
3136   HloInstruction* p0 = builder.AddInstruction(
3137       HloInstruction::CreateGetTupleElement(shape, p, 0));
3138   HloInstruction* p1a = builder.AddInstruction(
3139       HloInstruction::CreateGetTupleElement(shape, p, 1));
3140   HloInstruction* copy = builder.AddInstruction(
3141       HloInstruction::CreateUnary(shape, HloOpcode::kCopy, p1a));
3142   HloInstruction* negate0 = builder.AddInstruction(
3143       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3144   HloInstruction* negate1 = builder.AddInstruction(
3145       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3146   HloInstruction* negate2 = builder.AddInstruction(
3147       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3148   HloInstruction* negate3 = builder.AddInstruction(
3149       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3150   HloInstruction* negate4 = builder.AddInstruction(
3151       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3152   HloInstruction* negate5 = builder.AddInstruction(
3153       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3154   HloInstruction* negate6 = builder.AddInstruction(
3155       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3156   HloInstruction* p1b = builder.AddInstruction(
3157       HloInstruction::CreateGetTupleElement(shape, p, 1));
3158   HloInstruction* add = builder.AddInstruction(
3159       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1b));
3160 
3161   auto module = CreateNewVerifiedModule();
3162   HloComputation* computation = module->AddEntryComputation(builder.Build());
3163 
3164   HloSchedule schedule(module.get());
3165   schedule.set_sequence(
3166       computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
3167                     negate6, p1a, copy, p1b, add});
3168   TF_CHECK_OK(module->set_schedule(schedule));
3169 
3170   AssignMemorySpace(module.get());
3171 }
3172 
TEST_P(MemorySpaceAssignmentTest,MultiOutputFusion)3173 TEST_P(MemorySpaceAssignmentTest, MultiOutputFusion) {
3174   HloComputation::Builder builder(TestName());
3175   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3176   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3177   auto module = CreateNewVerifiedModule();
3178 
3179   HloComputation::Builder fusion_builder("fusion");
3180   HloInstruction* fusion_param0 = fusion_builder.AddInstruction(
3181       HloInstruction::CreateParameter(0, shape, "p0"));
3182   HloInstruction* fusion_param1 = fusion_builder.AddInstruction(
3183       HloInstruction::CreateParameter(1, shape, "p1"));
3184   fusion_builder.AddInstruction(
3185       HloInstruction::CreateTuple({fusion_param0, fusion_param1}));
3186   HloComputation* fusion_computation =
3187       module->AddEmbeddedComputation(fusion_builder.Build());
3188 
3189   HloInstruction* p0 =
3190       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3191   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
3192       tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3193       fusion_computation));
3194   HloInstruction* element0 = builder.AddInstruction(
3195       HloInstruction::CreateGetTupleElement(shape, fusion, 0));
3196   HloInstruction* element1 = builder.AddInstruction(
3197       HloInstruction::CreateGetTupleElement(shape, fusion, 1));
3198   HloInstruction* add = builder.AddInstruction(
3199       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1));
3200 
3201   HloComputation* computation = module->AddEntryComputation(builder.Build());
3202 
3203   HloSchedule schedule(module.get());
3204   schedule.set_sequence(computation, {p0, fusion, element0, element1, add});
3205   TF_CHECK_OK(module->set_schedule(schedule));
3206 
3207   AssignMemorySpace(module.get());
3208 }
3209 
TEST_P(MemorySpaceAssignmentTest,TupleInput)3210 TEST_P(MemorySpaceAssignmentTest, TupleInput) {
3211   HloComputation::Builder builder(TestName());
3212   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3213   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3214   auto module = CreateNewVerifiedModule();
3215 
3216   HloComputation::Builder fusion_builder("fusion");
3217   HloInstruction* fusion_param = fusion_builder.AddInstruction(
3218       HloInstruction::CreateParameter(0, tuple_shape, "p"));
3219   HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
3220       HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
3221   HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
3222       HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
3223   fusion_builder.AddInstruction(HloInstruction::CreateBinary(
3224       shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
3225   HloComputation* fusion_computation =
3226       module->AddEmbeddedComputation(fusion_builder.Build());
3227 
3228   HloInstruction* p0 =
3229       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3230   HloInstruction* p1 =
3231       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3232   HloInstruction* negate0 = builder.AddInstruction(
3233       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3234   HloInstruction* negate1 = builder.AddInstruction(
3235       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1));
3236   HloInstruction* tuple =
3237       builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
3238   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
3239       shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
3240 
3241   HloComputation* computation = module->AddEntryComputation(builder.Build());
3242 
3243   HloSchedule schedule(module.get());
3244   schedule.set_sequence(computation, {p0, p1, negate0, negate1, tuple, fusion});
3245   TF_CHECK_OK(module->set_schedule(schedule));
3246 
3247   AssignMemorySpace(module.get());
3248 }
3249 
TEST_P(MemorySpaceAssignmentTest,TupleToTuple1)3250 TEST_P(MemorySpaceAssignmentTest, TupleToTuple1) {
3251   HloComputation::Builder builder(TestName());
3252   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3253   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3254   auto module = CreateNewVerifiedModule();
3255 
3256   HloComputation::Builder fusion0_builder("fusion0");
3257   HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3258       HloInstruction::CreateParameter(0, shape, "p0"));
3259   HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3260       HloInstruction::CreateParameter(1, shape, "p1"));
3261   fusion0_builder.AddInstruction(
3262       HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3263   HloComputation* fusion0_computation =
3264       module->AddEmbeddedComputation(fusion0_builder.Build());
3265 
3266   HloComputation::Builder fusion1_builder("fusion1");
3267   HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3268       HloInstruction::CreateParameter(0, tuple_shape, "p"));
3269   HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3270       HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3271   HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3272       HloInstruction::CreateGetTupleElement(shape, fusion1_param, 1));
3273   fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3274       shape, HloOpcode::kAdd, fusion1_element0, fusion1_element1));
3275   HloComputation* fusion1_computation =
3276       module->AddEmbeddedComputation(fusion1_builder.Build());
3277 
3278   HloInstruction* p0 =
3279       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3280   HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3281       tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3282       fusion0_computation));
3283   HloInstruction* element0 = builder.AddInstruction(
3284       HloInstruction::CreateGetTupleElement(shape, fusion0, 0));
3285   HloInstruction* element1 = builder.AddInstruction(
3286       HloInstruction::CreateGetTupleElement(shape, fusion0, 1));
3287   HloInstruction* negate0 = builder.AddInstruction(
3288       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3289   HloInstruction* negate1 = builder.AddInstruction(
3290       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3291   HloInstruction* negate2 = builder.AddInstruction(
3292       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3293   HloInstruction* negate3 = builder.AddInstruction(
3294       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3295   HloInstruction* negate4 = builder.AddInstruction(
3296       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3297   HloInstruction* negate5 = builder.AddInstruction(
3298       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3299   HloInstruction* negate6 = builder.AddInstruction(
3300       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3301   HloInstruction* add0 = builder.AddInstruction(
3302       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1));
3303   HloInstruction* add1 = builder.AddInstruction(
3304       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, negate6));
3305   HloInstruction* fusion1 = builder.AddInstruction(
3306       HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3307                                    {fusion0}, fusion1_computation));
3308   HloInstruction* mul = builder.AddInstruction(
3309       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add1, fusion1));
3310 
3311   HloComputation* computation = module->AddEntryComputation(builder.Build());
3312 
3313   HloSchedule schedule(module.get());
3314   schedule.set_sequence(
3315       computation,
3316       {p0, fusion0, element0, element1, negate0, negate1, negate2, negate3,
3317        negate4, negate5, negate6, add0, add1, fusion1, mul});
3318   TF_CHECK_OK(module->set_schedule(schedule));
3319 
3320   AssignMemorySpace(module.get(), -1, 5);
3321   EXPECT_THAT(fusion1,
3322               op::Fusion(op::Tuple(
3323                   op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3324                                 op::GetTupleElement(op::Fusion(), 0)),
3325                   op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3326                                 op::GetTupleElement(op::Fusion(), 1)))));
3327 }
3328 
TEST_P(MemorySpaceAssignmentTest,TupleToTuple2)3329 TEST_P(MemorySpaceAssignmentTest, TupleToTuple2) {
3330   HloComputation::Builder builder(TestName());
3331   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3332   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3333   Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({shape, tuple_shape});
3334   auto module = CreateNewVerifiedModule();
3335 
3336   HloComputation::Builder fusion0_builder("fusion0");
3337   HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3338       HloInstruction::CreateParameter(0, shape, "p0"));
3339   HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3340       HloInstruction::CreateParameter(1, shape, "p1"));
3341   HloInstruction* fusion0_tuple = fusion0_builder.AddInstruction(
3342       HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3343   fusion0_builder.AddInstruction(
3344       HloInstruction::CreateTuple({fusion0_param0, fusion0_tuple}));
3345   HloComputation* fusion0_computation =
3346       module->AddEmbeddedComputation(fusion0_builder.Build());
3347 
3348   HloComputation::Builder fusion1_builder("fusion1");
3349   HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3350       HloInstruction::CreateParameter(0, nested_tuple_shape, "p"));
3351   HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3352       HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3353   HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3354       HloInstruction::CreateGetTupleElement(tuple_shape, fusion1_param, 1));
3355   HloInstruction* fusion1_element2 = fusion1_builder.AddInstruction(
3356       HloInstruction::CreateGetTupleElement(shape, fusion1_element1, 1));
3357   fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3358       shape, HloOpcode::kAdd, fusion1_element0, fusion1_element2));
3359   HloComputation* fusion1_computation =
3360       module->AddEmbeddedComputation(fusion1_builder.Build());
3361 
3362   HloInstruction* p0 =
3363       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3364   HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3365       nested_tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3366       fusion0_computation));
3367   HloInstruction* negate0 = builder.AddInstruction(
3368       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3369   HloInstruction* negate1 = builder.AddInstruction(
3370       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3371   HloInstruction* negate2 = builder.AddInstruction(
3372       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3373   HloInstruction* negate3 = builder.AddInstruction(
3374       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3375   HloInstruction* negate4 = builder.AddInstruction(
3376       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3377   HloInstruction* negate5 = builder.AddInstruction(
3378       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3379   HloInstruction* negate6 = builder.AddInstruction(
3380       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3381   HloInstruction* fusion1 = builder.AddInstruction(
3382       HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3383                                    {fusion0}, fusion1_computation));
3384 
3385   HloComputation* computation = module->AddEntryComputation(builder.Build());
3386 
3387   HloSchedule schedule(module.get());
3388   schedule.set_sequence(
3389       computation, {p0, fusion0, negate0, negate1, negate2, negate3, negate4,
3390                     negate5, negate6, fusion1});
3391   TF_CHECK_OK(module->set_schedule(schedule));
3392 
3393   AssignMemorySpace(module.get(), -1, 5);
3394 
3395   EXPECT_THAT(
3396       fusion1,
3397       op::Fusion(op::Tuple(
3398           op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3399                         op::GetTupleElement(op::Fusion(), 0)),
3400           op::Tuple(
3401               op::AsyncCopy(
3402                   kAlternateMemorySpace, kDefaultMemorySpace,
3403                   op::GetTupleElement(op::GetTupleElement(op::Fusion(), 1), 0)),
3404               op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3405                             op::GetTupleElement(
3406                                 op::GetTupleElement(op::Fusion(), 1), 1))))));
3407 }
3408 
TEST_P(MemorySpaceAssignmentTest,TupleToTuple3)3409 TEST_P(MemorySpaceAssignmentTest, TupleToTuple3) {
3410   HloComputation::Builder builder(TestName());
3411   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3412   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3413   auto module = CreateNewVerifiedModule();
3414 
3415   HloComputation::Builder fusion0_builder("fusion0");
3416   HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3417       HloInstruction::CreateParameter(0, shape, "p0"));
3418   HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3419       HloInstruction::CreateParameter(1, shape, "p1"));
3420   fusion0_builder.AddInstruction(
3421       HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3422   HloComputation* fusion0_computation =
3423       module->AddEmbeddedComputation(fusion0_builder.Build());
3424 
3425   HloComputation::Builder fusion1_builder("fusion1");
3426   HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3427       HloInstruction::CreateParameter(0, tuple_shape, "p"));
3428   HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3429       HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3430   HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3431       HloInstruction::CreateGetTupleElement(shape, fusion1_param, 1));
3432   fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3433       shape, HloOpcode::kAdd, fusion1_element0, fusion1_element1));
3434   HloComputation* fusion1_computation =
3435       module->AddEmbeddedComputation(fusion1_builder.Build());
3436 
3437   HloInstruction* p0 =
3438       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3439   HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3440       tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3441       fusion0_computation));
3442   HloInstruction* fusion1 = builder.AddInstruction(
3443       HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3444                                    {fusion0}, fusion1_computation));
3445 
3446   HloComputation* computation = module->AddEntryComputation(builder.Build());
3447 
3448   HloSchedule schedule(module.get());
3449   schedule.set_sequence(computation, {p0, fusion0, fusion1});
3450   TF_CHECK_OK(module->set_schedule(schedule));
3451 
3452   AssignMemorySpace(module.get());
3453   EXPECT_THAT(fusion1, op::Fusion(op::Fusion()));
3454 }
3455 
TEST_P(MemorySpaceAssignmentTest,InputOutputAlias)3456 TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) {
3457   HloComputation::Builder builder(TestName());
3458   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3459   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3460   HloInstruction* p = builder.AddInstruction(
3461       HloInstruction::CreateParameter(0, tuple_shape, "p"));
3462   HloInstruction* p0 = builder.AddInstruction(
3463       HloInstruction::CreateGetTupleElement(shape, p, 0));
3464   HloInstruction* negate0 = builder.AddInstruction(
3465       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3466   HloInstruction* negate1 = builder.AddInstruction(
3467       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3468   HloInstruction* negate2 = builder.AddInstruction(
3469       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3470   HloInstruction* negate3 = builder.AddInstruction(
3471       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3472   HloInstruction* negate4 = builder.AddInstruction(
3473       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3474   HloInstruction* negate5 = builder.AddInstruction(
3475       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3476   HloInstruction* negate6 = builder.AddInstruction(
3477       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3478   HloInstruction* p1 = builder.AddInstruction(
3479       HloInstruction::CreateGetTupleElement(shape, p, 1));
3480   HloInstruction* add = builder.AddInstruction(
3481       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
3482   HloInstruction* negate7 = builder.AddInstruction(
3483       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add));
3484   HloInstruction* tuple =
3485       builder.AddInstruction(HloInstruction::CreateTuple({p0, add}));
3486 
3487   auto module = CreateNewVerifiedModule();
3488   HloComputation* computation = module->AddEntryComputation(builder.Build());
3489 
3490   HloSchedule schedule(module.get());
3491   schedule.set_sequence(
3492       computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
3493                     negate6, p1, add, negate7, tuple});
3494   TF_CHECK_OK(module->set_schedule(schedule));
3495 
3496   // Make input {0} alias with output {0} and input {1} alias with output {1}.
3497   TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
3498   TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
3499 
3500   AssignMemorySpace(module.get());
3501 
3502   // Make sure the input is in the default memory space.
3503   EXPECT_EQ(p->shape().tuple_shapes(0).layout().memory_space(),
3504             kDefaultMemorySpace);
3505   EXPECT_EQ(p->shape().tuple_shapes(1).layout().memory_space(),
3506             kDefaultMemorySpace);
3507 }
3508 
TEST_P(MemorySpaceAssignmentTest,CostAnalysis)3509 TEST_P(MemorySpaceAssignmentTest, CostAnalysis) {
3510   // This is mostly a smoke test since it's difficult and brittle to work out
3511   // the cost of the HLO instructions.
3512   HloComputation::Builder builder(TestName());
3513   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3514   HloInstruction* p0 =
3515       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3516   HloInstruction* p1 =
3517       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3518   HloInstruction* negate0 = builder.AddInstruction(
3519       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3520   HloInstruction* negate1 = builder.AddInstruction(
3521       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3522   HloInstruction* negate2 = builder.AddInstruction(
3523       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3524   HloInstruction* negate3 = builder.AddInstruction(
3525       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3526   HloInstruction* negate4 = builder.AddInstruction(
3527       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3528   HloInstruction* negate5 = builder.AddInstruction(
3529       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3530   HloInstruction* negate6 = builder.AddInstruction(
3531       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3532   HloInstruction* add = builder.AddInstruction(
3533       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
3534 
3535   auto module = CreateNewVerifiedModule();
3536   HloComputation* computation = module->AddEntryComputation(builder.Build());
3537 
3538   HloSchedule schedule(module.get());
3539   schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
3540                                       negate3, negate4, negate5, negate6, add});
3541   TF_CHECK_OK(module->set_schedule(schedule));
3542 
3543   AssignMemorySpaceUsingCostAnalysis(module.get());
3544   // Parameters are in the default memory space.
3545   EXPECT_THAT(p0, op::ShapeWithLayout(shape));
3546   EXPECT_THAT(p1, op::ShapeWithLayout(shape));
3547   // Negate instructions are in the alternate memory space (1).
3548   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
3549       F32, {2, 3},
3550       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3551       /*element_size_in_bits=*/0, kAlternateMemorySpace);
3552   EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
3553   EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
3554   EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
3555   EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
3556   EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
3557   EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem));
3558   EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
3559 }
3560 
TEST_P(MemorySpaceAssignmentTest,MemoryBoundednessBufferIntervalCompare)3561 TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
3562   // This test is carefully crafted to force only negates to be allocated to the
3563   // alternate memory. The graph consists of interleaving negate and tanh
3564   // operations:
3565   //
3566   //        +------+      +-------+      +-----
3567   //       /        \    /         \    /
3568   //  negate  tanh  negate  tanh   negate  tanh
3569   //             \          /  \           /
3570   //              +--------+    +---------+
3571   //
3572   // The alternate memory is sized to fit only two f32[4,3] tensors at a time.
3573   // Also, transcendentals are made to be lower bandwidth than FLOPs. So, the
3574   // MemoryBoundednessBufferIntervalCompare should prioritize the negates, which
3575   // are more memory bound.
3576   HloComputation::Builder builder(TestName());
3577   Shape shape = ShapeUtil::MakeShape(F32, {4, 3});
3578   HloInstruction* p0 =
3579       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3580   HloInstruction* p1 =
3581       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3582   HloInstruction* tanh0 = builder.AddInstruction(
3583       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3584   HloInstruction* negate0 = builder.AddInstruction(
3585       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1));
3586   HloInstruction* tanh1 = builder.AddInstruction(
3587       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh0));
3588   HloInstruction* negate1 = builder.AddInstruction(
3589       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3590   HloInstruction* tanh2 = builder.AddInstruction(
3591       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
3592   HloInstruction* negate2 = builder.AddInstruction(
3593       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3594   HloInstruction* tanh3 = builder.AddInstruction(
3595       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
3596   HloInstruction* negate3 = builder.AddInstruction(
3597       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3598   HloInstruction* tanh4 = builder.AddInstruction(
3599       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh3));
3600   HloInstruction* negate4 = builder.AddInstruction(
3601       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3602   HloInstruction* tuple =
3603       builder.AddInstruction(HloInstruction::CreateTuple({tanh4, negate4}));
3604 
3605   auto module = CreateNewVerifiedModule();
3606   HloComputation* computation = module->AddEntryComputation(builder.Build());
3607 
3608   HloSchedule schedule(module.get());
3609   schedule.set_sequence(computation,
3610                         {p0, p1, tanh0, negate0, tanh1, negate1, tanh2, negate2,
3611                          tanh3, negate3, tanh4, negate4, tuple});
3612   TF_CHECK_OK(module->set_schedule(schedule));
3613 
3614   AssignMemorySpaceUsingCostAnalysis(module.get());
3615   // Parameters are in the default memory space.
3616   EXPECT_THAT(p0, op::ShapeWithLayout(shape));
3617   EXPECT_THAT(p1, op::ShapeWithLayout(shape));
3618   Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3619       F32, {4, 3},
3620       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3621       /*element_size_in_bits=*/0, kDefaultMemorySpace);
3622   // Expect only negates to be in alternate memory space. Not all might fit but
3623   // make sure at least one does.
3624   std::vector<HloInstruction*> negate_instructions = {negate0, negate1, negate2,
3625                                                       negate3, negate4};
3626   int64_t num_negates_in_alternate_mem = absl::c_count_if(
3627       negate_instructions, [&](const HloInstruction* instruction) {
3628         return instruction->shape().layout().memory_space() ==
3629                kAlternateMemorySpace;
3630       });
3631   EXPECT_GE(num_negates_in_alternate_mem, 1);
3632   EXPECT_THAT(tanh0, op::ShapeWithLayout(shape_in_default_mem));
3633   EXPECT_THAT(tanh1, op::ShapeWithLayout(shape_in_default_mem));
3634   EXPECT_THAT(tanh2, op::ShapeWithLayout(shape_in_default_mem));
3635   EXPECT_THAT(tanh3, op::ShapeWithLayout(shape_in_default_mem));
3636   EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem));
3637 }
3638 
TEST_P(MemorySpaceAssignmentTest,SimpleWhileTupleTest)3639 TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) {
3640   Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
3641   Shape f32v1 = ShapeUtil::MakeShape(F32, {1});
3642   Shape t_s32_f32v1 = ShapeUtil::MakeTupleShape({s32, f32v1});
3643   auto module = CreateNewVerifiedModule("SimpleWhile");
3644   HloSchedule schedule(module.get());
3645 
3646   // A simple compare-to-limit (x < 4) computation for a While.
3647   //
3648   // condition:
3649   //   const4[s32] -----------------------------------\
3650   //                                                   \
3651   //   param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
3652   //
3653   HloComputation* cond_computation;
3654   {
3655     auto builder = HloComputation::Builder("WhileCond");
3656     auto const4 = builder.AddInstruction(
3657         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
3658     auto param = builder.AddInstruction(
3659         HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
3660     auto index = builder.AddInstruction(
3661         HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
3662     auto compare = builder.AddInstruction(
3663         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
3664                                       const4, ComparisonDirection::kLt));
3665     cond_computation = module->AddEmbeddedComputation(builder.Build());
3666     schedule.set_sequence(cond_computation, {const4, param, index, compare});
3667   }
3668 
3669   // Builds a simple body computation for a While.
3670   //
3671   // body:
3672   //   constv[f32[1]] --------------------------------------\
3673   //                                                         \
3674   //                           /--- get-tuple-elementv[1] --- addv ---\
3675   //   param[(s32,f32[1])] ---|                                    tuple
3676   //                           \--- get-tuple-elementc[0] --- addc ---/
3677   //                                                         /
3678   //   const1[s32] -----------------------------------------/
3679   //
3680   HloComputation* body_computation;
3681   {
3682     auto builder = HloComputation::Builder("WhileBody");
3683     auto const1 = builder.AddInstruction(
3684         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
3685     auto constv = builder.AddInstruction(
3686         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.1f})));
3687     auto param = builder.AddInstruction(
3688         HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
3689     auto indexc = builder.AddInstruction(
3690         HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
3691     auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
3692         indexc->shape(), HloOpcode::kAdd, indexc, const1));
3693     auto indexv = builder.AddInstruction(
3694         HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
3695     auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
3696         constv->shape(), HloOpcode::kAdd, indexv, constv));
3697     auto tuple =
3698         builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
3699     body_computation = module->AddEmbeddedComputation(builder.Build());
3700     schedule.set_sequence(body_computation, {const1, constv, param, indexc,
3701                                              addc, indexv, addv, tuple});
3702   }
3703 
3704   // This tests a simple while loop where the parameters are aliased with the
3705   // output buffers.
3706   auto builder = HloComputation::Builder("SimpleWhile");
3707   auto param = builder.AddInstruction(
3708       HloInstruction::CreateParameter(0, t_s32_f32v1, "param"));
3709   auto gte0 = builder.AddInstruction(
3710       HloInstruction::CreateGetTupleElement(s32, param, 0));
3711   auto gte1 = builder.AddInstruction(
3712       HloInstruction::CreateGetTupleElement(f32v1, param, 1));
3713   auto tuple =
3714       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
3715   auto while0 = builder.AddInstruction(HloInstruction::CreateWhile(
3716       t_s32_f32v1, cond_computation, body_computation, tuple));
3717 
3718   HloComputation* computation = module->AddEntryComputation(builder.Build());
3719   schedule.set_sequence(computation, {param, gte0, gte1, tuple, while0});
3720   TF_CHECK_OK(module->set_schedule(schedule));
3721 
3722   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
3723                     /*max_prefetch_interval=*/50);
3724 
3725   // Ensure all parameters and while are placed in default memory.
3726   Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3727       F32, {4, 6},
3728       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3729       /*element_size_in_bits=*/0, kDefaultMemorySpace);
3730   Shape s32_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3731       xla::S32, {},
3732       /*minor_to_major=*/{}, /*dim_level_types=*/{}, /*tiles=*/{},
3733       /*element_size_in_bits=*/0, kDefaultMemorySpace);
3734   Shape f32v1_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3735       F32, {1},
3736       /*minor_to_major=*/{0}, /*dim_level_types=*/{}, /*tiles=*/{},
3737       /*element_size_in_bits=*/0, kDefaultMemorySpace);
3738   Shape t_s32_f32v1_in_default_mem =
3739       ShapeUtil::MakeTupleShape({s32_in_default_mem, f32v1_in_default_mem});
3740   EXPECT_THAT(param, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
3741   EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
3742 }
3743 
TEST_P(MemorySpaceAssignmentTest,EvictionsShouldntBeDelayed)3744 TEST_P(MemorySpaceAssignmentTest, EvictionsShouldntBeDelayed) {
3745   // This test reproduces an eviction scheduling bug where evictions to default
3746   // memory can happen later than intended, causing memory corruption. This test
3747   // is a variant of MemoryBoundednessBufferIntervalCompare but uses f32[4,3]
3748   // tensors instead, so at most two tensors should fit in the alternate memory
3749   // space at a given time. We have a number of redundant operations
3750   // (tanh_redundant ops) that do not have users. The bug was due to
3751   // SimplifyGraph removing dead instructions, and removing them from the
3752   // schedule. However, the CopyStart/CopyDone insertion relies on the schedule
3753   // indexes, so they could be inserted too late.
3754   HloComputation::Builder builder(TestName());
3755   Shape shape = ShapeUtil::MakeShape(F32, {4, 3});
3756   HloInstruction* p0 =
3757       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3758   HloInstruction* tanh0 = builder.AddInstruction(
3759       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3760   HloInstruction* tanh_redundant0 = builder.AddInstruction(
3761       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3762   HloInstruction* tanh_redundant1 = builder.AddInstruction(
3763       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3764   HloInstruction* tanh_redundant2 = builder.AddInstruction(
3765       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3766   HloInstruction* tanh_redundant3 = builder.AddInstruction(
3767       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3768   HloInstruction* tanh_redundant4 = builder.AddInstruction(
3769       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3770   HloInstruction* tanh_redundant5 = builder.AddInstruction(
3771       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3772   HloInstruction* tanh_redundant6 = builder.AddInstruction(
3773       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3774   HloInstruction* negate0 = builder.AddInstruction(
3775       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, tanh0));
3776   HloInstruction* tanh1 = builder.AddInstruction(
3777       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, negate0));
3778   HloInstruction* negate1 = builder.AddInstruction(
3779       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3780   HloInstruction* tanh2 = builder.AddInstruction(
3781       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
3782   HloInstruction* negate2 = builder.AddInstruction(
3783       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3784   HloInstruction* tanh3 = builder.AddInstruction(
3785       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
3786   HloInstruction* negate3 = builder.AddInstruction(
3787       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3788   HloInstruction* tuple = builder.AddInstruction(
3789       HloInstruction::CreateTuple({tanh3, negate3, tanh0}));
3790 
3791   auto module = CreateNewVerifiedModule();
3792   HloComputation* computation = module->AddEntryComputation(builder.Build());
3793 
3794   HloSchedule schedule(module.get());
3795   schedule.set_sequence(
3796       computation,
3797       {p0, tanh0, tanh_redundant0, tanh_redundant1, tanh_redundant2,
3798        tanh_redundant3, tanh_redundant4, tanh_redundant5, tanh_redundant6,
3799        negate0, tanh1, negate1, tanh2, negate2, tanh3, negate3, tuple});
3800   TF_CHECK_OK(module->set_schedule(schedule));
3801 
3802   AssignMemorySpaceUsingCostAnalysis(module.get());
3803 
3804   TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis,
3805                           HloAliasAnalysis::Run(module.get()));
3806   TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range,
3807                           HloLiveRange::Run(module->schedule(), *alias_analysis,
3808                                             module->entry_computation()));
3809 
3810   std::vector<int> num_live_buffers_in_alternate_mem(
3811       hlo_live_range->flattened_instruction_sequence().size() + 1, 0);
3812 
3813   // Go through each value and for those that are allocated in the alternate
3814   // memory space, increment (inclusive) num_live_buffers_in_alternate_mem for
3815   // every time step that they are live.
3816   for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
3817     const Shape& shape = value->shape();
3818     if (!shape.has_layout() ||
3819         shape.layout().memory_space() == kDefaultMemorySpace) {
3820       continue;
3821     }
3822 
3823     HloLiveRange::TimeBound time_bound =
3824         hlo_live_range->buffer_live_ranges().at(value);
3825     for (int i = time_bound.start; i <= time_bound.end; ++i) {
3826       ++num_live_buffers_in_alternate_mem[i];
3827     }
3828   }
3829 
3830   // The test memory can at most hold two f32[4,3] buffers at a time. If there
3831   // is more than that, it means we have memory corruption.
3832   for (int i = 0; i < num_live_buffers_in_alternate_mem.size(); ++i) {
3833     EXPECT_LE(num_live_buffers_in_alternate_mem[i], 2);
3834   }
3835 }
3836 
TEST_P(MemorySpaceAssignmentTest,InputOutputsInAlternateMemShouldntBeAssigned)3837 TEST_P(MemorySpaceAssignmentTest,
3838        InputOutputsInAlternateMemShouldntBeAssigned) {
3839   // When input/outputs are marked to be in the alternate memory (e.g.
3840   // go/tpu-fast-mem-inference), do not allocate those and assume they will live
3841   // in the alternate memory for the entire computation. The BufferAssignment
3842   // pass, which is run after this, will allocate those buffers.
3843   HloComputation::Builder builder(TestName());
3844   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3845   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
3846       F32, {2, 3},
3847       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
3848       /*element_size_in_bits=*/0, kAlternateMemorySpace);
3849   // p0 is in the default memory space.
3850   HloInstruction* p0 =
3851       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3852   // p1 is in the alternate memory space.
3853   HloInstruction* p1 = builder.AddInstruction(
3854       HloInstruction::CreateParameter(1, shape_in_alternate_mem, "p1"));
3855   HloInstruction* negate0 = builder.AddInstruction(
3856       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3857   HloInstruction* negate1 = builder.AddInstruction(
3858       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3859   HloInstruction* negate2 = builder.AddInstruction(
3860       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3861   HloInstruction* negate3 = builder.AddInstruction(
3862       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3863   HloInstruction* negate4 = builder.AddInstruction(
3864       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3865   HloInstruction* negate5 = builder.AddInstruction(
3866       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3867   HloInstruction* negate6 = builder.AddInstruction(
3868       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3869   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
3870       shape_in_alternate_mem, HloOpcode::kAdd, negate6, p1));
3871   // Index {0} of the root instruction is in the alternate memory space, index
3872   // {1} is in the default memory space.
3873   HloInstruction* tuple =
3874       builder.AddInstruction(HloInstruction::CreateTuple({add, negate5}));
3875 
3876   auto module = CreateNewVerifiedModule();
3877   HloComputation* computation = module->AddEntryComputation(builder.Build());
3878 
3879   HloSchedule schedule(module.get());
3880   schedule.set_sequence(computation,
3881                         {p0, p1, negate0, negate1, negate2, negate3, negate4,
3882                          negate5, negate6, add, tuple});
3883   TF_CHECK_OK(module->set_schedule(schedule));
3884 
3885   Options options;
3886   options.max_size_in_bytes = 128;
3887   options.alignment_in_bytes = 8;
3888   options.verify = true;
3889   options.is_allowed_in_alternate_mem_fn = [](const HloValue& value) {
3890     return true;
3891   };
3892   std::unique_ptr<PresetAssignments> preset_assignments = AssignMemorySpace(
3893       module.get(),
3894       /*max_outstanding_async_copies=*/-1,
3895       /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2, options);
3896 
3897   // Ensure that p1 is in the alternate memory and add, which has p1 as an
3898   // operand, has a direct dependency to p1 (no CopyStart/CopyDone).
3899   EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_alternate_mem));
3900   EXPECT_THAT(add, op::Add(op::Negate(), op::Parameter(1)));
3901   // Make sure add is still in the alternate memory space.
3902   EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
3903 
3904   // Check the preset assignments and ensure the inputs/outputs in the alternate
3905   // memory space aren't in the preset assignments. Inputs/outputs in the
3906   // alternate memory space are left to BufferAssignment to be allocated.
3907   for (const auto& position_and_chunk : preset_assignments->chunks()) {
3908     const HloPosition& position = position_and_chunk.first;
3909     EXPECT_NE(position.instruction, p1);
3910     EXPECT_NE(position.instruction, add);
3911   }
3912 }
3913 
TEST_P(MemorySpaceAssignmentTest,PendingChunkMemoryCorruptionBug)3914 TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) {
3915   // Tests a memory corruption bug where the allocated chunk overlaps with a
3916   // pending chunk. To test this, we provide a new buffer interval compare where
3917   // we prioritize the allocation of sine, cosine, and tanh to create the
3918   // situation:
3919   //
3920   //    Max memory
3921   //  -------------------------------------------
3922   //      +------------+
3923   //      |     b      |
3924   //      +------------+
3925   //  +-------+
3926   //  |       |
3927   //  |       |
3928   //  |   a   |
3929   //  |       |                 +------------+
3930   //  |       |                 |     n      |
3931   //  +-------+                 +------------+
3932   //  -------------------------------------------
3933   //    Min memory          time ->
3934   //
3935   //
3936   // Then allocating for buffer d, we have these two prefetch buffers
3937   // overlapping:
3938   //
3939   //    Max memory
3940   //  -------------------------------------------
3941   //      +------------+ +----------+
3942   //      |     b      | | prefetch |
3943   //      +------------+ | for o    |
3944   //  +-------+     +---------+     |
3945   //  |       |     |    |    |     |
3946   //  |       |     |    |    |     |
3947   //  |   a   |     |    +----|-----+
3948   //  |       |     | prefetch| +------------+
3949   //  |       |     | for m   | |     n      |
3950   //  +-------+     +---------+ +------------+
3951   //  -------------------------------------------
3952   //    Min memory          time ->
3953   //
3954   absl::string_view hlo_string = R"(
3955   HloModule bug, is_scheduled=true
3956 
3957   ENTRY %Entry {
3958     %param0 = f32[8,3] parameter(0)
3959     %param1 = f32[2,4] parameter(1)
3960     %a = f32[8,3] sine(%param0)
3961     %b = f32[2,4] cosine(%param1)
3962     %d = f32[8,3] tanh(%a)
3963     %c = f32[8,3] negate(%a)
3964     %e = f32[2,4] negate(%b)
3965     %f = f32[2,4] negate(%e)
3966     %g = f32[2,4] negate(%f)
3967     %h = f32[2,4] negate(%g)
3968     %i = f32[2,4] negate(%h)
3969     %j = f32[2,4] negate(%i)
3970     %k = f32[2,4] negate(%j)
3971     %l = f32[2,4] negate(%k)
3972     %m = f32[8,3] negate(%d)
3973     %n = f32[2,4] sine(%l)
3974     %o = f32[8,3] negate(%d)
3975     %p = f32[2,4] negate(%n)
3976     %q = f32[8,3] negate(%m)
3977     ROOT %tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(%p, %q, %o)
3978   }
3979   )";
3980 
3981   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
3982       [](const MemorySpaceAssignment::BufferInterval& a,
3983          const MemorySpaceAssignment::BufferInterval& b) {
3984         auto get_opcode_priority = [](const HloOpcode& opcode) {
3985           switch (opcode) {
3986             case HloOpcode::kSin:
3987               return 0;
3988             case HloOpcode::kCos:
3989               return 1;
3990             case HloOpcode::kTanh:
3991               return 2;
3992             default:
3993               return 3;
3994           }
3995         };
3996 
3997         return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
3998                get_opcode_priority(b.buffer->defining_instruction()->opcode());
3999       };
4000   TF_ASSERT_OK_AND_ASSIGN(auto module,
4001                           ParseAndReturnVerifiedModule(hlo_string));
4002 
4003   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4004   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4005                     buffer_interval_compare, &prefetch_interval_picker);
4006 }
4007 
TEST_P(MemorySpaceAssignmentTest,WhileAliasedArgumentRequiredAssignmentBug)4008 TEST_P(MemorySpaceAssignmentTest, WhileAliasedArgumentRequiredAssignmentBug) {
4009   // Tests an overly pessimistic assertion when the same HloValue is passed
4010   // multiple times to a while HLO. We already handle this case that the two
4011   // arguments must alias and get the same allocation in AllocateSegment so the
4012   // assertion isn't necessary.
4013   absl::string_view hlo_string = R"(
4014   HloModule bug, is_scheduled=true
4015 
4016   while_condition {
4017     param1 = (f32[2,4], f32[2,4], f32[2,4]) parameter(0)
4018     ROOT cond = pred[] constant(true)
4019   }
4020 
4021   while_body {
4022     param2 = (f32[2,4], f32[2,4], f32[2,4]) parameter(0)
4023     gte2 = f32[2,4] get-tuple-element(param2), index=0
4024     gte3 = f32[2,4] get-tuple-element(param2), index=1
4025     gte4 = f32[2,4] get-tuple-element(param2), index=2
4026     add = f32[2,4] add(gte2, gte3)
4027     ROOT tuple2 = (f32[2,4], f32[2,4], f32[2,4]) tuple(add, gte3, gte4)
4028   }
4029 
4030   ENTRY Entry {
4031     param0 = f32[2,4] parameter(0)
4032     a = f32[2,4] negate(param0)
4033     b = f32[2,4] negate(param0)
4034     tuple = (f32[2,4], f32[2,4], f32[2,4]) tuple(a, b, b)
4035     while = (f32[2,4], f32[2,4], f32[2,4]) while(tuple), condition=while_condition, body=while_body
4036     gte1 = f32[2,4] get-tuple-element(while), index=0
4037     gte2 = f32[2,4] get-tuple-element(while), index=1
4038     ROOT root = f32[2,4] add(gte1, gte2)
4039   }
4040   )";
4041   TF_ASSERT_OK_AND_ASSIGN(auto module,
4042                           ParseAndReturnVerifiedModule(hlo_string));
4043   AssignMemorySpace(module.get());
4044 }
4045 
TEST_P(MemorySpaceAssignmentTest,DisallowedUseBug)4046 TEST_P(MemorySpaceAssignmentTest, DisallowedUseBug) {
4047   // When we have a disallowed use (in this case tanh), we aren't allowed to
4048   // allocate this use in alternate memory. However, if we have another use
4049   // after this on the same buffer (o), this use may refer to "a" instead of the
4050   // evicted value, which is illegal because "a" will be allocated in the
4051   // alternate memory space.
4052   absl::string_view hlo_string = R"(
4053   HloModule bug, is_scheduled=true
4054 
4055   ENTRY Entry {
4056     param0 = f32[8,3] parameter(0)
4057     param1 = f32[2,4] parameter(1)
4058     a = f32[8,3] cosine(param0)
4059     b = f32[2,4] negate(param1)
4060     d = f32[8,3] negate(a)
4061     c = f32[2,4] negate(b)
4062     e = f32[2,4] negate(c)
4063     f = f32[8,3] tanh(a)
4064     g = f32[2,4] negate(e)
4065     h = f32[2,4] negate(g)
4066     i = f32[2,4] negate(h)
4067     j = f32[2,4] negate(i)
4068     k = f32[2,4] negate(j)
4069     l = f32[2,4] negate(k)
4070     m = f32[2,4] negate(l)
4071     n = f32[2,4] sine(m)
4072     o = f32[8,3] negate(a)
4073     p = f32[2,4] negate(n)
4074     q = f32[8,3] add(o, f)
4075     r = f32[8,3] add(q, d)
4076     ROOT tuple = (f32[2,4], f32[8,3]) tuple(p, r)
4077   }
4078   )";
4079 
4080   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4081       [](const MemorySpaceAssignment::BufferInterval& a,
4082          const MemorySpaceAssignment::BufferInterval& b) {
4083         auto get_opcode_priority = [](const HloOpcode& opcode) {
4084           switch (opcode) {
4085             case HloOpcode::kSin:
4086               return 0;
4087             case HloOpcode::kCos:
4088               return 1;
4089             case HloOpcode::kTanh:
4090               return 2;
4091             default:
4092               return 3;
4093           }
4094         };
4095 
4096         return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
4097                get_opcode_priority(b.buffer->defining_instruction()->opcode());
4098       };
4099   TF_ASSERT_OK_AND_ASSIGN(auto module,
4100                           ParseAndReturnVerifiedModule(hlo_string));
4101 
4102   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4103   Options options;
4104   options.max_size_in_bytes = 128;
4105   options.alignment_in_bytes = 8;
4106   options.verify = true;
4107   options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
4108     return use.instruction->opcode() != HloOpcode::kTanh;
4109   };
4110   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4111                     buffer_interval_compare, &prefetch_interval_picker,
4112                     options);
4113 }
4114 
TEST_P(MemorySpaceAssignmentTest,DisallowedUseBugInWhile)4115 TEST_P(MemorySpaceAssignmentTest, DisallowedUseBugInWhile) {
4116   // Test for situations where we disallow a use (tanh in this case) in the
4117   // alternate memory space and there is a subsequent use that also requires the
4118   // buffer to be in the default memory space. In this case, the allocation in
4119   // the default memory space might not be the very last one, so we need to
4120   // search the allocation sequence and find the one in the default memory
4121   // space.
4122   absl::string_view hlo_string = R"(
4123   HloModule module, is_scheduled=true
4124 
4125   while_cond {
4126     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4127     ROOT gte = pred[] get-tuple-element(p0), index=3
4128   }
4129 
4130   while_body {
4131     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4132     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4133     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4134     gte2 = f32[3]{0} get-tuple-element(p0), index=2
4135     gte3 = pred[] get-tuple-element(p0), index=3
4136     add = f32[3]{0} add(gte0, gte0)
4137     negate0 = f32[3]{0} negate(add)
4138     negate1 = f32[3]{0} negate(negate0)
4139     negate2 = f32[3]{0} negate(negate1)
4140     negate3 = f32[3]{0} negate(negate2)
4141     negate4 = f32[3]{0} negate(negate3)
4142     negate5 = f32[3]{0} negate(negate4)
4143     negate6 = f32[3]{0} negate(negate5)
4144     negate7 = f32[3]{0} negate(negate6)
4145     negate8 = f32[3]{0} negate(negate7)
4146     negate9 = f32[3]{0} negate(negate8)
4147     negate10 = f32[3]{0} negate(negate9)
4148     negate11 = f32[3]{0} negate(negate10)
4149     negate12 = f32[3]{0} negate(negate11)
4150     negate13 = f32[3]{0} negate(negate12)
4151     negate14 = f32[3]{0} negate(negate13)
4152     negate15 = f32[3]{0} negate(gte2)
4153     tanh = f32[3]{0} tanh(gte2)
4154     ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(negate14, tanh, gte2, gte3)
4155   }
4156 
4157   ENTRY entry {
4158     p0 = f32[3]{0} parameter(0)
4159     p1 = pred[] parameter(1)
4160     copy0 = f32[3]{0} copy(p0)
4161     copy1 = f32[3]{0} copy(p0)
4162     tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
4163     while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4164     ROOT gte = f32[3]{0} get-tuple-element(while), index=2
4165   }
4166   )";
4167 
4168   TF_ASSERT_OK_AND_ASSIGN(auto module,
4169                           ParseAndReturnVerifiedModule(hlo_string));
4170   Options options;
4171   options.max_size_in_bytes = 128;
4172   options.alignment_in_bytes = 8;
4173   options.verify = true;
4174   options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
4175     return use.instruction->opcode() != HloOpcode::kTanh;
4176   };
4177   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4178                     /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
4179                     options);
4180 }
4181 
TEST_P(MemorySpaceAssignmentTest,AvoidRedundantEvictionInWhile)4182 TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionInWhile) {
4183   absl::string_view hlo_string = R"(
4184   HloModule module, is_scheduled=true
4185 
4186   while_cond {
4187     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4188     ROOT gte = pred[] get-tuple-element(p0), index=2
4189   }
4190 
4191   while_body {
4192     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4193     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4194     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4195     tanh = f32[3]{0} tanh(gte1)
4196     gte2 = pred[] get-tuple-element(p0), index=2
4197     negate0 = f32[3]{0} negate(gte0)
4198     negate1 = f32[3]{0} negate(negate0)
4199     negate2 = f32[3]{0} negate(negate1)
4200     negate3 = f32[3]{0} negate(negate2)
4201     negate4 = f32[3]{0} negate(negate3)
4202     negate5 = f32[3]{0} negate(negate4)
4203     negate6 = f32[3]{0} negate(negate5)
4204     negate7 = f32[3]{0} negate(negate6)
4205     negate8 = f32[3]{0} negate(negate7)
4206     negate9 = f32[3]{0} negate(negate8)
4207     negate10 = f32[3]{0} negate(negate9)
4208     negate11 = f32[3]{0} negate(negate10)
4209     negate12 = f32[3]{0} negate(negate11)
4210     negate13 = f32[3]{0} negate(negate12)
4211     negate14 = f32[3]{0} negate(negate13)
4212     add = f32[3]{0} add(negate14, tanh)
4213     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2)
4214   }
4215 
4216   ENTRY entry {
4217     p0 = f32[3]{0} parameter(0)
4218     p1 = pred[] parameter(1)
4219     copy = f32[3]{0} copy(p0)
4220     tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4221     while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4222     gte = f32[3]{0} get-tuple-element(while), index=1
4223     ROOT negate = f32[3]{0} negate(gte)
4224   }
4225   )";
4226 
4227   TF_ASSERT_OK_AND_ASSIGN(auto module,
4228                           ParseAndReturnVerifiedModule(hlo_string));
4229   AssignMemorySpace(module.get());
4230 
4231   if (GetParam()) {
4232     // Expect that while{1} is allocated to alternate memory space. Also expect
4233     // that this buffer is prefetched at the end of the while loop body but is
4234     // never evicted (since it has a copy in the default memory space).
4235     const HloInstruction* while_instr = FindInstruction(module.get(), "while");
4236     EXPECT_EQ(while_instr->shape().tuple_shapes(1).layout().memory_space(),
4237               kAlternateMemorySpace);
4238     const HloInstruction* gte1 = FindInstruction(module.get(), "gte1");
4239     EXPECT_EQ(gte1->user_count(), 1);
4240     EXPECT_EQ(gte1->users()[0]->opcode(), HloOpcode::kTanh);
4241     const HloInstruction* while_root =
4242         while_instr->while_body()->root_instruction();
4243     EXPECT_THAT(while_root->operand(1),
4244                 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4245                               op::GetTupleElement(op::Parameter(0))));
4246   }
4247 }
4248 
TEST_P(MemorySpaceAssignmentTest,RedundantEvictionEliminationShouldntAddRedundantParam)4249 TEST_P(MemorySpaceAssignmentTest,
4250        RedundantEvictionEliminationShouldntAddRedundantParam) {
4251   // Check that if there wasn't an eviction in the while loop, we don't add the
4252   // buffer in default memory as an additional parameter to the while loop.
4253   absl::string_view hlo_string = R"(
4254   HloModule module, is_scheduled=true
4255 
4256   while_cond {
4257     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4258     ROOT gte = pred[] get-tuple-element(p0), index=2
4259   }
4260 
4261   while_body {
4262     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4263     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4264     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4265     tanh = f32[3]{0} tanh(gte1)
4266     gte2 = pred[] get-tuple-element(p0), index=2
4267     negate0 = f32[3]{0} negate(gte0)
4268     negate1 = f32[3]{0} negate(negate0)
4269     add = f32[3]{0} add(negate1, tanh)
4270     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2)
4271   }
4272 
4273   ENTRY entry {
4274     p0 = f32[3]{0} parameter(0)
4275     p1 = pred[] parameter(1)
4276     copy = f32[3]{0} copy(p0)
4277     tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4278     while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4279     gte = f32[3]{0} get-tuple-element(while), index=1
4280     ROOT negate = f32[3]{0} negate(gte)
4281   }
4282   )";
4283 
4284   TF_ASSERT_OK_AND_ASSIGN(auto module,
4285                           ParseAndReturnVerifiedModule(hlo_string));
4286   AssignMemorySpace(module.get());
4287 
4288   // Expect that while tuple shape contains 3 elements like the original.
4289   const HloInstruction* while_instr = FindInstruction(module.get(), "while");
4290   EXPECT_EQ(while_instr->shape().tuple_shapes_size(), 3);
4291 }
4292 
TEST_P(MemorySpaceAssignmentTest,AvoidRedundantEvictionInNestedWhile)4293 TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionInNestedWhile) {
4294   absl::string_view hlo_string = R"(
4295   HloModule module, is_scheduled=true
4296 
4297   while_cond2 {
4298     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4299     ROOT gte = pred[] get-tuple-element(p0), index=2
4300   }
4301 
4302   while_body2 {
4303     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4304     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4305     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4306     tanh = f32[3]{0} tanh(gte1)
4307     gte2 = pred[] get-tuple-element(p0), index=2
4308     negate0 = f32[3]{0} negate(gte0)
4309     negate1 = f32[3]{0} negate(negate0)
4310     negate2 = f32[3]{0} negate(negate1)
4311     negate3 = f32[3]{0} negate(negate2)
4312     negate4 = f32[3]{0} negate(negate3)
4313     negate5 = f32[3]{0} negate(negate4)
4314     negate6 = f32[3]{0} negate(negate5)
4315     negate7 = f32[3]{0} negate(negate6)
4316     negate8 = f32[3]{0} negate(negate7)
4317     negate9 = f32[3]{0} negate(negate8)
4318     negate10 = f32[3]{0} negate(negate9)
4319     negate11 = f32[3]{0} negate(negate10)
4320     negate12 = f32[3]{0} negate(negate11)
4321     negate13 = f32[3]{0} negate(negate12)
4322     negate14 = f32[3]{0} negate(negate13)
4323     add = f32[3]{0} add(negate14, tanh)
4324     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2)
4325   }
4326 
4327   while_cond1 {
4328     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4329     ROOT gte = pred[] get-tuple-element(p0), index=2
4330   }
4331 
4332   while_body1 {
4333     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4334     ROOT while2 = (f32[3]{0}, f32[3]{0}, pred[]) while(p0), condition=while_cond2, body=while_body2
4335   }
4336 
4337   ENTRY entry {
4338     p0 = f32[3]{0} parameter(0)
4339     p1 = pred[] parameter(1)
4340     copy = f32[3]{0} copy(p0)
4341     tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4342     while1 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond1, body=while_body1
4343     gte = f32[3]{0} get-tuple-element(while1), index=1
4344     ROOT negate = f32[3]{0} negate(gte)
4345   }
4346   )";
4347 
4348   TF_ASSERT_OK_AND_ASSIGN(auto module,
4349                           ParseAndReturnVerifiedModule(hlo_string));
4350   AssignMemorySpace(module.get());
4351 
4352   if (GetParam()) {
4353     // Expect that while1{1} and while2{1} are allocated to alternate memory
4354     // space. Also expect that this buffer is prefetched at the end of the while
4355     // loop body but is never evicted (since it has a copy in the default memory
4356     // space).
4357     const HloInstruction* while1_instr =
4358         FindInstruction(module.get(), "while1");
4359     EXPECT_EQ(while1_instr->shape().tuple_shapes(1).layout().memory_space(),
4360               kAlternateMemorySpace);
4361     const HloInstruction* while2_instr =
4362         FindInstruction(module.get(), "while2");
4363     EXPECT_EQ(while2_instr->shape().tuple_shapes(1).layout().memory_space(),
4364               kAlternateMemorySpace);
4365     const HloInstruction* gte1 = FindInstruction(module.get(), "gte1");
4366     EXPECT_EQ(gte1->user_count(), 1);
4367     EXPECT_EQ(gte1->users()[0]->opcode(), HloOpcode::kTanh);
4368     const HloInstruction* while_root =
4369         while2_instr->while_body()->root_instruction();
4370     EXPECT_THAT(while_root->operand(1),
4371                 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4372                               op::GetTupleElement(op::Parameter(0))));
4373   }
4374 }
4375 
TEST_P(MemorySpaceAssignmentTest,RedundantEvictionEliminationBug)4376 TEST_P(MemorySpaceAssignmentTest, RedundantEvictionEliminationBug) {
4377   absl::string_view hlo_string = R"(
4378   HloModule module, is_scheduled=true
4379 
4380   while_cond {
4381     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4382     ROOT gte = pred[] get-tuple-element(p0), index=2
4383   }
4384 
4385   while_body {
4386     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4387     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4388     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4389     tanh = f32[3]{0} tanh(gte1)
4390     gte2 = pred[] get-tuple-element(p0), index=2
4391     negate0 = f32[3]{0} negate(gte0)
4392     negate1 = f32[3]{0} negate(negate0)
4393     negate2 = f32[3]{0} negate(negate1)
4394     negate3 = f32[3]{0} negate(negate2)
4395     negate4 = f32[3]{0} negate(negate3)
4396     negate5 = f32[3]{0} negate(negate4)
4397     negate6 = f32[3]{0} negate(negate5)
4398     negate7 = f32[3]{0} negate(negate6)
4399     negate8 = f32[3]{0} negate(negate7)
4400     negate9 = f32[3]{0} negate(negate8)
4401     negate10 = f32[3]{0} negate(negate9)
4402     negate11 = f32[3]{0} negate(negate10)
4403     negate12 = f32[3]{0} negate(negate11)
4404     negate13 = f32[3]{0} negate(negate12)
4405     negate14 = f32[3]{0} negate(negate13)
4406     add0 = f32[3]{0} add(negate14, tanh)
4407     add1 = f32[3]{0} add(add0, gte1)
4408     negate = f32[3]{0} negate(add1)
4409     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add1, negate, gte2)
4410   }
4411 
4412   ENTRY entry {
4413     p0 = f32[3]{0} parameter(0)
4414     p1 = pred[] parameter(1)
4415     copy = f32[3]{0} copy(p0)
4416     tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4417     while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4418     gte = f32[3]{0} get-tuple-element(while), index=1
4419     ROOT negate = f32[3]{0} negate(gte)
4420   }
4421   )";
4422 
4423   TF_ASSERT_OK_AND_ASSIGN(auto module,
4424                           ParseAndReturnVerifiedModule(hlo_string));
4425   AssignMemorySpace(module.get());
4426 
4427   // Expect that redundant eviction elimination doesn't kick in because
4428   // while{1} is updated within the body.
4429   const HloInstruction* while_instr = FindInstruction(module.get(), "while");
4430   EXPECT_EQ(while_instr->shape().tuple_shapes_size(), 3);
4431   if (GetParam()) {
4432     EXPECT_EQ(while_instr->shape().tuple_shapes(1).layout().memory_space(),
4433               kAlternateMemorySpace);
4434     const HloInstruction* gte1 = FindInstruction(module.get(), "gte1");
4435     EXPECT_EQ(gte1->user_count(), 2);
4436     EXPECT_NE(absl::c_find_if(gte1->users(),
4437                               [](const HloInstruction* use) {
4438                                 return use->opcode() == HloOpcode::kCopyStart;
4439                               }),
4440               gte1->users().end());
4441   }
4442 }
4443 
TEST_P(MemorySpaceAssignmentTest,RedundantEvictionEliminationInChainedWhile)4444 TEST_P(MemorySpaceAssignmentTest, RedundantEvictionEliminationInChainedWhile) {
4445   // Check against a bug where a while HLO feeding to another while HLO can
4446   // cause a crash if we performa redundant eviction elimination to the first
4447   // while but not the other (while operand/parameter shapes would mismatch).
4448   absl::string_view hlo_string = R"(
4449   HloModule module, is_scheduled=true
4450 
4451   while_cond1 {
4452     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4453     ROOT gte = pred[] get-tuple-element(p0), index=2
4454   }
4455 
4456   while_body1 {
4457     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4458     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4459     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4460     tanh = f32[3]{0} tanh(gte1)
4461     gte2 = pred[] get-tuple-element(p0), index=2
4462     negate0 = f32[3]{0} negate(gte0)
4463     negate1 = f32[3]{0} negate(negate0)
4464     negate2 = f32[3]{0} negate(negate1)
4465     negate3 = f32[3]{0} negate(negate2)
4466     negate4 = f32[3]{0} negate(negate3)
4467     negate5 = f32[3]{0} negate(negate4)
4468     negate6 = f32[3]{0} negate(negate5)
4469     negate7 = f32[3]{0} negate(negate6)
4470     negate8 = f32[3]{0} negate(negate7)
4471     negate9 = f32[3]{0} negate(negate8)
4472     negate10 = f32[3]{0} negate(negate9)
4473     negate11 = f32[3]{0} negate(negate10)
4474     negate12 = f32[3]{0} negate(negate11)
4475     negate13 = f32[3]{0} negate(negate12)
4476     negate14 = f32[3]{0} negate(negate13)
4477     add = f32[3]{0} add(negate14, tanh)
4478     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2)
4479   }
4480 
4481   while_cond2 {
4482     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4483     ROOT gte = pred[] get-tuple-element(p0), index=2
4484   }
4485 
4486   while_body2 {
4487     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4488     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4489     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4490     tanh = f32[3]{0} tanh(gte1)
4491     gte2 = pred[] get-tuple-element(p0), index=2
4492     negate0 = f32[3]{0} negate(gte0)
4493     add = f32[3]{0} add(negate0, tanh)
4494     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(add, gte1, gte2)
4495   }
4496 
4497   ENTRY entry {
4498     p0 = f32[3]{0} parameter(0)
4499     p1 = pred[] parameter(1)
4500     copy = f32[3]{0} copy(p0)
4501     tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4502     while1 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond1, body=while_body1
4503     while2 = (f32[3]{0}, f32[3]{0}, pred[]) while(while1), condition=while_cond2, body=while_body2
4504     gte = f32[3]{0} get-tuple-element(while2), index=1
4505     ROOT negate = f32[3]{0} negate(gte)
4506   }
4507   )";
4508 
4509   TF_ASSERT_OK_AND_ASSIGN(auto module,
4510                           ParseAndReturnVerifiedModule(hlo_string));
4511   AssignMemorySpace(module.get());
4512 
4513   if (GetParam()) {
4514     // Expect that while1 has one more value than while2 in its shape.
4515     EXPECT_EQ(
4516         FindInstruction(module.get(), "while1")->shape().tuple_shapes_size(),
4517         FindInstruction(module.get(), "while2")->shape().tuple_shapes_size() +
4518             1);
4519   }
4520 }
4521 
TEST_P(MemorySpaceAssignmentTest,AvoidRedundantEvictionAfterWhile)4522 TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionAfterWhile) {
4523   absl::string_view hlo_string = R"(
4524   HloModule module, is_scheduled=true
4525 
4526   while_cond {
4527     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4528     ROOT gte = pred[] get-tuple-element(p0), index=2
4529   }
4530 
4531   while_body {
4532     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4533     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4534     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4535     gte2 = pred[] get-tuple-element(p0), index=2
4536     add = f32[3]{0} add(gte0, gte1)
4537     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
4538   }
4539 
4540   ENTRY entry {
4541     p0 = f32[3]{0} parameter(0)
4542     p1 = pred[] parameter(1)
4543     copy = f32[3]{0} copy(p0)
4544     negate0 = f32[3]{0} negate(p0)
4545     negate1 = f32[3]{0} negate(negate0)
4546     negate2 = f32[3]{0} negate(negate1)
4547     negate3 = f32[3]{0} negate(negate2)
4548     negate4 = f32[3]{0} negate(negate3)
4549     negate5 = f32[3]{0} negate(negate4)
4550     negate6 = f32[3]{0} negate(negate5)
4551     negate7 = f32[3]{0} negate(negate6)
4552     negate8 = f32[3]{0} negate(negate7)
4553     negate9 = f32[3]{0} negate(negate8)
4554     negate10 = f32[3]{0} negate(negate9)
4555     negate11 = f32[3]{0} negate(negate10)
4556     negate12 = f32[3]{0} negate(negate11)
4557     negate13 = f32[3]{0} negate(negate12)
4558     negate14 = f32[3]{0} negate(negate13)
4559     tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, negate14, p1)
4560     while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4561     gte0 = f32[3]{0} get-tuple-element(while), index=0
4562     gte1 = f32[3]{0} get-tuple-element(while), index=1
4563     negate20 = f32[3]{0} negate(gte1)
4564     negate21 = f32[3]{0} negate(negate20)
4565     negate22 = f32[3]{0} negate(negate21)
4566     negate23 = f32[3]{0} negate(negate22)
4567     negate24 = f32[3]{0} negate(negate23)
4568     negate25 = f32[3]{0} negate(negate24)
4569     negate26 = f32[3]{0} negate(negate25)
4570     negate27 = f32[3]{0} negate(negate26)
4571     negate28 = f32[3]{0} negate(negate27)
4572     negate29 = f32[3]{0} negate(negate28)
4573     negate30 = f32[3]{0} negate(negate29)
4574     negate31 = f32[3]{0} negate(negate30)
4575     negate32 = f32[3]{0} negate(negate31)
4576     negate33 = f32[3]{0} negate(negate32)
4577     negate34 = f32[3]{0} negate(negate33)
4578     ROOT add = f32[3]{0} add(negate34, gte0)
4579   }
4580   )";
4581 
4582   TF_ASSERT_OK_AND_ASSIGN(auto module,
4583                           ParseAndReturnVerifiedModule(hlo_string));
4584   AssignMemorySpace(module.get());
4585 
4586   if (GetParam()) {
4587     EXPECT_THAT(
4588         module->entry_computation()->root_instruction()->operand(1),
4589         op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, op::Copy()));
4590   }
4591 }
4592 
TEST_P(MemorySpaceAssignmentTest,AvoidRedundantEvictionAfterWhile2)4593 TEST_P(MemorySpaceAssignmentTest, AvoidRedundantEvictionAfterWhile2) {
4594   absl::string_view hlo_string = R"(
4595   HloModule module, is_scheduled=true
4596 
4597   while_cond1 {
4598     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4599     ROOT gte = pred[] get-tuple-element(p0), index=2
4600   }
4601 
4602   while_body1 {
4603     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4604     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4605     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4606     gte2 = pred[] get-tuple-element(p0), index=2
4607     add = f32[3]{0} add(gte0, gte1)
4608     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
4609   }
4610 
4611   while_cond2 {
4612     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4613     ROOT gte = pred[] get-tuple-element(p0), index=2
4614   }
4615 
4616   while_body2 {
4617     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4618     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4619     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4620     gte2 = pred[] get-tuple-element(p0), index=2
4621     add = f32[3]{0} add(gte0, gte1)
4622     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, gte2)
4623   }
4624 
4625   ENTRY entry {
4626     p0 = f32[3]{0} parameter(0)
4627     p1 = pred[] parameter(1)
4628     copy = f32[3]{0} copy(p0)
4629     tuple1 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, p0, p1)
4630     while1 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple1), condition=while_cond1, body=while_body1
4631     gte0 = f32[3]{0} get-tuple-element(while1), index=0
4632     gte1 = f32[3]{0} get-tuple-element(while1), index=1
4633     negate0 = f32[3]{0} negate(gte1)
4634     negate1 = f32[3]{0} negate(negate0)
4635     negate2 = f32[3]{0} negate(negate1)
4636     negate3 = f32[3]{0} negate(negate2)
4637     negate4 = f32[3]{0} negate(negate3)
4638     negate5 = f32[3]{0} negate(negate4)
4639     negate6 = f32[3]{0} negate(negate5)
4640     negate7 = f32[3]{0} negate(negate6)
4641     negate8 = f32[3]{0} negate(negate7)
4642     negate9 = f32[3]{0} negate(negate8)
4643     negate10 = f32[3]{0} negate(negate9)
4644     negate11 = f32[3]{0} negate(negate10)
4645     negate12 = f32[3]{0} negate(negate11)
4646     negate13 = f32[3]{0} negate(negate12)
4647     negate14 = f32[3]{0} negate(negate13)
4648     tuple2 = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, negate14, p1)
4649     while2 = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple2), condition=while_cond2, body=while_body2
4650     gte2 = f32[3]{0} get-tuple-element(while2), index=0
4651     gte3 = f32[3]{0} get-tuple-element(while2), index=1
4652     negate20 = f32[3]{0} negate(gte3)
4653     negate21 = f32[3]{0} negate(negate20)
4654     negate22 = f32[3]{0} negate(negate21)
4655     negate23 = f32[3]{0} negate(negate22)
4656     negate24 = f32[3]{0} negate(negate23)
4657     negate25 = f32[3]{0} negate(negate24)
4658     negate26 = f32[3]{0} negate(negate25)
4659     negate27 = f32[3]{0} negate(negate26)
4660     negate28 = f32[3]{0} negate(negate27)
4661     negate29 = f32[3]{0} negate(negate28)
4662     negate30 = f32[3]{0} negate(negate29)
4663     negate31 = f32[3]{0} negate(negate30)
4664     negate32 = f32[3]{0} negate(negate31)
4665     negate33 = f32[3]{0} negate(negate32)
4666     negate34 = f32[3]{0} negate(negate33)
4667     ROOT add = f32[3]{0} add(negate34, gte2)
4668   }
4669   )";
4670 
4671   TF_ASSERT_OK_AND_ASSIGN(auto module,
4672                           ParseAndReturnVerifiedModule(hlo_string));
4673   AssignMemorySpace(module.get());
4674 
4675   if (GetParam()) {
4676     EXPECT_THAT(
4677         module->entry_computation()->root_instruction()->operand(1),
4678         op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4679                       op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
4680                                     op::GetTupleElement(op::While()))));
4681   }
4682 }
4683 
TEST_P(MemorySpaceAssignmentTest,AfterWhileRedundantEarlierEvictionModifiedBuffer)4684 TEST_P(MemorySpaceAssignmentTest,
4685        AfterWhileRedundantEarlierEvictionModifiedBuffer) {
4686   absl::string_view hlo_string = R"(
4687   HloModule module, is_scheduled=true
4688 
4689   while_cond {
4690     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4691     ROOT gte = pred[] get-tuple-element(p0), index=2
4692   }
4693 
4694   while_body {
4695     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4696     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4697     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4698     gte2 = pred[] get-tuple-element(p0), index=2
4699     add = f32[3]{0} add(gte0, gte1)
4700     negate = f32[3]{0} negate(gte0)
4701     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(negate, add, gte2)
4702   }
4703 
4704   ENTRY entry {
4705     p0 = f32[3]{0} parameter(0)
4706     p1 = pred[] parameter(1)
4707     copy = f32[3]{0} copy(p0)
4708     negate0 = f32[3]{0} negate(p0)
4709     negate1 = f32[3]{0} negate(negate0)
4710     negate2 = f32[3]{0} negate(negate1)
4711     negate3 = f32[3]{0} negate(negate2)
4712     negate4 = f32[3]{0} negate(negate3)
4713     negate5 = f32[3]{0} negate(negate4)
4714     negate6 = f32[3]{0} negate(negate5)
4715     negate7 = f32[3]{0} negate(negate6)
4716     negate8 = f32[3]{0} negate(negate7)
4717     negate9 = f32[3]{0} negate(negate8)
4718     negate10 = f32[3]{0} negate(negate9)
4719     negate11 = f32[3]{0} negate(negate10)
4720     negate12 = f32[3]{0} negate(negate11)
4721     negate13 = f32[3]{0} negate(negate12)
4722     negate14 = f32[3]{0} negate(negate13)
4723     tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy, negate14, p1)
4724     while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4725     gte0 = f32[3]{0} get-tuple-element(while), index=0
4726     gte1 = f32[3]{0} get-tuple-element(while), index=1
4727     negate20 = f32[3]{0} negate(gte1)
4728     negate21 = f32[3]{0} negate(negate20)
4729     negate22 = f32[3]{0} negate(negate21)
4730     negate23 = f32[3]{0} negate(negate22)
4731     negate24 = f32[3]{0} negate(negate23)
4732     negate25 = f32[3]{0} negate(negate24)
4733     negate26 = f32[3]{0} negate(negate25)
4734     negate27 = f32[3]{0} negate(negate26)
4735     negate28 = f32[3]{0} negate(negate27)
4736     negate29 = f32[3]{0} negate(negate28)
4737     negate30 = f32[3]{0} negate(negate29)
4738     negate31 = f32[3]{0} negate(negate30)
4739     negate32 = f32[3]{0} negate(negate31)
4740     negate33 = f32[3]{0} negate(negate32)
4741     negate34 = f32[3]{0} negate(negate33)
4742     ROOT add = f32[3]{0} add(negate34, gte0)
4743   }
4744   )";
4745 
4746   TF_ASSERT_OK_AND_ASSIGN(auto module,
4747                           ParseAndReturnVerifiedModule(hlo_string));
4748   AssignMemorySpace(module.get());
4749 
4750   if (GetParam()) {
4751     EXPECT_THAT(
4752         module->entry_computation()->root_instruction()->operand(1),
4753         op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4754                       op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
4755                                     op::GetTupleElement(op::While()))));
4756   }
4757 }
4758 
TEST_P(MemorySpaceAssignmentTest,BitcastRoot)4759 TEST_P(MemorySpaceAssignmentTest, BitcastRoot) {
4760   // Tests against a bug where the root of entry computation is a bitcast
4761   // instruction and it ends up getting an allocation in the alternate memory.
4762   absl::string_view hlo_string = R"(
4763 HloModule primitive_computation_gather.4, is_scheduled=true
4764 
4765 %while_body {
4766   %param.1 = (s32[], f32[3,3,3]) parameter(0)
4767   %get-tuple-element.32 = s32[] get-tuple-element(%param.1), index=0
4768   %copy.6 = s32[] copy(s32[] %get-tuple-element.32)
4769   %constant.8 = s32[] constant(1)
4770   %add = s32[] add(s32[] %copy.6, s32[] %constant.8)
4771   %get-tuple-element.35 = f32[3,3,3] get-tuple-element(%param.1), index=1
4772   negate = f32[3,3,3] negate(get-tuple-element.35)
4773   ROOT %tuple.10 = (s32[], f32[3,3,3]) tuple(s32[] %add, f32[3,3,3] negate)
4774 }
4775 
4776 %while_cond {
4777   %param.0 = (s32[], f32[3,3,3]) parameter(0)
4778   %get-tuple-element = s32[] get-tuple-element(%param.0), index=0
4779   %constant.3 = s32[] constant(3)
4780   ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.3), direction=LT
4781 }
4782 
4783 ENTRY %primitive_computation_gather.4 (parameter.1: f32[3,10,5], parameter.2: s32[3,1]) -> f32[3,3,3] {
4784   %constant.1 = s32[] constant(0)
4785   %copy.11 = s32[] copy(s32[] %constant.1)
4786   %constant = f32[] constant(0)
4787   %broadcast = f32[3,3,3] broadcast(f32[] %constant), dimensions={}
4788   %tuple.8 = (s32[], f32[3,3,3]) tuple(s32[] %copy.11, f32[3,3,3] %broadcast)
4789   %while = (s32[], f32[3,3,3]) while(%tuple.8), condition=%while_cond, body=%while_body
4790   %get-tuple-element.7 = f32[3,3,3] get-tuple-element(%while), index=1
4791   ROOT %bitcast.1 = f32[3,3,3] bitcast(f32[3,3,3] %get-tuple-element.7)
4792 }
4793   )";
4794 
4795   TF_ASSERT_OK_AND_ASSIGN(auto module,
4796                           ParseAndReturnVerifiedModule(hlo_string));
4797   AssignMemorySpace(module.get());
4798 
4799   const HloInstruction* root = module->entry_computation()->root_instruction();
4800   EXPECT_TRUE(!root->shape().has_layout() ||
4801               root->shape().layout().memory_space() == kDefaultMemorySpace);
4802 }
4803 
TEST_P(MemorySpaceAssignmentTest,AsyncOpShortLiveRange)4804 TEST_P(MemorySpaceAssignmentTest, AsyncOpShortLiveRange) {
4805   absl::string_view hlo_string = R"(
4806 HloModule module, is_scheduled=true
4807 
4808 ENTRY entry {
4809   param = bf16[4]{0} parameter(0)
4810   negate0 = bf16[4]{0} negate(param)
4811   collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0), source_target_pairs={{0,1},{1,2},{2,3}}
4812   negate1 = bf16[4]{0} negate(param)
4813   negate2 = bf16[4]{0} negate(negate1)
4814   negate3 = bf16[4]{0} negate(negate2)
4815   collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
4816   ROOT add = add(collective-permute-done, negate3)
4817 }
4818   )";
4819 
4820   TF_ASSERT_OK_AND_ASSIGN(auto module,
4821                           ParseAndReturnVerifiedModule(hlo_string));
4822   AssignMemorySpace(module.get());
4823 
4824   // Expect both the source and destination buffers to get alternate memory
4825   // allocations.
4826   HloInstruction* collective_permute_start =
4827       module->entry_computation()->GetInstructionWithName(
4828           "collective-permute-start");
4829   EXPECT_TRUE(collective_permute_start->shape()
4830                   .tuple_shapes(0)
4831                   .layout()
4832                   .memory_space() == kAlternateMemorySpace);
4833   EXPECT_TRUE(collective_permute_start->shape()
4834                   .tuple_shapes(1)
4835                   .layout()
4836                   .memory_space() == kAlternateMemorySpace);
4837 }
4838 
TEST_P(MemorySpaceAssignmentTest,AsyncOpShortLiveRangeInputBufferConsumer)4839 TEST_P(MemorySpaceAssignmentTest, AsyncOpShortLiveRangeInputBufferConsumer) {
4840   absl::string_view hlo_string = R"(
4841 HloModule module, is_scheduled=true
4842 
4843 ENTRY entry {
4844   param = bf16[4]{0} parameter(0)
4845   negate0 = bf16[4]{0} negate(param)
4846   collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0), source_target_pairs={{0,1},{1,2},{2,3}}
4847   negate1 = bf16[4]{0} negate(negate0)
4848   negate2 = bf16[4]{0} negate(negate1)
4849   negate3 = bf16[4]{0} negate(negate2)
4850   collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
4851   ROOT add = add(collective-permute-done, negate3)
4852 }
4853   )";
4854 
4855   TF_ASSERT_OK_AND_ASSIGN(auto module,
4856                           ParseAndReturnVerifiedModule(hlo_string));
4857   AssignMemorySpace(module.get());
4858 
4859   // Expect only the destination buffer to get alternate memory allocation
4860   // because negate0 is also used by negate1.
4861   HloInstruction* collective_permute_start =
4862       module->entry_computation()->GetInstructionWithName(
4863           "collective-permute-start");
4864   EXPECT_TRUE(collective_permute_start->shape()
4865                   .tuple_shapes(0)
4866                   .layout()
4867                   .memory_space() == kDefaultMemorySpace);
4868   EXPECT_TRUE(collective_permute_start->shape()
4869                   .tuple_shapes(1)
4870                   .layout()
4871                   .memory_space() == kAlternateMemorySpace);
4872 }
4873 
TEST_P(MemorySpaceAssignmentTest,AsyncOpLongLiveRange)4874 TEST_P(MemorySpaceAssignmentTest, AsyncOpLongLiveRange) {
4875   absl::string_view hlo_string = R"(
4876 HloModule module, is_scheduled=true
4877 
4878 ENTRY entry {
4879   param = bf16[4]{0} parameter(0)
4880   negate0 = bf16[4]{0} negate(param)
4881   collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0), source_target_pairs={{0,1},{1,2},{2,3}}
4882   negate1 = bf16[4]{0} negate(param)
4883   negate2 = bf16[4]{0} negate(negate1)
4884   negate3 = bf16[4]{0} negate(negate2)
4885   negate4 = bf16[4]{0} negate(negate3)
4886   negate5 = bf16[4]{0} negate(negate4)
4887   negate6 = bf16[4]{0} negate(negate5)
4888   negate7 = bf16[4]{0} negate(negate6)
4889   negate8 = bf16[4]{0} negate(negate7)
4890   negate9 = bf16[4]{0} negate(negate8)
4891   negate10 = bf16[4]{0} negate(negate9)
4892   negate11 = bf16[4]{0} negate(negate10)
4893   negate12 = bf16[4]{0} negate(negate11)
4894   negate13 = bf16[4]{0} negate(negate12)
4895   collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
4896   ROOT add = add(collective-permute-done, negate13)
4897 }
4898   )";
4899 
4900   TF_ASSERT_OK_AND_ASSIGN(auto module,
4901                           ParseAndReturnVerifiedModule(hlo_string));
4902   AssignMemorySpace(module.get());
4903 
4904   // Expect none of the buffers to get alternate memory allocations because of
4905   // the long live range.
4906   HloInstruction* collective_permute_start =
4907       module->entry_computation()->GetInstructionWithName(
4908           "collective-permute-start");
4909   EXPECT_TRUE(collective_permute_start->shape()
4910                   .tuple_shapes(0)
4911                   .layout()
4912                   .memory_space() == kDefaultMemorySpace);
4913   EXPECT_TRUE(collective_permute_start->shape()
4914                   .tuple_shapes(1)
4915                   .layout()
4916                   .memory_space() == kDefaultMemorySpace);
4917 }
4918 
TEST_P(MemorySpaceAssignmentTest,AsyncOpLongLiveRangeInputBufferConsumer)4919 TEST_P(MemorySpaceAssignmentTest, AsyncOpLongLiveRangeInputBufferConsumer) {
4920   absl::string_view hlo_string = R"(
4921 HloModule module, is_scheduled=true
4922 
4923 ENTRY entry {
4924   param = bf16[4]{0} parameter(0)
4925   negate0 = bf16[4]{0} negate(param)
4926   collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0), source_target_pairs={{0,1},{1,2},{2,3}}
4927   negate1 = bf16[4]{0} negate(negate0)
4928   negate2 = bf16[4]{0} negate(negate1)
4929   negate3 = bf16[4]{0} negate(negate2)
4930   negate4 = bf16[4]{0} negate(negate3)
4931   negate5 = bf16[4]{0} negate(negate4)
4932   negate6 = bf16[4]{0} negate(negate5)
4933   negate7 = bf16[4]{0} negate(negate6)
4934   negate8 = bf16[4]{0} negate(negate7)
4935   negate9 = bf16[4]{0} negate(negate8)
4936   negate10 = bf16[4]{0} negate(negate9)
4937   negate11 = bf16[4]{0} negate(negate10)
4938   negate12 = bf16[4]{0} negate(negate11)
4939   negate13 = bf16[4]{0} negate(negate12)
4940   collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
4941   ROOT add = add(collective-permute-done, negate13)
4942 }
4943   )";
4944 
4945   TF_ASSERT_OK_AND_ASSIGN(auto module,
4946                           ParseAndReturnVerifiedModule(hlo_string));
4947   AssignMemorySpace(module.get());
4948 
4949   // Expect none of the buffers to get alternate memory allocations because of
4950   // the long live range and because negate0 is also used by negate1.
4951   HloInstruction* collective_permute_start =
4952       module->entry_computation()->GetInstructionWithName(
4953           "collective-permute-start");
4954   EXPECT_TRUE(collective_permute_start->shape()
4955                   .tuple_shapes(0)
4956                   .layout()
4957                   .memory_space() == kDefaultMemorySpace);
4958   EXPECT_TRUE(collective_permute_start->shape()
4959                   .tuple_shapes(1)
4960                   .layout()
4961                   .memory_space() == kDefaultMemorySpace);
4962 }
4963 
TEST_P(MemorySpaceAssignmentTest,InPlaceAsyncCollectivePermute)4964 TEST_P(MemorySpaceAssignmentTest, InPlaceAsyncCollectivePermute) {
4965   absl::string_view hlo_string = R"(
4966 HloModule module, is_scheduled=true
4967 
4968 ENTRY entry {
4969   param = bf16[4]{0} parameter(0)
4970   negate0 = bf16[4]{0} negate(param)
4971   negate1 = bf16[4]{0} negate(param)
4972   const0 = s32[] constant(0)
4973   const1 = s32[] constant(1)
4974   tuple0 = (s32[]) tuple(const0)
4975   tuple1 = (s32[]) tuple(const1)
4976   collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0, negate1, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
4977   negate2 = bf16[4]{0} negate(param)
4978   negate3 = bf16[4]{0} negate(negate2)
4979   negate4 = bf16[4]{0} negate(negate3)
4980   collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
4981   ROOT add = add(collective-permute-done, negate4)
4982 }
4983   )";
4984 
4985   TF_ASSERT_OK_AND_ASSIGN(auto module,
4986                           ParseAndReturnVerifiedModule(hlo_string));
4987   AssignMemorySpace(module.get());
4988 
4989   // Expect both the source and destination buffers to get alternate memory
4990   // allocations.
4991   if (GetParam()) {
4992     HloInstruction* collective_permute_start =
4993         module->entry_computation()->GetInstructionWithName(
4994             "collective-permute-start");
4995     EXPECT_TRUE(collective_permute_start->shape()
4996                     .tuple_shapes(0)
4997                     .layout()
4998                     .memory_space() == kAlternateMemorySpace);
4999     EXPECT_TRUE(collective_permute_start->shape()
5000                     .tuple_shapes(1)
5001                     .layout()
5002                     .memory_space() == kAlternateMemorySpace);
5003   }
5004 }
5005 
TEST_P(MemorySpaceAssignmentTest,InPlaceAsyncCollectivePermuteSameBuffer)5006 TEST_P(MemorySpaceAssignmentTest, InPlaceAsyncCollectivePermuteSameBuffer) {
5007   absl::string_view hlo_string = R"(
5008 HloModule module, is_scheduled=true
5009 
5010 ENTRY entry {
5011   param = bf16[4]{0} parameter(0)
5012   negate0 = bf16[4]{0} negate(param)
5013   const0 = s32[] constant(0)
5014   const1 = s32[] constant(1)
5015   tuple0 = (s32[]) tuple(const0)
5016   tuple1 = (s32[]) tuple(const1)
5017   collective-permute-start = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0, negate0, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
5018   negate2 = bf16[4]{0} negate(param)
5019   negate3 = bf16[4]{0} negate(negate2)
5020   negate4 = bf16[4]{0} negate(negate3)
5021   collective-permute-done = bf16[4]{0} collective-permute-done(collective-permute-start)
5022   ROOT add = add(collective-permute-done, negate4)
5023 }
5024   )";
5025 
5026   TF_ASSERT_OK_AND_ASSIGN(auto module,
5027                           ParseAndReturnVerifiedModule(hlo_string));
5028   AssignMemorySpace(module.get());
5029 
5030   // Expect both the source and destination buffers to get alternate memory
5031   // allocations.
5032   if (GetParam()) {
5033     HloInstruction* collective_permute_start =
5034         module->entry_computation()->GetInstructionWithName(
5035             "collective-permute-start");
5036     EXPECT_TRUE(collective_permute_start->shape()
5037                     .tuple_shapes(0)
5038                     .layout()
5039                     .memory_space() == kAlternateMemorySpace);
5040     EXPECT_TRUE(collective_permute_start->shape()
5041                     .tuple_shapes(1)
5042                     .layout()
5043                     .memory_space() == kAlternateMemorySpace);
5044   }
5045 }
5046 
TEST_P(MemorySpaceAssignmentTest,InPlaceAsyncCollectivePermuteSameBufferChained)5047 TEST_P(MemorySpaceAssignmentTest,
5048        InPlaceAsyncCollectivePermuteSameBufferChained) {
5049   absl::string_view hlo_string = R"(
5050 HloModule module, is_scheduled=true
5051 
5052 ENTRY entry {
5053   param = bf16[4]{0} parameter(0)
5054   negate0 = bf16[4]{0} negate(param)
5055   const0 = s32[] constant(0)
5056   const1 = s32[] constant(1)
5057   tuple0 = (s32[]) tuple(const0)
5058   tuple1 = (s32[]) tuple(const1)
5059   collective-permute-start.1 = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(negate0, negate0, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
5060   negate2 = bf16[4]{0} negate(param)
5061   negate3 = bf16[4]{0} negate(negate2)
5062   negate4 = bf16[4]{0} negate(negate3)
5063   collective-permute-done.1 = bf16[4]{0} collective-permute-done(collective-permute-start.1)
5064   collective-permute-start.2 = (bf16[4]{0}, bf16[4]{0}, u32[], u32[]) collective-permute-start(collective-permute-done.1, collective-permute-done.1, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
5065   negate5 = bf16[4]{0} negate(negate4)
5066   negate6 = bf16[4]{0} negate(negate5)
5067   negate7 = bf16[4]{0} negate(negate6)
5068   collective-permute-done.2 = bf16[4]{0} collective-permute-done(collective-permute-start.2)
5069   ROOT add = add(collective-permute-done.2, negate7)
5070 }
5071   )";
5072 
5073   TF_ASSERT_OK_AND_ASSIGN(auto module,
5074                           ParseAndReturnVerifiedModule(hlo_string));
5075   AssignMemorySpace(module.get());
5076 
5077   // Expect both the source and destination buffers to get alternate memory
5078   // allocations.
5079   if (GetParam()) {
5080     HloInstruction* collective_permute_start_1 =
5081         module->entry_computation()->GetInstructionWithName(
5082             "collective-permute-start.1");
5083     EXPECT_TRUE(collective_permute_start_1->shape()
5084                     .tuple_shapes(0)
5085                     .layout()
5086                     .memory_space() == kAlternateMemorySpace);
5087     EXPECT_TRUE(collective_permute_start_1->shape()
5088                     .tuple_shapes(1)
5089                     .layout()
5090                     .memory_space() == kAlternateMemorySpace);
5091     HloInstruction* collective_permute_start_2 =
5092         module->entry_computation()->GetInstructionWithName(
5093             "collective-permute-start.2");
5094     EXPECT_TRUE(collective_permute_start_2->shape()
5095                     .tuple_shapes(0)
5096                     .layout()
5097                     .memory_space() == kAlternateMemorySpace);
5098     EXPECT_TRUE(collective_permute_start_2->shape()
5099                     .tuple_shapes(1)
5100                     .layout()
5101                     .memory_space() == kAlternateMemorySpace);
5102   }
5103 }
5104 
TEST_P(MemorySpaceAssignmentTest,TupleInPlaceAsyncCollectivePermuteSameBufferChained)5105 TEST_P(MemorySpaceAssignmentTest,
5106        TupleInPlaceAsyncCollectivePermuteSameBufferChained) {
5107   absl::string_view hlo_string = R"(
5108 HloModule module, is_scheduled=true
5109 
5110 ENTRY entry {
5111   param = bf16[4]{0} parameter(0)
5112   param2 = bf16[48]{0} parameter(1)
5113   negate0.1 = bf16[48]{0} negate(param2)
5114   negate0.2 = bf16[48]{0} negate(param2)
5115   const0 = s32[] constant(0)
5116   const1 = s32[] constant(1)
5117   tuple0.0 = (s32[]) tuple(const0)
5118   tuple0 = ((s32[]), (s32[])) tuple(tuple0.0, tuple0.0)
5119   tuple1.0 = (s32[]) tuple(const1)
5120   tuple1 = ((s32[]), (s32[])) tuple(tuple1.0, tuple1.0)
5121   tuple2 = (bf16[48]{0}, bf16[48]{0}) tuple(negate0.1, negate0.2)
5122   collective-permute-start.1 = ((bf16[48]{0}, bf16[48]{0}), (bf16[48]{0}, bf16[48]{0}), u32[], u32[]) collective-permute-start(tuple2, tuple2, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
5123   negate2 = bf16[4]{0} negate(param)
5124   negate3 = bf16[4]{0} negate(negate2)
5125   negate4 = bf16[4]{0} negate(negate3)
5126   collective-permute-done.1 = (bf16[48]{0}, bf16[48]{0}) collective-permute-done(collective-permute-start.1)
5127   collective-permute-start.2 = ((bf16[48]{0}, bf16[48]{0}), (bf16[48]{0}, bf16[48]{0}), u32[], u32[]) collective-permute-start(collective-permute-done.1, collective-permute-done.1, tuple0, tuple1), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{1}}
5128   negate5 = bf16[4]{0} negate(negate4)
5129   negate6 = bf16[4]{0} negate(negate5)
5130   negate7 = bf16[4]{0} negate(negate6)
5131   collective-permute-done.2 = (bf16[48]{0}, bf16[48]{0}) collective-permute-done(collective-permute-start.2)
5132   gte = bf16[48]{0} get-tuple-element(collective-permute-done.2), index=0
5133   ROOT root = (bf16[48]{0}, bf16[4]{0}) tuple(gte, negate7)
5134 }
5135   )";
5136 
5137   TF_ASSERT_OK_AND_ASSIGN(auto module,
5138                           ParseAndReturnVerifiedModule(hlo_string));
5139   AssignMemorySpace(module.get());
5140 
5141   const HloInstruction* cp_done1 =
5142       FindInstruction(module.get(), "collective-permute-done.1");
5143   EXPECT_EQ(cp_done1->operand(0)->opcode(), HloOpcode::kCollectivePermuteStart);
5144   const HloInstruction* cp_done2 =
5145       FindInstruction(module.get(), "collective-permute-done.2");
5146   EXPECT_EQ(cp_done2->operand(0)->opcode(), HloOpcode::kCollectivePermuteStart);
5147 }
5148 
TEST_P(MemorySpaceAssignmentTest,ReservedScopedMemory)5149 TEST_P(MemorySpaceAssignmentTest, ReservedScopedMemory) {
5150   absl::string_view hlo_string = R"(
5151 HloModule module, is_scheduled=true
5152 
5153 ENTRY entry {
5154   param0 = f32[2,4] parameter(0)
5155   a = f32[2,4] negate(param0)
5156   b = f32[2,4] negate(a)
5157   c = f32[2,4] negate(b)
5158   d = f32[2,4] negate(c)
5159   e = f32[2,4] negate(d)
5160   ROOT f = f32[2,4] add(e, b)
5161 }
5162   )";
5163 
5164   TF_ASSERT_OK_AND_ASSIGN(auto module,
5165                           ParseAndReturnVerifiedModule(hlo_string));
5166   Options options;
5167   options.max_size_in_bytes = 128;
5168   options.alignment_in_bytes = 8;
5169   options.verify = true;
5170   // Make instruction c reserve 64 bytes in the alternate memory. This should
5171   // prevent both b and c to put their outputs in the alternate memory.
5172   options.reserved_scoped_memory_fn = [&](const HloInstruction* instruction) {
5173     if (instruction->name() == "c") {
5174       return 100;
5175     }
5176     return 0;
5177   };
5178   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5179                     /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
5180                     options);
5181   auto get_memory_space = [&](absl::string_view instruction_name) {
5182     return module->entry_computation()
5183         ->GetInstructionWithName(instruction_name)
5184         ->shape()
5185         .layout()
5186         .memory_space();
5187   };
5188   EXPECT_TRUE(get_memory_space("a") == kAlternateMemorySpace);
5189   EXPECT_TRUE(get_memory_space("b") == kDefaultMemorySpace);
5190   EXPECT_TRUE(get_memory_space("c") == kDefaultMemorySpace);
5191   EXPECT_TRUE(get_memory_space("d") == kAlternateMemorySpace);
5192   EXPECT_TRUE(get_memory_space("e") == kAlternateMemorySpace);
5193 }
5194 
TEST_P(MemorySpaceAssignmentTest,ConstantAllocationFar)5195 TEST_P(MemorySpaceAssignmentTest, ConstantAllocationFar) {
5196   absl::string_view hlo_string = R"(
5197 HloModule module, is_scheduled=true
5198 
5199 ENTRY entry {
5200   param0 = f32[2,4] parameter(0)
5201   const = f32[2,4] constant({...})
5202   a = f32[2,4] negate(param0)
5203   b = f32[2,4] negate(a)
5204   c = f32[2,4] negate(b)
5205   d = f32[2,4] negate(c)
5206   e = f32[2,4] negate(d)
5207   ROOT negate = f32[2,4] add(const, e)
5208 }
5209   )";
5210 
5211   TF_ASSERT_OK_AND_ASSIGN(auto module,
5212                           ParseAndReturnVerifiedModule(hlo_string));
5213   AssignMemorySpace(module.get());
5214   EXPECT_TRUE(module->entry_computation()
5215                   ->GetInstructionWithName("const")
5216                   ->shape()
5217                   .layout()
5218                   .memory_space() == kDefaultMemorySpace);
5219   EXPECT_TRUE(module->entry_computation()
5220                   ->GetInstructionWithName("negate")
5221                   ->operand(0)
5222                   ->shape()
5223                   .layout()
5224                   .memory_space() == kAlternateMemorySpace);
5225 }
5226 
TEST_P(MemorySpaceAssignmentTest,ConstantAllocationNear)5227 TEST_P(MemorySpaceAssignmentTest, ConstantAllocationNear) {
5228   absl::string_view hlo_string = R"(
5229 HloModule module, is_scheduled=true
5230 
5231 ENTRY entry {
5232   param0 = f32[2,4] parameter(0)
5233   a = f32[2,4] negate(param0)
5234   b = f32[2,4] negate(a)
5235   c = f32[2,4] negate(b)
5236   d = f32[2,4] negate(c)
5237   e = f32[2,4] negate(d)
5238   const = f32[2,4] constant({...})
5239   ROOT negate = f32[2,4] add(const, e)
5240 }
5241   )";
5242 
5243   TF_ASSERT_OK_AND_ASSIGN(auto module,
5244                           ParseAndReturnVerifiedModule(hlo_string));
5245   AssignMemorySpace(module.get());
5246   EXPECT_TRUE(module->entry_computation()
5247                   ->GetInstructionWithName("const")
5248                   ->shape()
5249                   .layout()
5250                   .memory_space() == kDefaultMemorySpace);
5251   EXPECT_TRUE(module->entry_computation()
5252                   ->GetInstructionWithName("negate")
5253                   ->operand(0)
5254                   ->shape()
5255                   .layout()
5256                   .memory_space() == kAlternateMemorySpace);
5257 }
5258 
5259 // A mock MemorySpaceAssignmentRepacker class that accepst a map of
5260 // (start_time,offset) -> new_offset values. Using this map, the repacker
5261 // repacks the allocations to the new_offset.
5262 class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker {
5263  public:
FakeMemorySpaceAssignmentRepacker(absl::flat_hash_map<std::pair<int64_t,int64_t>,int64_t> & repack_map,std::function<void (absl::Span<AllocationBlock * >)> check_fun=nullptr,bool always_return_modified=false)5264   explicit FakeMemorySpaceAssignmentRepacker(
5265       absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t>& repack_map,
5266       std::function<void(absl::Span<AllocationBlock*>)> check_fun = nullptr,
5267       bool always_return_modified = false)
5268       : MemorySpaceAssignmentRepacker(/*max_size=*/128, /*alignment=*/8),
5269         repack_map_(repack_map),
5270         check_fun_(check_fun),
5271         always_return_modified_(always_return_modified) {}
5272 
Repack(absl::Span<AllocationBlock * > allocations)5273   StatusOr<bool> Repack(absl::Span<AllocationBlock*> allocations) override {
5274     bool modified = false;
5275     for (AllocationBlock* block : allocations) {
5276       absl::flat_hash_set<int64_t> colocations;
5277       std::string colocations_str;
5278       for (const AllocationBlock* colocation : block->colocations) {
5279         absl::StrAppend(&colocations_str, colocation->id, ", ");
5280         colocations.insert(colocation->id);
5281       }
5282       VLOG(1) << "Alloc id: " << block->id << " time: [" << block->start_time
5283               << ", " << block->end_time << "] size: " << block->size
5284               << " init offset: " << block->initial_offset << " colocations: {"
5285               << colocations_str << "}";
5286       auto it = repack_map_.find({block->start_time, block->initial_offset});
5287       if (it != repack_map_.end()) {
5288         modified = true;
5289         block->offset = it->second;
5290       } else {
5291         block->offset = block->initial_offset;
5292       }
5293       for (AllocationBlock* colocation : block->colocations) {
5294         if (it != repack_map_.end()) {
5295           colocation->offset = it->second;
5296         } else {
5297           colocation->offset = colocation->initial_offset;
5298         }
5299       }
5300     }
5301     if (check_fun_) {
5302       check_fun_(allocations);
5303     }
5304 
5305     return always_return_modified_ || modified;
5306   }
5307 
5308  private:
5309   // A map from (start_time, offset) to new_offset.
5310   absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> repack_map_;
5311   std::function<void(absl::Span<AllocationBlock*>)> check_fun_;
5312   bool always_return_modified_;
5313 };
5314 
TEST_P(MemorySpaceAssignmentTest,Repack)5315 TEST_P(MemorySpaceAssignmentTest, Repack) {
5316   // We initially perform the following allocations at these offsets.
5317   //
5318   //    Max memory
5319   //  -------------------------------------------
5320   //
5321   //
5322   //
5323   //
5324   //      +------------+
5325   //      |     b      |
5326   //      +------------+
5327   //  +-------+                 +------------+
5328   //  |   a   |                 |     n      |
5329   //  +-------+                 +------------+
5330   //  -------------------------------------------
5331   //    Min memory          time ->
5332   //
5333   // Next up, we try to allocate the prefetch for m. However due to
5334   // fragmentation, this won't be possible:
5335   //
5336   //    Max memory
5337   //  -------------------------------------------
5338   //
5339   //
5340   //
5341   //                +---------+
5342   //      +------------+      |
5343   //      |     b   |  |      |
5344   //      +------------+      |
5345   //  +-------+     |         | +------------+
5346   //  |   a   |     |    d    | |     n      |
5347   //  +-------+     +---------+ +------------+
5348   //  -------------------------------------------
5349   //    Min memory          time ->
5350   //
5351   // We then call repack to repack the existing allocations which allows us to
5352   // allocate the prefetch for m:
5353   //
5354   //    Max memory
5355   //  -------------------------------------------
5356   //                +---------+
5357   //                |         |
5358   //                |         |
5359   //                |         |
5360   //  +-------+     |         |
5361   //  |   a   |     |    d    |
5362   //  +-------+     +---------+
5363   //      +------------+        +------------+
5364   //      |      b     |        |     n      |
5365   //      +------------+        +------------+
5366   //  -------------------------------------------
5367   //    Min memory          time ->
5368   absl::string_view hlo_string = R"(
5369   HloModule bug, is_scheduled=true
5370 
5371   ENTRY Entry {
5372     param0 = f32[8,3] parameter(0)
5373     param1 = f32[2,4] parameter(1)
5374     a = f32[2,4] sine(param1)
5375     b = f32[2,4] cosine(param1)
5376     c = f32[8,3] negate(param0)
5377     j = f32[2,4] negate(a)
5378     d = f32[8,3] tanh(param0)
5379     k = f32[2,4] negate(j)
5380     l = f32[2,4] add(b, k)
5381     m = f32[8,3] negate(d)
5382     n = f32[2,4] sine(l)
5383     o = f32[8,3] negate(m)
5384     p = f32[2,4] negate(n)
5385     q = f32[8,3] negate(m)
5386     ROOT tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(p, q, o)
5387   }
5388   )";
5389 
5390   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
5391       [](const MemorySpaceAssignment::BufferInterval& a,
5392          const MemorySpaceAssignment::BufferInterval& b) {
5393         auto get_opcode_priority = [](const HloOpcode& opcode) {
5394           switch (opcode) {
5395             case HloOpcode::kSin:
5396               return 0;
5397             case HloOpcode::kCos:
5398               return 1;
5399             case HloOpcode::kTanh:
5400               return 2;
5401             default:
5402               return 3;
5403           }
5404         };
5405 
5406         return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
5407                get_opcode_priority(b.buffer->defining_instruction()->opcode());
5408       };
5409   TF_ASSERT_OK_AND_ASSIGN(auto module,
5410                           ParseAndReturnVerifiedModule(hlo_string));
5411 
5412   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
5413   absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> repack_map;
5414   // Move "a" from offset 0 to 32.
5415   repack_map[{2, 0}] = 32;
5416   // Move "b" from offset 32 to 0.
5417   repack_map[{3, 32}] = 0;
5418   FakeMemorySpaceAssignmentRepacker repacker =
5419       FakeMemorySpaceAssignmentRepacker(repack_map);
5420   Options options;
5421   options.max_size_in_bytes = 128;
5422   options.alignment_in_bytes = 8;
5423   options.verify = true;
5424   options.max_repacks = 1;
5425   options.repacker = &repacker;
5426   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5427                     buffer_interval_compare, &prefetch_interval_picker,
5428                     options);
5429 
5430   // If repacking succeeds, we should find the buffer for d in alternate memory.
5431   const HloInstruction* d =
5432       module->entry_computation()->GetInstructionWithName("d");
5433   EXPECT_EQ(d->shape().layout().memory_space(), kAlternateMemorySpace);
5434 }
5435 
TEST_P(MemorySpaceAssignmentTest,RepackExportsAliasedOffsets)5436 TEST_P(MemorySpaceAssignmentTest, RepackExportsAliasedOffsets) {
5437   // This test is that we are correctly exporting aliased offsets for repacking.
5438   // In this example, the buffer produced at HLO "a" will be allocated first,
5439   // and will consist of four allocations:
5440   //    1) a produced in the alternate memory (and then evicted to the default
5441   //    memory). 2) a prefetched to the alternate memory to be used by q and
5442   //    while HLOs. 3) a used within the while loop body. 4) the output of while
5443   //    HLO, used by u.
5444   //
5445   // Since a will be allocated first (the test is crafted to prioritize sine
5446   // HLO), all four allocations should get the same (zero) offsets. However,
5447   // while allocations 2, 3, and 4 need to be colocated with each other,
5448   // allocation 1 doesn't need to be colocated with the other three.
5449   absl::string_view hlo_string = R"(
5450   HloModule bug, is_scheduled=true
5451 
5452   while_condition {
5453     param1 = (f32[2,4], f32[2,4]) parameter(0)
5454     ROOT cond = pred[] constant(true)
5455   }
5456 
5457   while_body {
5458     param2 = (f32[2,4], f32[2,4]) parameter(0)
5459     gte2 = f32[2,4] get-tuple-element(param2), index=0
5460     gte3 = f32[2,4] get-tuple-element(param2), index=1
5461     add = f32[2,4] add(gte2, gte3)
5462     ROOT tuple2 = (f32[2,4], f32[2,4]) tuple(add, gte3)
5463   }
5464 
5465   ENTRY Entry {
5466     param0 = f32[2,4] parameter(0)
5467     a = f32[2,4] sine(param0)
5468     b = f32[2,4] negate(a)
5469     c = f32[2,4] negate(b)
5470     d = f32[2,4] negate(c)
5471     e = f32[2,4] negate(d)
5472     f = f32[2,4] negate(e)
5473     g = f32[2,4] negate(f)
5474     h = f32[2,4] negate(g)
5475     i = f32[2,4] negate(h)
5476     j = f32[2,4] negate(i)
5477     k = f32[2,4] negate(j)
5478     l = f32[2,4] negate(k)
5479     m = f32[2,4] negate(l)
5480     n = f32[2,4] negate(m)
5481     o = f32[2,4] negate(n)
5482     p = f32[2,4] negate(o)
5483     q = f32[2,4] add(p, a)
5484     tuple = (f32[2,4], f32[2,4]) tuple(q, a)
5485     while = (f32[2,4], f32[2,4]) while(tuple), condition=while_condition, body=while_body
5486     gte0 = f32[2,4] get-tuple-element(while), index=0
5487     gte1 = f32[2,4] get-tuple-element(while), index=1
5488     r = f32[2,4] negate(gte0)
5489     s = f32[2,4] negate(r)
5490     t = f32[2,4] negate(s)
5491     constant = f32[] constant(0)
5492     broadcast = f32[8,4] broadcast(constant), dimensions={}
5493     cos = f32[8,4] cosine(broadcast)
5494     u = f32[2,4] add(t, gte1)
5495     v = f32[2,4] add(u, param0)
5496     w = f32[8,4] negate(cos)
5497     ROOT tuple3 = (f32[2,4], f32[8,4]) tuple(v, w)
5498   }
5499   )";
5500 
5501   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
5502       [](const MemorySpaceAssignment::BufferInterval& a,
5503          const MemorySpaceAssignment::BufferInterval& b) {
5504         auto get_opcode_priority = [](const HloOpcode& opcode) {
5505           switch (opcode) {
5506             case HloOpcode::kSin:
5507               return 0;
5508             case HloOpcode::kCos:
5509               return 1;
5510             case HloOpcode::kTanh:
5511               return 2;
5512             default:
5513               return 3;
5514           }
5515         };
5516 
5517         return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
5518                get_opcode_priority(b.buffer->defining_instruction()->opcode());
5519       };
5520   TF_ASSERT_OK_AND_ASSIGN(auto module,
5521                           ParseAndReturnVerifiedModule(hlo_string));
5522 
5523   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
5524   absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> repack_map;
5525 
5526   // Expect that of the four separate allocations for the "a" buffer, the first
5527   // and the next three are in separate colocations.
5528   auto check_fun =
5529       [](absl::Span<MemorySpaceAssignmentRepacker::AllocationBlock*>
5530              allocations) {
5531         EXPECT_TRUE(allocations.at(0)->colocations.size() == 1 ||
5532                     allocations.at(0)->colocations.size() == 3);
5533         EXPECT_EQ(allocations.at(1)->colocations.size(), 3);
5534         EXPECT_EQ(allocations.at(2)->colocations.size(), 3);
5535         EXPECT_TRUE(allocations.at(3)->colocations.size() == 1 ||
5536                     allocations.at(3)->colocations.size() == 3);
5537       };
5538   FakeMemorySpaceAssignmentRepacker repacker =
5539       FakeMemorySpaceAssignmentRepacker(repack_map, check_fun);
5540   Options options;
5541   options.max_size_in_bytes = 128;
5542   options.alignment_in_bytes = 8;
5543   options.verify = true;
5544   options.max_repacks = 1;
5545   options.repacker = &repacker;
5546   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5547                     buffer_interval_compare, &prefetch_interval_picker,
5548                     options);
5549 }
5550 
TEST_P(MemorySpaceAssignmentTest,RepackExportsAliasedOffsetsForReservedScopedMemory)5551 TEST_P(MemorySpaceAssignmentTest,
5552        RepackExportsAliasedOffsetsForReservedScopedMemory) {
5553   absl::string_view hlo_string = R"(
5554 HloModule module, is_scheduled=true
5555 
5556 ENTRY entry {
5557   param0 = f32[2,4] parameter(0)
5558   a = f32[2,4] negate(param0)
5559   b = f32[2,4] negate(a)
5560   c = f32[2,4] negate(b)
5561   d = f32[2,4] negate(c)
5562   e = f32[2,4] negate(d)
5563   ROOT f = f32[2,4] add(e, b)
5564 }
5565   )";
5566   TF_ASSERT_OK_AND_ASSIGN(auto module,
5567                           ParseAndReturnVerifiedModule(hlo_string));
5568   Options options;
5569   options.max_size_in_bytes = 128;
5570   options.alignment_in_bytes = 8;
5571   options.verify = true;
5572   options.max_repacks = 1;
5573   // Make two instructions reserve scoped memory.
5574   options.reserved_scoped_memory_fn = [&](const HloInstruction* instruction) {
5575     if (instruction->name() == "c" || instruction->name() == "d") {
5576       return 100;
5577     }
5578     return 0;
5579   };
5580 
5581   absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> repack_map;
5582   bool repacker_ran = false;
5583 
5584   // Expect that the first two value to repack has a colocations size of 2,
5585   // corresponding to the scoped allocations.
5586   auto check_fun =
5587       [&](absl::Span<MemorySpaceAssignmentRepacker::AllocationBlock*>
5588               allocations) {
5589         EXPECT_EQ(allocations.at(0)->colocations.size(), 2);
5590         EXPECT_EQ(allocations.at(1)->colocations.size(), 2);
5591         repacker_ran = true;
5592       };
5593   FakeMemorySpaceAssignmentRepacker repacker =
5594       FakeMemorySpaceAssignmentRepacker(repack_map, check_fun);
5595   options.repacker = &repacker;
5596   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5597                     /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
5598                     options);
5599   EXPECT_TRUE(repacker_ran);
5600 }
5601 
TEST_P(MemorySpaceAssignmentTest,RepackShouldntEraseRequiredAssignmentForConditionalOutput)5602 TEST_P(MemorySpaceAssignmentTest,
5603        RepackShouldntEraseRequiredAssignmentForConditionalOutput) {
5604   // This is a test case for b/171040271. Repacks erase the required assignments
5605   // (since some required assignments are inserted conditionally based on
5606   // allocation decisions), including the fact that conditional outputs are
5607   // always required to get assignments in the default memory. After repacking,
5608   // this required assignment was never added back, causing conditionals to get
5609   // alternate-memory allocations.
5610   absl::string_view hlo_string = R"(
5611   HloModule CondAllocation, is_scheduled=true
5612 
5613   true_computation {
5614     p0 = (f32[3]) parameter(0)
5615     gte = f32[3] get-tuple-element(p0), index=0
5616     neg1 = f32[3] negate(gte)
5617     ROOT tuple1 = (f32[3]) tuple(neg1)
5618   }
5619 
5620   false_computation {
5621     p0 = (f32[3]) parameter(0)
5622     gte = f32[3] get-tuple-element(p0), index=0
5623     neg2 = f32[3] negate(gte)
5624     ROOT tuple2 = (f32[3]) tuple(neg2)
5625   }
5626 
5627   ENTRY entry {
5628     p0 = f32[3] parameter(0)
5629     p1 = pred[] parameter(1)
5630     copy = f32[3] copy(p0)
5631     tuple = (f32[3]) tuple(copy)
5632     conditional = (f32[3]) conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
5633     ROOT gte = f32[3] get-tuple-element(conditional), index=0
5634   }
5635   )";
5636   TF_ASSERT_OK_AND_ASSIGN(auto module,
5637                           ParseAndReturnVerifiedModule(hlo_string));
5638   absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> repack_map;
5639   FakeMemorySpaceAssignmentRepacker repacker =
5640       FakeMemorySpaceAssignmentRepacker(repack_map, nullptr,
5641                                         /*always_return_modified=*/true);
5642   Options options;
5643   options.max_size_in_bytes = 128;
5644   options.alignment_in_bytes = 8;
5645   options.verify = true;
5646   options.max_repacks = 10;
5647   options.repacker = &repacker;
5648   options.repack_after_every_allocation = true;
5649   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
5650   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5651                     /*buffer_interval_compare=*/{}, &prefetch_interval_picker,
5652                     options);
5653 }
5654 
TEST_P(MemorySpaceAssignmentTest,Determinism)5655 TEST_P(MemorySpaceAssignmentTest, Determinism) {
5656   // Run memory space assignment a few times to make sure every time it compiles
5657   // to the same thing.
5658   std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
5659 
5660   AssignMemorySpace(module.get());
5661   std::string module_str = module->ToString();
5662 
5663   for (int i = 0; i < 10; ++i) {
5664     std::unique_ptr<HloModule> other_module = CreateEvictAndPrefetchModule();
5665     AssignMemorySpace(other_module.get());
5666     EXPECT_EQ(module_str, other_module->ToString());
5667   }
5668 }
5669 
TEST_P(MemorySpaceAssignmentTest,InPlaceOp)5670 TEST_P(MemorySpaceAssignmentTest, InPlaceOp) {
5671   // Tests that in-place ops like DynamicUpdateSlice get the same allocation as
5672   // its input.
5673   absl::string_view hlo_string = R"(
5674 HloModule Module, is_scheduled=true
5675 
5676 fused_computation {
5677   param0 = f32[2,3] parameter(0)
5678   constant.1 = f32[] constant(0)
5679   broadcast = f32[2,1] broadcast(constant.1), dimensions={}
5680   constant.3 = s32[] constant(0)
5681   ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
5682 }
5683 
5684 ENTRY main {
5685   param = f32[2,3] parameter(0)
5686   negate = f32[2,3] negate(param)
5687   fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation
5688   ROOT add = f32[2,3] add(fusion, fusion)
5689 }
5690   )";
5691 
5692   TF_ASSERT_OK_AND_ASSIGN(auto module,
5693                           ParseAndReturnVerifiedModule(hlo_string));
5694   auto preset_assignments = AssignMemorySpace(module.get());
5695   HloInstruction* negate_instruction =
5696       module->entry_computation()->GetInstructionWithName("negate");
5697   int64_t negate_offset =
5698       GetAlternateMemoryOffset(*preset_assignments, negate_instruction);
5699   HloInstruction* fusion_instruction =
5700       module->entry_computation()->GetInstructionWithName("fusion");
5701   int64_t fusion_offset =
5702       GetAlternateMemoryOffset(*preset_assignments, fusion_instruction);
5703   // We expect negate and fusion to get the same offsets.
5704   EXPECT_EQ(negate_offset, fusion_offset);
5705   const bool allocate_across_sequential_calls = GetParam();
5706   if (allocate_across_sequential_calls) {
5707     EXPECT_NE(negate_offset, -1);
5708   }
5709 }
5710 
TEST_P(MemorySpaceAssignmentTest,ConditionalInPlaceOp)5711 TEST_P(MemorySpaceAssignmentTest, ConditionalInPlaceOp) {
5712   absl::string_view hlo_string = R"(
5713 HloModule Module, is_scheduled=true
5714 
5715 fused_computation {
5716   param0 = f32[2,3] parameter(0)
5717   constant.1 = f32[] constant(0)
5718   broadcast = f32[2,1] broadcast(constant.1), dimensions={}
5719   constant.3 = s32[] constant(0)
5720   ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
5721 }
5722 
5723 true_computation {
5724   p0 = (f32[2,3]) parameter(0)
5725   gte = f32[2,3] get-tuple-element(p0), index=0
5726   ROOT neg1 = f32[2,3] negate(gte)
5727 }
5728 
5729 false_computation {
5730   p0 = (f32[2,3]) parameter(0)
5731   gte = f32[2,3] get-tuple-element(p0), index=0
5732   neg2 = f32[2,3] negate(gte)
5733   ROOT fusion = f32[2,3] fusion(neg2), kind=kLoop, calls=fused_computation
5734 }
5735 
5736 ENTRY entry {
5737   p0 = f32[2,3] parameter(0)
5738   p1 = pred[] parameter(1)
5739   copy = f32[2,3] copy(p0)
5740   tuple = (f32[2,3]) tuple(copy)
5741   ROOT conditional = f32[2,3] conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
5742 }
5743   )";
5744 
5745   TF_ASSERT_OK_AND_ASSIGN(auto module,
5746                           ParseAndReturnVerifiedModule(hlo_string));
5747   AssignMemorySpace(module.get());
5748 }
5749 
5750 INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
5751                          MemorySpaceAssignmentTest,
5752                          ::testing::Values(false, true));
5753 
5754 using AsynchronousCopyResourceTest = ::testing::Test;
5755 
TEST_F(AsynchronousCopyResourceTest,Simple)5756 TEST_F(AsynchronousCopyResourceTest, Simple) {
5757   // time:      0 1 2 3 4 5 6 7 8 9
5758   // resource:  2 3 1 6 7 1 7 2 2 4
5759   // -1,3,5    +-----+                OK
5760   // resource:  0 0 1 6 7 1 7 2 2 4
5761   //  1,4,4        +---+              OK
5762   // resource:  0 0 0 3 7 1 7 2 2 4
5763   //  5,9,10               +-----+
5764   // resource:  0 0 0 3 7 1 0 0 1 4
5765   //  4,9,3              +-------+    Violate
5766   //  4,8,2              +-----+      OK; The 5,9 copy shifts resource to right.
5767   // resource:  0 0 0 3 7 0 0 0 0 4
5768   auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5769   AsynchronousCopyResource resource(
5770       {2.0, 3.0, 1.0, 6.0, 7.0, 1.0, 7.0, 2.0, 2.0, 4.0});
5771   EXPECT_TRUE(resource.HasEnoughResource(-1, 3, 5.0));
5772   resource.AddCopy({-1, 3, 5.0, alternate_mem_space, 0});
5773   EXPECT_TRUE(resource.HasEnoughResource(1, 4, 4.0));
5774   resource.AddCopy({1, 4, 4.0, alternate_mem_space, 1});
5775   EXPECT_TRUE(resource.HasEnoughResource(5, 9, 10.0));
5776   resource.AddCopy({5, 9, 10.0, alternate_mem_space, 2});
5777   EXPECT_FALSE(resource.HasEnoughResource(4, 9, 3.0));
5778   EXPECT_TRUE(resource.HasEnoughResource(4, 8, 2.0));
5779   resource.AddCopy({4, 8, 2.0, alternate_mem_space, 3});
5780 }
5781 
TEST_F(AsynchronousCopyResourceTest,Propagate)5782 TEST_F(AsynchronousCopyResourceTest, Propagate) {
5783   // time:      0 1 2 3 4 5 6 7 8 9
5784   // resource:  2 2 2 2 2 2 2 2 2 2
5785   // 6,10,2                  +-----+   OK
5786   // resource:  2 2 2 2 2 2 2 0 2 2
5787   // 5,9,2                 +-----+     OK
5788   // resource:  2 2 2 2 2 2 0 0 2 2
5789   // 4,8,2               +-----+       OK
5790   // resource:  2 2 2 2 2 0 0 0 2 2
5791   // 3,7,2             +-----+         OK
5792   // resource:  2 2 2 2 0 0 0 0 2 2
5793   // 2,6,2           +-----+           OK
5794   // resource:  2 2 2 0 0 0 0 0 2 2
5795   // 1,5,2         +-----+             OK
5796   // resource:  2 2 0 0 0 0 0 0 2 2
5797   // 0,4,3       +-----+               OK
5798   // resource:  2 0 0 0 0 0 0 0 1 2
5799   // 0,4,3       +-----+               OK
5800   // resource:  2 0 0 0 0 0 0 0 0 0
5801   // 0,4,1       +-----+               Violate
5802   auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5803   AsynchronousCopyResource resource(
5804       {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0});
5805   EXPECT_TRUE(resource.HasEnoughResource(6, 10, 2.0));
5806   resource.AddCopy({6, 10, 2.0, alternate_mem_space, 0});
5807   EXPECT_EQ(
5808       resource.GetCurrentResources(),
5809       std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 2.0, 2.0}));
5810   EXPECT_TRUE(resource.HasEnoughResource(5, 9, 2.0));
5811   resource.AddCopy({5, 9, 2.0, alternate_mem_space, 1});
5812   EXPECT_TRUE(resource.HasEnoughResource(4, 8, 2.0));
5813   resource.AddCopy({4, 8, 2.0, alternate_mem_space, 2});
5814   EXPECT_TRUE(resource.HasEnoughResource(3, 7, 2.0));
5815   resource.AddCopy({3, 7, 2.0, alternate_mem_space, 3});
5816   EXPECT_TRUE(resource.HasEnoughResource(2, 6, 2.0));
5817   resource.AddCopy({2, 6, 2.0, alternate_mem_space, 4});
5818   EXPECT_TRUE(resource.HasEnoughResource(1, 5, 2.0));
5819   resource.AddCopy({1, 5, 2.0, alternate_mem_space, 5});
5820   EXPECT_EQ(
5821       resource.GetCurrentResources(),
5822       std::vector<float>({2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0}));
5823   EXPECT_TRUE(resource.HasEnoughResource(0, 4, 3.0));
5824   resource.AddCopy({0, 4, 3.0, alternate_mem_space, 6});
5825   EXPECT_EQ(
5826       resource.GetCurrentResources(),
5827       std::vector<float>({2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0}));
5828   EXPECT_TRUE(resource.HasEnoughResource(0, 4, 3.0));
5829   resource.AddCopy({0, 4, 3.0, alternate_mem_space, 7});
5830   EXPECT_EQ(
5831       resource.GetCurrentResources(),
5832       std::vector<float>({2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
5833   EXPECT_FALSE(resource.HasEnoughResource(0, 4, 1.0));
5834 }
5835 
TEST_F(AsynchronousCopyResourceTest,CantPropagate)5836 TEST_F(AsynchronousCopyResourceTest, CantPropagate) {
5837   // time:      0 1 2 3 4 5 6 7 8 9
5838   // resource:  2 2 2 2 2 2 2 2 2 2
5839   // 5,10,2                +-------+   OK
5840   // resource:  2 2 2 2 2 2 0 2 2 2
5841   // 4,7,2               +---+         OK
5842   // resource:  2 2 2 2 2 0 0 2 2 2
5843   // 4,8,4               +-----+       OK
5844   // resource:  2 2 2 2 2 0 0 0 0 2
5845   // 3,6,4             +---+           Violate
5846   auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5847   AsynchronousCopyResource resource(
5848       {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0});
5849   EXPECT_TRUE(resource.HasEnoughResource(5, 10, 2.0));
5850   resource.AddCopy({5, 10, 2.0, alternate_mem_space, 0});
5851   EXPECT_EQ(
5852       resource.GetCurrentResources(),
5853       std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 2.0, 2.0, 2.0}));
5854   EXPECT_TRUE(resource.HasEnoughResource(4, 7, 2.0));
5855   resource.AddCopy({4, 7, 2.0, alternate_mem_space, 1});
5856   EXPECT_EQ(
5857       resource.GetCurrentResources(),
5858       std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 2.0, 2.0, 2.0}));
5859   EXPECT_TRUE(resource.HasEnoughResource(4, 8, 4.0));
5860   resource.AddCopy({4, 8, 4.0, alternate_mem_space, 2});
5861   EXPECT_EQ(
5862       resource.GetCurrentResources(),
5863       std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 2.0}));
5864   EXPECT_FALSE(resource.HasEnoughResource(3, 6, 4.0));
5865 }
5866 
TEST_F(AsynchronousCopyResourceTest,Nested)5867 TEST_F(AsynchronousCopyResourceTest, Nested) {
5868   // time:      0 1 2 3 4
5869   // resource:  2 2 2 2 2
5870   // 1,3,2         +-+       OK
5871   // resource:  2 2 0 2 2
5872   // 0,4,4       +-----+     Violate
5873   auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5874   AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0});
5875   EXPECT_TRUE(resource.HasEnoughResource(1, 3, 2.0));
5876   resource.AddCopy({1, 3, 2.0, alternate_mem_space, 0});
5877   EXPECT_EQ(resource.GetCurrentResources(),
5878             std::vector<float>({2.0, 2.0, 0.0, 2.0, 2.0}));
5879   EXPECT_FALSE(resource.HasEnoughResource(0, 4, 4.0));
5880 }
5881 
TEST_F(AsynchronousCopyResourceTest,Remove)5882 TEST_F(AsynchronousCopyResourceTest, Remove) {
5883   // time:      0 1 2 3 4
5884   // resource:  2 2 2 2 2
5885   // add:2,5,2       +---+   OK
5886   // resource:  2 2 2 0 2
5887   // add:-1,2,3+---+         OK
5888   // resource:  0 1 2 0 2
5889   // add:0,4,4   +-----+     OK
5890   // resource:  0 0 0 0 1
5891   // rem:0,4,4   +-----+
5892   // resource:  0 1 2 0 2
5893   // rem:2,5,2       +---+
5894   // resource:  0 1 2 2 2
5895   // rem:-1,2,3+---+
5896   // resource:  2 2 2 2 2
5897   auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5898   AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0});
5899   AsynchronousCopy copy1{2, 5, 2.0, alternate_mem_space, 0};
5900   AsynchronousCopy copy2{-1, 2, 3.0, alternate_mem_space, 1};
5901   AsynchronousCopy copy3{0, 4, 4.0, alternate_mem_space, 2};
5902   EXPECT_TRUE(resource.HasEnoughResource(2, 5, 2.0));
5903   resource.AddCopy(copy1);
5904   EXPECT_EQ(resource.GetCurrentResources(),
5905             std::vector<float>({2.0, 2.0, 2.0, 0.0, 2.0}));
5906   EXPECT_TRUE(resource.HasEnoughResource(-1, 2, 3.0));
5907   resource.AddCopy(copy2);
5908   EXPECT_EQ(resource.GetCurrentResources(),
5909             std::vector<float>({0.0, 1.0, 2.0, 0.0, 2.0}));
5910   EXPECT_TRUE(resource.HasEnoughResource(0, 4, 4.0));
5911   resource.AddCopy(copy3);
5912   EXPECT_EQ(resource.GetCurrentResources(),
5913             std::vector<float>({0.0, 0.0, 0.0, 0.0, 1.0}));
5914   resource.RemoveCopy(copy3);
5915   EXPECT_EQ(resource.GetCurrentResources(),
5916             std::vector<float>({0.0, 1.0, 2.0, 0.0, 2.0}));
5917   resource.RemoveCopy(copy1);
5918   EXPECT_EQ(resource.GetCurrentResources(),
5919             std::vector<float>({0.0, 1.0, 2.0, 2.0, 2.0}));
5920   resource.RemoveCopy(copy2);
5921   EXPECT_EQ(resource.GetCurrentResources(),
5922             std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0}));
5923 }
5924 
TEST_F(AsynchronousCopyResourceTest,NestedRemove)5925 TEST_F(AsynchronousCopyResourceTest, NestedRemove) {
5926   // time:      0 1 2 3 4
5927   // resource:  2 2 2 2 2
5928   // add:1,3,2     +-+       OK
5929   // resource:  2 2 0 2 2
5930   // add:0,4,4   +-----+     Violate
5931   // rem:1,3,2     +-+
5932   // resource:  2 2 2 2 2
5933   // add:0,4,4   +-----+     OK
5934   // resource:  2 0 0 2 2
5935   // add:1,3,2     +-+       Violate
5936   // rem:0,4,4   +-----+
5937   // resource:  2 2 2 2 2
5938   // add:1,3,2     +-+       OK
5939   // resource:  2 2 0 2 2
5940   auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5941   AsynchronousCopyResource resource({2.0, 2.0, 2.0, 2.0, 2.0});
5942   AsynchronousCopy copy1{1, 3, 2.0, alternate_mem_space, 0};
5943   AsynchronousCopy copy2{0, 4, 4.0, alternate_mem_space, 1};
5944   EXPECT_TRUE(resource.HasEnoughResource(1, 3, 2.0));
5945   resource.AddCopy(copy1);
5946   EXPECT_EQ(resource.GetCurrentResources(),
5947             std::vector<float>({2.0, 2.0, 0.0, 2.0, 2.0}));
5948   EXPECT_FALSE(resource.HasEnoughResource(0, 4, 4.0));
5949   resource.RemoveCopy(copy1);
5950   auto current_resources = resource.GetCurrentResources();
5951   EXPECT_EQ(resource.GetCurrentResources(),
5952             std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0}));
5953   EXPECT_TRUE(resource.HasEnoughResource(0, 4, 4.0));
5954   resource.AddCopy(copy2);
5955   EXPECT_EQ(resource.GetCurrentResources(),
5956             std::vector<float>({2.0, 0.0, 0.0, 2.0, 2.0}));
5957   EXPECT_FALSE(resource.HasEnoughResource(1, 3, 2.0));
5958   resource.RemoveCopy(copy2);
5959   EXPECT_EQ(resource.GetCurrentResources(),
5960             std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0}));
5961   EXPECT_TRUE(resource.HasEnoughResource(1, 3, 2.0));
5962 }
5963 
TEST_F(AsynchronousCopyResourceTest,PropagateRemove)5964 TEST_F(AsynchronousCopyResourceTest, PropagateRemove) {
5965   // time:      0 1 2 3 4 5 6 7 8 9
5966   // resource:  2 2 2 2 2 2 2 2 2 2
5967   // add:6,10,2              +-----+   OK
5968   // resource:  2 2 2 2 2 2 2 0 2 2
5969   // add:5,9,2             +-----+     OK
5970   // resource:  2 2 2 2 2 2 0 0 2 2
5971   // add:4,8,2           +-----+       OK
5972   // resource:  2 2 2 2 2 0 0 0 2 2
5973   // add:3,7,2         +-----+         OK
5974   // resource:  2 2 2 2 0 0 0 0 2 2
5975   // add:2,6,2       +-----+           OK
5976   // resource:  2 2 2 0 0 0 0 0 2 2
5977   // add:1,5,2     +-----+             OK
5978   // resource:  2 2 0 0 0 0 0 0 2 2
5979   // add:0,4,3   +-----+               OK
5980   // resource:  2 0 0 0 0 0 0 0 1 2
5981   // add:0,5,3   +-------+             OK
5982   // resource:  2 0 0 0 0 0 0 0 0 0
5983   // rem:0,5,3   +-------+
5984   // resource:  2 0 0 0 0 0 0 0 1 2
5985   // rem:0,4,3   +-----+
5986   // resource:  2 2 0 0 0 0 0 0 2 2
5987   auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
5988   AsynchronousCopyResource resource(
5989       {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0});
5990   EXPECT_TRUE(resource.HasEnoughResource(6, 10, 2.0));
5991   resource.AddCopy({6, 10, 2.0, alternate_mem_space, 0});
5992   EXPECT_EQ(
5993       resource.GetCurrentResources(),
5994       std::vector<float>({2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 2.0, 2.0}));
5995   EXPECT_TRUE(resource.HasEnoughResource(5, 9, 2.0));
5996   resource.AddCopy({5, 9, 2.0, alternate_mem_space, 1});
5997   EXPECT_TRUE(resource.HasEnoughResource(4, 8, 2.0));
5998   resource.AddCopy({4, 8, 2.0, alternate_mem_space, 2});
5999   EXPECT_TRUE(resource.HasEnoughResource(3, 7, 2.0));
6000   resource.AddCopy({3, 7, 2.0, alternate_mem_space, 3});
6001   EXPECT_TRUE(resource.HasEnoughResource(2, 6, 2.0));
6002   resource.AddCopy({2, 6, 2.0, alternate_mem_space, 4});
6003   EXPECT_TRUE(resource.HasEnoughResource(1, 5, 2.0));
6004   resource.AddCopy({1, 5, 2.0, alternate_mem_space, 5});
6005   EXPECT_EQ(
6006       resource.GetCurrentResources(),
6007       std::vector<float>({2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0}));
6008   AsynchronousCopy copy1{0, 4, 3.0, alternate_mem_space, 6};
6009   EXPECT_TRUE(resource.HasEnoughResource(0, 4, 3.0));
6010   resource.AddCopy(copy1);
6011   EXPECT_EQ(
6012       resource.GetCurrentResources(),
6013       std::vector<float>({2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0}));
6014   EXPECT_TRUE(resource.HasEnoughResource(0, 5, 3.0));
6015   AsynchronousCopy copy2{0, 5, 3.0, alternate_mem_space, 7};
6016   resource.AddCopy(copy2);
6017   EXPECT_EQ(
6018       resource.GetCurrentResources(),
6019       std::vector<float>({2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
6020   resource.RemoveCopy(copy2);
6021   EXPECT_EQ(
6022       resource.GetCurrentResources(),
6023       std::vector<float>({2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0}));
6024   resource.RemoveCopy(copy1);
6025   EXPECT_EQ(
6026       resource.GetCurrentResources(),
6027       std::vector<float>({2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0}));
6028 }
6029 
TEST_F(AsynchronousCopyResourceTest,StartAtZeroAndRemove)6030 TEST_F(AsynchronousCopyResourceTest, StartAtZeroAndRemove) {
6031   // time:      0 1 2 3 4
6032   // resource:  0 0 1 1 2
6033   // add:0,4,2   +-----+     OK
6034   // resource:  0 0 0 0 2
6035   // rem:0,4,2   +-----+
6036   // resource:  0 0 1 1 2
6037   // add:0,4,2   +-----+     OK
6038   // resource:  0 0 0 0 2
6039   auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
6040   AsynchronousCopyResource resource({0.0, 0.0, 1.0, 1.0, 2.0});
6041   AsynchronousCopy copy1{0, 4, 2.0, alternate_mem_space, 0};
6042   EXPECT_TRUE(resource.HasEnoughResource(0, 4, 2.0));
6043   resource.AddCopy(copy1);
6044   EXPECT_EQ(resource.GetCurrentResources(),
6045             std::vector<float>({0.0, 0.0, 0.0, 0.0, 2.0}));
6046   resource.RemoveCopy(copy1);
6047   EXPECT_EQ(resource.GetCurrentResources(),
6048             std::vector<float>({0.0, 0.0, 1.0, 1.0, 2.0}));
6049   resource.AddCopy(copy1);
6050   EXPECT_EQ(resource.GetCurrentResources(),
6051             std::vector<float>({0.0, 0.0, 0.0, 0.0, 2.0}));
6052 }
6053 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTest)6054 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTest) {
6055   HloComputation::Builder builder(TestName());
6056 
6057   constexpr int kBatch = 8;
6058   constexpr int kFeature = 8;
6059   constexpr int kOutput = 2;
6060 
6061   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6062   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6063   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6064   HloInstruction* lhs = builder.AddInstruction(
6065       HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
6066   HloInstruction* rhs = builder.AddInstruction(
6067       HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
6068 
6069   DotDimensionNumbers dot_dnums;
6070   dot_dnums.add_lhs_contracting_dimensions(1);
6071   dot_dnums.add_rhs_contracting_dimensions(0);
6072   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6073       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6074 
6075   auto module = CreateNewVerifiedModule();
6076   HloComputation* computation = module->AddEntryComputation(builder.Build());
6077 
6078   HloSchedule schedule(module.get());
6079   schedule.set_sequence(computation, {lhs, rhs, dot});
6080   TF_CHECK_OK(module->set_schedule(schedule));
6081 
6082   AssignMemorySpace(module.get());
6083 
6084   auto cross_program_prefetches = module->CrossProgramPrefetches();
6085   EXPECT_EQ(cross_program_prefetches.size(), 1);
6086   if (!cross_program_prefetches.empty()) {
6087     EXPECT_EQ(cross_program_prefetches[0].first, 1);
6088     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({}));
6089   }
6090 }
6091 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTupleTest)6092 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleTest) {
6093   HloComputation::Builder builder(TestName());
6094 
6095   constexpr int kBatch = 8;
6096   constexpr int kFeature = 8;
6097   constexpr int kOutput = 2;
6098 
6099   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6100   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6101   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6102   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6103   HloInstruction* param = builder.AddInstruction(
6104       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
6105 
6106   auto lhs = builder.AddInstruction(
6107       HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
6108   auto rhs = builder.AddInstruction(
6109       HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
6110 
6111   DotDimensionNumbers dot_dnums;
6112   dot_dnums.add_lhs_contracting_dimensions(1);
6113   dot_dnums.add_rhs_contracting_dimensions(0);
6114   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6115       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6116 
6117   auto module = CreateNewVerifiedModule();
6118   HloComputation* computation = module->AddEntryComputation(builder.Build());
6119 
6120   HloSchedule schedule(module.get());
6121   schedule.set_sequence(computation, {param, lhs, rhs, dot});
6122   TF_CHECK_OK(module->set_schedule(schedule));
6123 
6124   AssignMemorySpace(module.get());
6125 
6126   auto cross_program_prefetches = module->CrossProgramPrefetches();
6127   EXPECT_EQ(cross_program_prefetches.size(), 1);
6128   if (!cross_program_prefetches.empty()) {
6129     EXPECT_EQ(cross_program_prefetches[0].first, 0);
6130     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
6131   }
6132 }
6133 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchBitcastTest)6134 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTest) {
6135   HloComputation::Builder builder(TestName());
6136 
6137   constexpr int kBatch = 8;
6138   constexpr int kFeature = 8;
6139   constexpr int kOutput = 2;
6140 
6141   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6142   auto rhs_shape = ShapeUtil::MakeShape(F32, {kOutput, kFeature});
6143   auto bitcast_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6144   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6145   HloInstruction* lhs = builder.AddInstruction(
6146       HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
6147   HloInstruction* rhs = builder.AddInstruction(
6148       HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
6149 
6150   auto bitcast =
6151       builder.AddInstruction(HloInstruction::CreateBitcast(bitcast_shape, rhs));
6152 
6153   DotDimensionNumbers dot_dnums;
6154   dot_dnums.add_lhs_contracting_dimensions(1);
6155   dot_dnums.add_rhs_contracting_dimensions(0);
6156   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6157       result_shape, lhs, bitcast, dot_dnums, DefaultPrecisionConfig(2)));
6158 
6159   auto module = CreateNewVerifiedModule();
6160   HloComputation* computation = module->AddEntryComputation(builder.Build());
6161 
6162   HloSchedule schedule(module.get());
6163   schedule.set_sequence(computation, {lhs, rhs, bitcast, dot});
6164   TF_CHECK_OK(module->set_schedule(schedule));
6165 
6166   AssignMemorySpace(module.get());
6167 
6168   auto cross_program_prefetches = module->CrossProgramPrefetches();
6169   EXPECT_EQ(cross_program_prefetches.size(), 1);
6170   if (!cross_program_prefetches.empty()) {
6171     EXPECT_EQ(cross_program_prefetches[0].first, 1);
6172     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({}));
6173   }
6174 }
6175 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchBitcastTupleTest)6176 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTupleTest) {
6177   HloComputation::Builder builder(TestName());
6178 
6179   constexpr int kBatch = 8;
6180   constexpr int kFeature = 8;
6181   constexpr int kOutput = 2;
6182 
6183   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6184   auto rhs_shape = ShapeUtil::MakeShape(F32, {kOutput, kFeature});
6185   auto bitcast_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6186   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6187   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6188   HloInstruction* param = builder.AddInstruction(
6189       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
6190 
6191   auto lhs = builder.AddInstruction(
6192       HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
6193   auto rhs = builder.AddInstruction(
6194       HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
6195 
6196   auto bitcast =
6197       builder.AddInstruction(HloInstruction::CreateBitcast(bitcast_shape, rhs));
6198 
6199   DotDimensionNumbers dot_dnums;
6200   dot_dnums.add_lhs_contracting_dimensions(1);
6201   dot_dnums.add_rhs_contracting_dimensions(0);
6202   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6203       result_shape, lhs, bitcast, dot_dnums, DefaultPrecisionConfig(2)));
6204 
6205   auto module = CreateNewVerifiedModule();
6206   HloComputation* computation = module->AddEntryComputation(builder.Build());
6207 
6208   HloSchedule schedule(module.get());
6209   schedule.set_sequence(computation, {param, lhs, rhs, bitcast, dot});
6210   TF_CHECK_OK(module->set_schedule(schedule));
6211 
6212   AssignMemorySpace(module.get());
6213 
6214   auto cross_program_prefetches = module->CrossProgramPrefetches();
6215   EXPECT_EQ(cross_program_prefetches.size(), 1);
6216   if (!cross_program_prefetches.empty()) {
6217     EXPECT_EQ(cross_program_prefetches[0].first, 0);
6218     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
6219   }
6220 }
6221 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchNestedTupleTest)6222 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNestedTupleTest) {
6223   HloComputation::Builder builder(TestName());
6224 
6225   constexpr int kBatch = 8;
6226   constexpr int kFeature = 8;
6227   constexpr int kOutput = 2;
6228 
6229   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6230   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6231   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6232   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6233   auto tuple_tuple_shape = ShapeUtil::MakeTupleShape({tuple_shape});
6234   HloInstruction* param = builder.AddInstruction(
6235       HloInstruction::CreateParameter(0, tuple_tuple_shape, "p0"));
6236 
6237   auto gte = builder.AddInstruction(
6238       HloInstruction::CreateGetTupleElement(tuple_shape, param, 0));
6239 
6240   auto lhs = builder.AddInstruction(
6241       HloInstruction::CreateGetTupleElement(lhs_shape, gte, 0));
6242   auto rhs = builder.AddInstruction(
6243       HloInstruction::CreateGetTupleElement(rhs_shape, gte, 1));
6244 
6245   DotDimensionNumbers dot_dnums;
6246   dot_dnums.add_lhs_contracting_dimensions(1);
6247   dot_dnums.add_rhs_contracting_dimensions(0);
6248   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6249       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6250 
6251   auto module = CreateNewVerifiedModule();
6252   HloComputation* computation = module->AddEntryComputation(builder.Build());
6253 
6254   HloSchedule schedule(module.get());
6255   schedule.set_sequence(computation, {param, gte, lhs, rhs, dot});
6256   TF_CHECK_OK(module->set_schedule(schedule));
6257 
6258   AssignMemorySpace(module.get());
6259 
6260   auto cross_program_prefetches = module->CrossProgramPrefetches();
6261   EXPECT_EQ(cross_program_prefetches.size(), 0);
6262 }
6263 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchUnusedParamTest)6264 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchUnusedParamTest) {
6265   HloComputation::Builder builder(TestName());
6266 
6267   constexpr int kFeature = 8;
6268   constexpr int kOutput = 2;
6269 
6270   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6271   HloInstruction* param = builder.AddInstruction(
6272       HloInstruction::CreateParameter(0, rhs_shape, "p0"));
6273 
6274   auto module = CreateNewVerifiedModule();
6275   HloComputation* computation = module->AddEntryComputation(builder.Build());
6276 
6277   HloSchedule schedule(module.get());
6278   schedule.set_sequence(computation, {param});
6279   TF_CHECK_OK(module->set_schedule(schedule));
6280 
6281   AssignMemorySpace(module.get());
6282 
6283   auto cross_program_prefetches = module->CrossProgramPrefetches();
6284   EXPECT_EQ(cross_program_prefetches.size(), 0);
6285 }
6286 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTooBigTest)6287 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTooBigTest) {
6288   HloComputation::Builder builder(TestName());
6289 
6290   constexpr int kBatch = 8;
6291   constexpr int kFeature = 8;
6292   constexpr int kOutput = 8;
6293 
6294   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6295   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6296   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6297   HloInstruction* lhs = builder.AddInstruction(
6298       HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
6299   HloInstruction* rhs = builder.AddInstruction(
6300       HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
6301 
6302   DotDimensionNumbers dot_dnums;
6303   dot_dnums.add_lhs_contracting_dimensions(1);
6304   dot_dnums.add_rhs_contracting_dimensions(0);
6305   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6306       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6307 
6308   auto module = CreateNewVerifiedModule();
6309   HloComputation* computation = module->AddEntryComputation(builder.Build());
6310 
6311   HloSchedule schedule(module.get());
6312   schedule.set_sequence(computation, {lhs, rhs, dot});
6313   TF_CHECK_OK(module->set_schedule(schedule));
6314 
6315   AssignMemorySpace(module.get());
6316 
6317   auto cross_program_prefetches = module->CrossProgramPrefetches();
6318   EXPECT_EQ(cross_program_prefetches.size(), 0);
6319 }
6320 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTooBigTupleTest)6321 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTooBigTupleTest) {
6322   HloComputation::Builder builder(TestName());
6323 
6324   constexpr int kBatch = 8;
6325   constexpr int kFeature = 8;
6326   constexpr int kOutput = 8;
6327 
6328   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6329   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6330   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6331   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6332   HloInstruction* param = builder.AddInstruction(
6333       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
6334 
6335   auto lhs = builder.AddInstruction(
6336       HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
6337   auto rhs = builder.AddInstruction(
6338       HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
6339 
6340   DotDimensionNumbers dot_dnums;
6341   dot_dnums.add_lhs_contracting_dimensions(1);
6342   dot_dnums.add_rhs_contracting_dimensions(0);
6343   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6344       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6345 
6346   auto module = CreateNewVerifiedModule();
6347   HloComputation* computation = module->AddEntryComputation(builder.Build());
6348 
6349   HloSchedule schedule(module.get());
6350   schedule.set_sequence(computation, {param, lhs, rhs, dot});
6351   TF_CHECK_OK(module->set_schedule(schedule));
6352 
6353   AssignMemorySpace(module.get());
6354 
6355   auto cross_program_prefetches = module->CrossProgramPrefetches();
6356   EXPECT_EQ(cross_program_prefetches.size(), 0);
6357 }
6358 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchFusionTest)6359 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) {
6360   HloComputation::Builder builder(TestName());
6361 
6362   constexpr int kBatch = 2;
6363   constexpr int kFeature = 2;
6364   constexpr int kOutput = 2;
6365 
6366   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6367   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6368   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6369 
6370   auto module = CreateNewVerifiedModule();
6371   HloComputation::Builder fusion_builder("fusion");
6372   {
6373     HloInstruction* lhs = fusion_builder.AddInstruction(
6374         HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
6375     HloInstruction* rhs = fusion_builder.AddInstruction(
6376         HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
6377     DotDimensionNumbers dot_dnums;
6378     dot_dnums.add_lhs_contracting_dimensions(1);
6379     dot_dnums.add_rhs_contracting_dimensions(0);
6380     auto dot = fusion_builder.AddInstruction(HloInstruction::CreateDot(
6381         result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6382     (void)dot;
6383   }
6384   HloComputation* fusion_computation =
6385       module->AddEmbeddedComputation(fusion_builder.Build());
6386 
6387   auto activations = builder.AddInstruction(HloInstruction::CreateConstant(
6388       LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
6389   auto weights = builder.AddInstruction(HloInstruction::CreateConstant(
6390       LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
6391   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
6392       result_shape, HloInstruction::FusionKind::kCustom, {activations, weights},
6393       fusion_computation));
6394 
6395   HloComputation* computation = module->AddEntryComputation(builder.Build());
6396 
6397   HloSchedule schedule(module.get());
6398   schedule.set_sequence(computation, {activations, weights, fusion});
6399   TF_CHECK_OK(module->set_schedule(schedule));
6400 
6401   AssignMemorySpace(module.get());
6402 
6403   auto cross_program_prefetches = module->CrossProgramPrefetches();
6404   EXPECT_EQ(cross_program_prefetches.size(), 0);
6405 }
6406 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchFusionTupleTest)6407 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTupleTest) {
6408   HloComputation::Builder builder(TestName());
6409 
6410   constexpr int kBatch = 2;
6411   constexpr int kFeature = 2;
6412   constexpr int kOutput = 2;
6413 
6414   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6415   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
6416   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6417   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6418 
6419   auto module = CreateNewVerifiedModule();
6420   HloComputation::Builder fusion_builder("fusion");
6421   {
6422     HloInstruction* param = fusion_builder.AddInstruction(
6423         HloInstruction::CreateParameter(0, tuple_shape, "p0"));
6424     auto lhs = fusion_builder.AddInstruction(
6425         HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
6426     auto rhs = fusion_builder.AddInstruction(
6427         HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
6428     DotDimensionNumbers dot_dnums;
6429     dot_dnums.add_lhs_contracting_dimensions(1);
6430     dot_dnums.add_rhs_contracting_dimensions(0);
6431     auto dot = fusion_builder.AddInstruction(HloInstruction::CreateDot(
6432         result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6433     (void)dot;
6434   }
6435   HloComputation* fusion_computation =
6436       module->AddEmbeddedComputation(fusion_builder.Build());
6437 
6438   auto activations = builder.AddInstruction(HloInstruction::CreateConstant(
6439       LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
6440   auto weights = builder.AddInstruction(HloInstruction::CreateConstant(
6441       LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
6442   HloInstruction* tuple = builder.AddInstruction(
6443       HloInstruction::CreateTuple({activations, weights}));
6444   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
6445       result_shape, HloInstruction::FusionKind::kCustom, {tuple},
6446       fusion_computation));
6447 
6448   HloComputation* computation = module->AddEntryComputation(builder.Build());
6449 
6450   HloSchedule schedule(module.get());
6451   schedule.set_sequence(computation, {activations, weights, tuple, fusion});
6452   TF_CHECK_OK(module->set_schedule(schedule));
6453 
6454   AssignMemorySpace(module.get());
6455 
6456   auto cross_program_prefetches = module->CrossProgramPrefetches();
6457   EXPECT_EQ(cross_program_prefetches.size(), 0);
6458 }
6459 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchPinnedTest)6460 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTest) {
6461   HloComputation::Builder builder(TestName());
6462 
6463   constexpr int kBatch = 8;
6464   constexpr int kFeature = 8;
6465   constexpr int kOutput = 2;
6466 
6467   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6468   auto rhs_shape = ShapeUtil::MakeShapeWithLayout(
6469       F32, {kFeature, kOutput},
6470       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
6471       /*element_size_in_bits=*/0, kAlternateMemorySpace);
6472   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6473   HloInstruction* lhs = builder.AddInstruction(
6474       HloInstruction::CreateParameter(0, lhs_shape, "lhs"));
6475   HloInstruction* rhs = builder.AddInstruction(
6476       HloInstruction::CreateParameter(1, rhs_shape, "rhs"));
6477 
6478   DotDimensionNumbers dot_dnums;
6479   dot_dnums.add_lhs_contracting_dimensions(1);
6480   dot_dnums.add_rhs_contracting_dimensions(0);
6481   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6482       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6483 
6484   auto module = CreateNewVerifiedModule();
6485   HloComputation* computation = module->AddEntryComputation(builder.Build());
6486 
6487   HloSchedule schedule(module.get());
6488   schedule.set_sequence(computation, {lhs, rhs, dot});
6489   TF_CHECK_OK(module->set_schedule(schedule));
6490 
6491   Options options;
6492   options.max_size_in_bytes = 128;
6493   options.alignment_in_bytes = 8;
6494   options.verify = true;
6495   options.is_allowed_in_alternate_mem_fn = [](const HloValue& value) {
6496     return true;
6497   };
6498   std::unique_ptr<PresetAssignments> preset_assignments = AssignMemorySpace(
6499       module.get(),
6500       /*max_outstanding_async_copies=*/-1,
6501       /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2, options);
6502 
6503   auto cross_program_prefetches = module->CrossProgramPrefetches();
6504   EXPECT_EQ(cross_program_prefetches.size(), 0);
6505 }
6506 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchPinnedTupleTest)6507 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTupleTest) {
6508   HloComputation::Builder builder(TestName());
6509 
6510   constexpr int kBatch = 8;
6511   constexpr int kFeature = 8;
6512   constexpr int kOutput = 2;
6513 
6514   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
6515   auto rhs_shape = ShapeUtil::MakeShapeWithLayout(
6516       F32, {kFeature, kOutput},
6517       /*minor_to_major=*/{1, 0}, /*dim_level_types=*/{}, /*tiles=*/{},
6518       /*element_size_in_bits=*/0, kAlternateMemorySpace);
6519   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
6520   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
6521   HloInstruction* param = builder.AddInstruction(
6522       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
6523 
6524   auto lhs = builder.AddInstruction(
6525       HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
6526   auto rhs = builder.AddInstruction(
6527       HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
6528 
6529   DotDimensionNumbers dot_dnums;
6530   dot_dnums.add_lhs_contracting_dimensions(1);
6531   dot_dnums.add_rhs_contracting_dimensions(0);
6532   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
6533       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
6534 
6535   auto module = CreateNewVerifiedModule();
6536   HloComputation* computation = module->AddEntryComputation(builder.Build());
6537 
6538   HloSchedule schedule(module.get());
6539   schedule.set_sequence(computation, {param, lhs, rhs, dot});
6540   TF_CHECK_OK(module->set_schedule(schedule));
6541 
6542   Options options;
6543   options.max_size_in_bytes = 128;
6544   options.alignment_in_bytes = 8;
6545   options.verify = true;
6546   options.is_allowed_in_alternate_mem_fn = [](const HloValue& value) {
6547     return true;
6548   };
6549   std::unique_ptr<PresetAssignments> preset_assignments = AssignMemorySpace(
6550       module.get(),
6551       /*max_outstanding_async_copies=*/-1,
6552       /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2, options);
6553 
6554   auto cross_program_prefetches = module->CrossProgramPrefetches();
6555   EXPECT_EQ(cross_program_prefetches.size(), 0);
6556 }
6557 
TEST_P(MemorySpaceAssignmentTest,CrossProgramRootDupMayAlias)6558 TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDupMayAlias) {
6559   absl::string_view hlo_string = R"(
6560   HloModule cross_program_prefetch, is_scheduled=true, input_output_alias={ {}: (0, {}, may-alias) }
6561     ENTRY CrossProgramPrefetch {
6562       c0 = s32[1,2] constant({{77, 77}})
6563       c1 = s32[] constant(0)
6564       p0 = s32[2,2] parameter(0)
6565       ROOT dup = s32[2,2] dynamic-update-slice(s32[2,2] p0, s32[1,2] c0, s32[] c1, s32[] c1)
6566     }
6567   )";
6568   TF_ASSERT_OK_AND_ASSIGN(auto module,
6569                           ParseAndReturnVerifiedModule(hlo_string));
6570   auto preset_assignments = AssignMemorySpace(
6571       module.get(), /*max_outstanding_async_copies=*/-1,
6572       /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6573 
6574   auto cross_program_prefetches = module->CrossProgramPrefetches();
6575   EXPECT_EQ(cross_program_prefetches.size(), 1);
6576 }
6577 
TEST_P(MemorySpaceAssignmentTest,CrossProgramRootDup)6578 TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDup) {
6579   absl::string_view hlo_string = R"(
6580   HloModule cross_program_prefetch, is_scheduled=true
6581     ENTRY CrossProgramPrefetch {
6582       c0 = s32[1,2] constant({{77, 77}})
6583       c1 = s32[] constant(0)
6584       p0 = s32[2,2] parameter(0)
6585       ROOT dup = s32[2,2] dynamic-update-slice(s32[2,2] p0, s32[1,2] c0, s32[] c1, s32[] c1)
6586     }
6587   )";
6588   TF_ASSERT_OK_AND_ASSIGN(auto module,
6589                           ParseAndReturnVerifiedModule(hlo_string));
6590   auto preset_assignments = AssignMemorySpace(
6591       module.get(), /*max_outstanding_async_copies=*/-1,
6592       /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6593 
6594   auto cross_program_prefetches = module->CrossProgramPrefetches();
6595   EXPECT_EQ(cross_program_prefetches.size(), 1);
6596 }
6597 
TEST_P(MemorySpaceAssignmentTest,CrossProgramRootDupDot)6598 TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDupDot) {
6599   // Cross program prefetch since the parameter and the root don't alias.
6600   absl::string_view hlo_string = R"(
6601   HloModule cross_program_prefetch, is_scheduled=true
6602     ENTRY CrossProgramPrefetch {
6603       c0 = s32[1,2] constant({{77, 77}})
6604       c1 = s32[] constant(0)
6605       p0 = s32[2,2] parameter(0)
6606       p1 = s32[2,2] parameter(1)
6607       dup = s32[2,2] dynamic-update-slice(s32[2,2] p0, s32[1,2] c0, s32[] c1, s32[] c1)
6608       ROOT dot = s32[2,2] dot(p1, dup), lhs_contracting_dims={0}, rhs_contracting_dims={0}
6609     }
6610   )";
6611   TF_ASSERT_OK_AND_ASSIGN(auto module,
6612                           ParseAndReturnVerifiedModule(hlo_string));
6613   auto preset_assignments = AssignMemorySpace(
6614       module.get(), /*max_outstanding_async_copies=*/-1,
6615       /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6616 
6617   auto cross_program_prefetches = module->CrossProgramPrefetches();
6618   EXPECT_EQ(cross_program_prefetches.size(), 1);
6619 }
6620 
TEST_P(MemorySpaceAssignmentTest,CrossProgramRootDotMayAlias)6621 TEST_P(MemorySpaceAssignmentTest, CrossProgramRootDotMayAlias) {
6622   absl::string_view hlo_string = R"(
6623   HloModule cross_program_prefetch, is_scheduled=true, input_output_alias={ {}: (0, {}, may-alias) }
6624     ENTRY CrossProgramPrefetch {
6625       p0 = s32[2,2] parameter(0)
6626       p1 = s32[2,2] parameter(1)
6627       ROOT dot = s32[2,2] dot(p1, p0), lhs_contracting_dims={0}, rhs_contracting_dims={0}
6628     }
6629   )";
6630   TF_ASSERT_OK_AND_ASSIGN(auto module,
6631                           ParseAndReturnVerifiedModule(hlo_string));
6632   auto preset_assignments = AssignMemorySpace(
6633       module.get(), /*max_outstanding_async_copies=*/-1,
6634       /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6635 
6636   auto cross_program_prefetches = module->CrossProgramPrefetches();
6637   EXPECT_EQ(cross_program_prefetches.size(), 1);
6638 }
6639 
TEST_P(MemorySpaceAssignmentTest,CrossProgramRootParameter)6640 TEST_P(MemorySpaceAssignmentTest, CrossProgramRootParameter) {
6641   absl::string_view hlo_string = R"(
6642   HloModule cross_program_prefetch, is_scheduled=true
6643     ENTRY CrossProgramPrefetch {
6644       p0 = s32[2,2] parameter(0)
6645       ROOT bitcast = u32[2,2] bitcast(p0)
6646     }
6647   )";
6648   TF_ASSERT_OK_AND_ASSIGN(auto module,
6649                           ParseAndReturnVerifiedModule(hlo_string));
6650   auto preset_assignments = AssignMemorySpace(
6651       module.get(), /*max_outstanding_async_copies=*/-1,
6652       /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6653 
6654   auto cross_program_prefetches = module->CrossProgramPrefetches();
6655   EXPECT_EQ(cross_program_prefetches.size(), 1);
6656 }
6657 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchNoReuse)6658 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) {
6659   // This test is for checking if the cross-program-prefetched buffer is freed
6660   // after its last use and there is an end-of-program prefetch.
6661   absl::string_view hlo_string = R"(
6662   HloModule cross_program_prefetch, is_scheduled=true
6663 
6664   ENTRY CrossProgramPrefetch {
6665     p0 = f32[8,8]{1,0} parameter(0)
6666     p1 = f32[8,2]{1,0} parameter(1)
6667     dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6668     negate.1 = f32[8,2]{1,0} negate(dot)
6669     negate.2 = f32[8,2]{1,0} negate(negate.1)
6670     negate.3 = f32[8,2]{1,0} negate(negate.2)
6671     negate.4 = f32[8,2]{1,0} negate(negate.3)
6672     negate.5 = f32[8,2]{1,0} negate(negate.4)
6673     negate.6 = f32[8,2]{1,0} negate(negate.5)
6674     negate.7 = f32[8,2]{1,0} negate(negate.6)
6675     negate.8 = f32[8,2]{1,0} negate(negate.7)
6676     ROOT negate.9 = f32[8,2]{1,0} negate(negate.8)
6677   }
6678   )";
6679   TF_ASSERT_OK_AND_ASSIGN(auto module,
6680                           ParseAndReturnVerifiedModule(hlo_string));
6681   auto preset_assignments = AssignMemorySpace(
6682       module.get(), /*max_outstanding_async_copies=*/-1,
6683       /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6684 
6685   auto cross_program_prefetches = module->CrossProgramPrefetches();
6686   EXPECT_EQ(cross_program_prefetches.size(), 1);
6687   if (!cross_program_prefetches.empty()) {
6688     EXPECT_EQ(cross_program_prefetches[0].first, 1);
6689     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({}));
6690   }
6691 
6692   TF_ASSERT_OK_AND_ASSIGN(
6693       std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
6694       HloDataflowAnalysis::Run(*module));
6695   const HloValue& cross_program_prefetched_value =
6696       dataflow_analysis->GetValueDefinedAt(
6697           module->entry_computation()->parameter_instruction(1), {});
6698   // Expect that there are two prefetches that use this value, one is the
6699   // cross-program prefetch, the other is the end-of-program prefetch.
6700   auto is_cross_program_prefetch = [](const HloUse& use) {
6701     return use.instruction->opcode() == HloOpcode::kCopyStart &&
6702            use.instruction->is_cross_program_prefetch();
6703   };
6704   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6705                              is_cross_program_prefetch),
6706             1);
6707   auto is_end_of_program_prefetch = [](const HloUse& use) {
6708     return use.instruction->opcode() == HloOpcode::kCopyStart &&
6709            !use.instruction->is_cross_program_prefetch();
6710   };
6711   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6712                              is_end_of_program_prefetch),
6713             1);
6714   // Also verify that the copy-done for the end-of-program prefetch is the last
6715   // instruction in schedule.
6716   const HloInstruction* last_instruction =
6717       module->schedule()
6718           .sequence(module->entry_computation())
6719           .instructions()[module->entry_computation()->instruction_count() - 1];
6720   EXPECT_THAT(last_instruction, op::CopyDone());
6721   EXPECT_NE(last_instruction, module->entry_computation()->root_instruction());
6722   // Cross program prefetch would use offset 0 because that's the first
6723   // assignment. Since we are freeing the cross-program prefetch buffer, we
6724   // would also expect to see some of the intermediate computations (one of the
6725   // negate ops) to also get 0 offset allocations.
6726   bool has_zero_offset_allocations = false;
6727   for (auto pos_and_chunk : preset_assignments->chunks()) {
6728     if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate &&
6729         pos_and_chunk.second.offset == 0) {
6730       has_zero_offset_allocations = true;
6731     }
6732   }
6733   EXPECT_TRUE(has_zero_offset_allocations);
6734 }
6735 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTupleNoReuse)6736 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleNoReuse) {
6737   // This test is for checking if the cross-program-prefetched buffer is freed
6738   // after its last use and there is an end-of-program prefetch.
6739   absl::string_view hlo_string = R"(
6740   HloModule cross_program_prefetch, is_scheduled=true
6741 
6742   ENTRY CrossProgramPrefetch {
6743     p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
6744     get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
6745     get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
6746     dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6747     negate.1 = f32[8,2]{1,0} negate(dot)
6748     negate.2 = f32[8,2]{1,0} negate(negate.1)
6749     negate.3 = f32[8,2]{1,0} negate(negate.2)
6750     negate.4 = f32[8,2]{1,0} negate(negate.3)
6751     negate.5 = f32[8,2]{1,0} negate(negate.4)
6752     negate.6 = f32[8,2]{1,0} negate(negate.5)
6753     negate.7 = f32[8,2]{1,0} negate(negate.6)
6754     negate.8 = f32[8,2]{1,0} negate(negate.7)
6755     ROOT negate.9 = f32[8,2]{1,0} negate(negate.8)
6756   }
6757   )";
6758   TF_ASSERT_OK_AND_ASSIGN(auto module,
6759                           ParseAndReturnVerifiedModule(hlo_string));
6760   auto preset_assignments = AssignMemorySpace(
6761       module.get(), /*max_outstanding_async_copies=*/-1,
6762       /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6763 
6764   auto cross_program_prefetches = module->CrossProgramPrefetches();
6765   EXPECT_EQ(cross_program_prefetches.size(), 1);
6766   if (!cross_program_prefetches.empty()) {
6767     EXPECT_EQ(cross_program_prefetches[0].first, 0);
6768     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
6769   }
6770 
6771   TF_ASSERT_OK_AND_ASSIGN(
6772       std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
6773       HloDataflowAnalysis::Run(*module));
6774   const HloValue& cross_program_prefetched_value =
6775       dataflow_analysis->GetValueDefinedAt(
6776           module->entry_computation()->parameter_instruction(0), {1});
6777   // Expect that there are two prefetches that use this value, one is the
6778   // cross-program prefetch, the other is the end-of-program prefetch.
6779   auto is_cross_program_prefetch = [](const HloUse& use) {
6780     return use.instruction->opcode() == HloOpcode::kCopyStart &&
6781            use.instruction->is_cross_program_prefetch();
6782   };
6783   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6784                              is_cross_program_prefetch),
6785             1);
6786   auto is_end_of_program_prefetch = [](const HloUse& use) {
6787     return use.instruction->opcode() == HloOpcode::kCopyStart &&
6788            !use.instruction->is_cross_program_prefetch();
6789   };
6790   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6791                              is_end_of_program_prefetch),
6792             1);
6793   // Also verify that the copy-done for the end-of-program prefetch is the last
6794   // instruction in schedule.
6795   const HloInstruction* last_instruction =
6796       module->schedule()
6797           .sequence(module->entry_computation())
6798           .instructions()[module->entry_computation()->instruction_count() - 1];
6799   EXPECT_THAT(last_instruction, op::CopyDone());
6800   EXPECT_NE(last_instruction, module->entry_computation()->root_instruction());
6801   // Cross program prefetch would use offset 0 because that's the first
6802   // assignment. Since we are freeing the cross-program prefetch buffer, we
6803   // would also expect to see some of the intermediate computations (one of the
6804   // negate ops) to also get 0 offset allocations.
6805   bool has_zero_offset_allocations = false;
6806   for (auto pos_and_chunk : preset_assignments->chunks()) {
6807     if (pos_and_chunk.first.instruction->opcode() == HloOpcode::kNegate &&
6808         pos_and_chunk.second.offset == 0) {
6809       has_zero_offset_allocations = true;
6810     }
6811   }
6812   EXPECT_TRUE(has_zero_offset_allocations);
6813 }
6814 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchReuse)6815 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchReuse) {
6816   // This tests the scenario that the cross-program-prefetched buffer is used
6817   // again close to the end of the computation. In this case, it is better not
6818   // to free the buffer.
6819   absl::string_view hlo_string = R"(
6820   HloModule cross_program_prefetch, is_scheduled=true
6821 
6822   ENTRY CrossProgramPrefetch {
6823     p0 = f32[8,8]{1,0} parameter(0)
6824     p1 = f32[8,2]{1,0} parameter(1)
6825     dot = f32[8,2]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6826     negate.1 = f32[8,2]{1,0} negate(dot)
6827     negate.2 = f32[8,2]{1,0} negate(negate.1)
6828     negate.3 = f32[8,2]{1,0} negate(negate.2)
6829     negate.4 = f32[8,2]{1,0} negate(negate.3)
6830     negate.5 = f32[8,2]{1,0} negate(negate.4)
6831     negate.6 = f32[8,2]{1,0} negate(negate.5)
6832     negate.7 = f32[8,2]{1,0} negate(negate.6)
6833     negate.8 = f32[8,2]{1,0} negate(negate.7)
6834     ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
6835   }
6836   )";
6837   TF_ASSERT_OK_AND_ASSIGN(auto module,
6838                           ParseAndReturnVerifiedModule(hlo_string));
6839 
6840   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
6841                     /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6842 
6843   auto cross_program_prefetches = module->CrossProgramPrefetches();
6844   EXPECT_EQ(cross_program_prefetches.size(), 1);
6845   if (!cross_program_prefetches.empty()) {
6846     EXPECT_EQ(cross_program_prefetches[0].first, 1);
6847     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({}));
6848   }
6849 
6850   TF_ASSERT_OK_AND_ASSIGN(
6851       std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
6852       HloDataflowAnalysis::Run(*module));
6853   const HloValue& cross_program_prefetched_value =
6854       dataflow_analysis->GetValueDefinedAt(
6855           module->entry_computation()->parameter_instruction(1), {});
6856   // Expect that there is one prefetch that use this value, the cross-program
6857   // prefetch. There shouldn't be an end-of-program prefetch.
6858   auto is_cross_program_prefetch = [](const HloUse& use) {
6859     return use.instruction->opcode() == HloOpcode::kCopyStart &&
6860            use.instruction->is_cross_program_prefetch();
6861   };
6862   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6863                              is_cross_program_prefetch),
6864             1);
6865   auto is_end_of_program_prefetch = [](const HloUse& use) {
6866     return use.instruction->opcode() == HloOpcode::kCopyStart &&
6867            !use.instruction->is_cross_program_prefetch();
6868   };
6869   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6870                              is_end_of_program_prefetch),
6871             0);
6872 }
6873 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTupleReuse)6874 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTupleReuse) {
6875   // This tests the scenario that the cross-program-prefetched buffer is used
6876   // again close to the end of the computation. In this case, it is better not
6877   // to free the buffer.
6878   absl::string_view hlo_string = R"(
6879   HloModule cross_program_prefetch, is_scheduled=true
6880 
6881   ENTRY CrossProgramPrefetch {
6882     p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
6883     get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
6884     get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
6885     dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
6886     negate.1 = f32[8,2]{1,0} negate(dot)
6887     negate.2 = f32[8,2]{1,0} negate(negate.1)
6888     negate.3 = f32[8,2]{1,0} negate(negate.2)
6889     negate.4 = f32[8,2]{1,0} negate(negate.3)
6890     negate.5 = f32[8,2]{1,0} negate(negate.4)
6891     negate.6 = f32[8,2]{1,0} negate(negate.5)
6892     negate.7 = f32[8,2]{1,0} negate(negate.6)
6893     negate.8 = f32[8,2]{1,0} negate(negate.7)
6894     ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, get-tuple-element.1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
6895   }
6896   )";
6897   TF_ASSERT_OK_AND_ASSIGN(auto module,
6898                           ParseAndReturnVerifiedModule(hlo_string));
6899 
6900   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
6901                     /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
6902 
6903   auto cross_program_prefetches = module->CrossProgramPrefetches();
6904   EXPECT_EQ(cross_program_prefetches.size(), 1);
6905   if (!cross_program_prefetches.empty()) {
6906     EXPECT_EQ(cross_program_prefetches[0].first, 0);
6907     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
6908   }
6909 
6910   TF_ASSERT_OK_AND_ASSIGN(
6911       std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
6912       HloDataflowAnalysis::Run(*module));
6913   const HloValue& cross_program_prefetched_value =
6914       dataflow_analysis->GetValueDefinedAt(
6915           module->entry_computation()->parameter_instruction(0), {1});
6916   // Expect that there is one prefetch that use this value, the cross-program
6917   // prefetch. There shouldn't be an end-of-program prefetch.
6918   auto is_cross_program_prefetch = [](const HloUse& use) {
6919     return use.instruction->opcode() == HloOpcode::kCopyStart &&
6920            use.instruction->is_cross_program_prefetch();
6921   };
6922   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6923                              is_cross_program_prefetch),
6924             1);
6925   auto is_end_of_program_prefetch = [](const HloUse& use) {
6926     return use.instruction->opcode() == HloOpcode::kCopyStart &&
6927            !use.instruction->is_cross_program_prefetch();
6928   };
6929   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.GetUses(),
6930                              is_end_of_program_prefetch),
6931             0);
6932 }
6933 
6934 using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
6935 
TEST_F(CostAnalysisPrefetchIntervalPickerTest,PrefetchIntervalOrder)6936 TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
6937   absl::string_view hlo_string = R"(
6938   HloModule bug, is_scheduled=true
6939 
6940   ENTRY Entry {
6941     param0 = f32[2,4] parameter(0)
6942     a = f32[2,4] negate(param0)
6943     b = f32[2,4] negate(a)
6944     c = f32[2,4] negate(b)
6945     d = f32[2,4] negate(c)
6946     e = f32[2,4] negate(d)
6947     f = f32[2,4] negate(e)
6948     g = f32[2,4] negate(f)
6949     h = f32[2,4] negate(g)
6950     i = f32[2,4] negate(h)
6951     j = f32[2,4] negate(i)
6952     k = f32[2,4] negate(j)
6953     l = f32[2,4] negate(k)
6954     m = f32[2,4] negate(l)
6955     n = f32[2,4] negate(m)
6956     o = f32[2,4] negate(n)
6957     p = f32[2,4] negate(o)
6958     q = f32[2,4] negate(p)
6959     r = f32[2,4] negate(q)
6960     s = f32[2,4] negate(r)
6961     t = f32[2,4] negate(s)
6962     u = f32[2,4] negate(t)
6963     ROOT v = f32[2,4] add(u, param0)
6964   }
6965   )";
6966   TF_ASSERT_OK_AND_ASSIGN(auto module,
6967                           ParseAndReturnVerifiedModule(hlo_string));
6968 
6969   HloCostAnalysis hlo_cost_analysis(ShapeSize);
6970   Options options;
6971   TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
6972                           FakeMemorySpaceAssignmentCostAnalysis::Create(
6973                               hlo_cost_analysis, *module, options));
6974   CostAnalysisPrefetchIntervalPicker interval_picker(
6975       *cost_analysis,
6976       /*min_overlap_to_async_copy_ratio=*/1.0,
6977       /*preferred_overlap_to_async_copy_ratio=*/2.0,
6978       /*max_overlap_to_mem_size_async_copy_ratio=*/4.0,
6979       /*mem_size_bytes=*/32);
6980 
6981   HloInstruction* root = module->entry_computation()->root_instruction();
6982   const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
6983   interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22);
6984 
6985   // Expect that the first interval is (15, 22), which has elapsed time of 6.0,
6986   // twice of the async copy elased (3.0). Then we expect that intervals will be
6987   // visited in alternating increasing and decreasing orders until hitting the
6988   // min and max async copy overlap ratios, which are the intervals (18, 22)
6989   // and (9, 22) respectively.
6990   LOG(INFO) << interval_picker.ToDebugString();
6991   EXPECT_EQ(interval_picker.Next(), 15);
6992   LOG(INFO) << interval_picker.ToDebugString();
6993   EXPECT_EQ(interval_picker.Next(), 16);
6994   LOG(INFO) << interval_picker.ToDebugString();
6995   EXPECT_EQ(interval_picker.Next(), 14);
6996   LOG(INFO) << interval_picker.ToDebugString();
6997   EXPECT_EQ(interval_picker.Next(), 17);
6998   LOG(INFO) << interval_picker.ToDebugString();
6999   EXPECT_EQ(interval_picker.Next(), 13);
7000   LOG(INFO) << interval_picker.ToDebugString();
7001   EXPECT_EQ(interval_picker.Next(), 18);  // Min async overlap ratio reached.
7002   LOG(INFO) << interval_picker.ToDebugString();
7003   EXPECT_EQ(interval_picker.Next(), 12);
7004   LOG(INFO) << interval_picker.ToDebugString();
7005   EXPECT_EQ(interval_picker.Next(), 11);
7006   LOG(INFO) << interval_picker.ToDebugString();
7007   EXPECT_EQ(interval_picker.Next(), 10);
7008   LOG(INFO) << interval_picker.ToDebugString();
7009   EXPECT_EQ(interval_picker.Next(), 9);  // Max async overlap ratio reached.
7010   LOG(INFO) << interval_picker.ToDebugString();
7011   EXPECT_TRUE(interval_picker.Done());
7012 
7013   // Expect that if the time between start_time and end_time is too short, there
7014   // won't be any available intervals.
7015   interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22);
7016   LOG(INFO) << interval_picker.ToDebugString();
7017   EXPECT_TRUE(interval_picker.Done());
7018 }
7019 
TEST_F(CostAnalysisPrefetchIntervalPickerTest,PrefetchIntervalOrderWhile)7020 TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) {
7021   absl::string_view hlo_string = R"(
7022   HloModule bug, is_scheduled=true
7023 
7024   while_condition {
7025     param1 = (f32[2,4]) parameter(0)    // 19
7026     ROOT cond = pred[] constant(true)   // 20
7027   }
7028 
7029   while_body {
7030     param2 = (f32[2,4]) parameter(0)    // 21
7031     gte2 = f32[2,4] get-tuple-element(param2), index=0  // 22
7032     add = f32[2,4] add(gte2, gte2)      // 23
7033     ROOT tuple2 = (f32[2,4]) tuple(add) // 24
7034   }
7035 
7036   ENTRY Entry {
7037     param0 = f32[2,4] parameter(0)  // 0
7038     a = f32[2,4] negate(param0)     // 1
7039     b = f32[2,4] negate(a)          // 2
7040     c = f32[2,4] negate(b)          // 3
7041     d = f32[2,4] negate(c)          // 4
7042     e = f32[2,4] negate(d)          // 5
7043     f = f32[2,4] negate(e)          // 6
7044     g = f32[2,4] negate(f)          // 7
7045     h = f32[2,4] negate(g)          // 8
7046     i = f32[2,4] negate(h)          // 9
7047     j = f32[2,4] negate(i)          // 10
7048     k = f32[2,4] negate(j)          // 11
7049     l = f32[2,4] negate(k)          // 12
7050     m = f32[2,4] negate(l)          // 13
7051     n = f32[2,4] negate(m)          // 14
7052     o = f32[2,4] negate(n)          // 15
7053     p = f32[2,4] negate(o)          // 16
7054     q = f32[2,4] negate(p)          // 17
7055     tuple = (f32[2,4]) tuple(q)     // 18
7056     while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body  // 25
7057     gte1 = f32[2,4] get-tuple-element(while), index=0  // 26
7058     r = f32[2,4] negate(gte1)       // 27
7059     s = f32[2,4] negate(r)          // 28
7060     t = f32[2,4] negate(s)          // 29
7061     u = f32[2,4] negate(t)          // 30
7062     ROOT v = f32[2,4] add(u, param0)  // 31
7063   }
7064   )";
7065   TF_ASSERT_OK_AND_ASSIGN(auto module,
7066                           ParseAndReturnVerifiedModule(hlo_string));
7067 
7068   HloCostAnalysis hlo_cost_analysis(ShapeSize);
7069   Options options;
7070   TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
7071                           FakeMemorySpaceAssignmentCostAnalysis::Create(
7072                               hlo_cost_analysis, *module, options));
7073   CostAnalysisPrefetchIntervalPicker interval_picker(
7074       *cost_analysis,
7075       /*min_overlap_to_async_copy_ratio=*/1.0,
7076       /*preferred_overlap_to_async_copy_ratio=*/2.0,
7077       /*max_overlap_to_mem_size_async_copy_ratio=*/12.0,
7078       /*mem_size_bytes=*/32);
7079 
7080   EXPECT_EQ(cost_analysis->options()
7081                 .xla_tpu_memory_space_assignment_while_execution_count,
7082             5);
7083   HloInstruction* root = module->entry_computation()->root_instruction();
7084   const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
7085   interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31);
7086 
7087   // Because there are while loop computations between [19, 24], we ensure that
7088   // the interval picker avoids this interval.
7089   LOG(INFO) << interval_picker.ToDebugString();
7090   EXPECT_EQ(interval_picker.Next(), 25);
7091   LOG(INFO) << interval_picker.ToDebugString();
7092   EXPECT_EQ(interval_picker.Next(), 26);
7093   LOG(INFO) << interval_picker.ToDebugString();
7094   EXPECT_EQ(interval_picker.Next(), 18);
7095   LOG(INFO) << interval_picker.ToDebugString();
7096   EXPECT_EQ(interval_picker.Next(), 27);  // Min async overlap ratio reached.
7097   LOG(INFO) << interval_picker.ToDebugString();
7098   EXPECT_EQ(interval_picker.Next(), 17);  // Max async overlap ratio reached.
7099   LOG(INFO) << interval_picker.ToDebugString();
7100   EXPECT_TRUE(interval_picker.Done());
7101 }
7102 
TEST_F(CostAnalysisPrefetchIntervalPickerTest,NestedWhile)7103 TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) {
7104   // This test is to check against a bug where we didn't assign
7105   // while_nest_level_ for while instructions, and defaulting to 0. This could
7106   // cause the prefetch interval logic to think a nested while instruction is
7107   // the same level as the outermost computation.
7108   absl::string_view hlo_string = R"(
7109   HloModule bug, is_scheduled=true
7110 
7111   while_condition.2 {
7112     param1 = (f32[2,4]) parameter(0)    // 11
7113     ROOT cond = pred[] constant(true)   // 12
7114   }
7115 
7116   while_body.2 {
7117     param2 = (f32[2,4]) parameter(0)    // 13
7118     gte2 = f32[2,4] get-tuple-element(param2), index=0  // 14
7119     add = f32[2,4] add(gte2, gte2)      // 15
7120     ROOT tuple2 = (f32[2,4]) tuple(add) // 16
7121   }
7122 
7123   while_condition.1 {
7124     param3 = (f32[2,4]) parameter(0)    // 5
7125     ROOT cond = pred[] constant(true)   // 6
7126   }
7127 
7128   while_body.1 {
7129     param4 = (f32[2,4]) parameter(0)    // 7
7130     gte1 = f32[2,4] get-tuple-element(param4), index=0  // 8
7131     add1 = f32[2,4] add(gte1, gte1)     // 9
7132     tuple1 = (f32[2,4]) tuple(add1)     // 10
7133     while = (f32[2,4]) while(tuple1), condition=while_condition.2, body=while_body.2  // 17
7134     gte2 = f32[2,4] get-tuple-element(while), index=0  // 18
7135     add2 = f32[2,4] add(gte2, gte2)     // 19
7136     ROOT tuple2 = (f32[2,4]) tuple(add2)  // 20
7137   }
7138 
7139   ENTRY Entry {
7140     param0 = f32[2,4] parameter(0)  // 0
7141     a = f32[2,4] negate(param0)     // 1
7142     b = f32[2,4] negate(a)          // 2
7143     c = f32[2,4] negate(b)          // 3
7144     tuple = (f32[2,4]) tuple(c)     // 4
7145     while = (f32[2,4]) while(tuple), condition=while_condition.1, body=while_body.1  // 21
7146     gte1 = f32[2,4] get-tuple-element(while), index=0  // 22
7147     ROOT root = f32[2,4] add(gte1, param0)  // 23
7148   }
7149   )";
7150   TF_ASSERT_OK_AND_ASSIGN(auto module,
7151                           ParseAndReturnVerifiedModule(hlo_string));
7152 
7153   HloCostAnalysis hlo_cost_analysis(ShapeSize);
7154   Options options;
7155   TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
7156                           FakeMemorySpaceAssignmentCostAnalysis::Create(
7157                               hlo_cost_analysis, *module, options));
7158   CostAnalysisPrefetchIntervalPicker interval_picker(
7159       *cost_analysis,
7160       /*min_overlap_to_async_copy_ratio=*/1.0,
7161       /*preferred_overlap_to_async_copy_ratio=*/2.0,
7162       /*max_overlap_to_mem_size_async_copy_ratio=*/12.0,
7163       /*mem_size_bytes=*/32);
7164 
7165   HloInstruction* root = module->entry_computation()->root_instruction();
7166   const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
7167   const Shape& shape = root->operand(1)->shape();
7168 
7169   // We expect the root's latest prefetch start time to be before the while loop
7170   // (logical time 4).
7171   EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0,
7172                                                     /*end_time=*/23, &use),
7173             4);
7174 }
7175 
TEST_F(CostAnalysisPrefetchIntervalPickerTest,ConsecutiveConditionals)7176 TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) {
7177   // This is a test for b/170668492, where prefetching for consecutive
7178   // conditionals can cause the prefetch to start in the conditional's
7179   // computation.
7180   absl::string_view hlo_string = R"(
7181   HloModule bug, is_scheduled=true
7182 
7183   true_computation.0 {
7184     p0 = (f32[3]{0}) parameter(0)                   // 5
7185     gte = f32[3]{0} get-tuple-element(p0), index=0  // 6
7186     ROOT neg1 = f32[3]{0} negate(gte)               // 7
7187   }
7188 
7189   false_computation.0 {
7190     p0 = (f32[3]{0}) parameter(0)                   // 8
7191     gte = f32[3]{0} get-tuple-element(p0), index=0  // 9
7192     ROOT neg2 = f32[3]{0} negate(gte)               // 10
7193   }
7194 
7195   true_computation.1 {
7196     p0 = (f32[3]{0}) parameter(0)                   // 12
7197     gte = f32[3]{0} get-tuple-element(p0), index=0  // 13
7198     ROOT neg1 = f32[3]{0} negate(gte)               // 14
7199   }
7200 
7201   false_computation.1 {
7202     p0 = (f32[3]{0}) parameter(0)                   // 15
7203     gte = f32[3]{0} get-tuple-element(p0), index=0  // 16
7204     ROOT neg2 = f32[3]{0} negate(gte)               // 17
7205   }
7206 
7207   ENTRY entry {
7208     p0 = f32[3]{0} parameter(0)       // 0
7209     p1 = f32[3]{0} parameter(1)       // 1
7210     p2 = pred[] parameter(2)          // 2
7211     tuple0 = (f32[3]{0}) tuple(p0)    // 3
7212     tuple1 = (f32[3]{0}) tuple(p1)    // 4
7213     conditional0 = f32[3]{0} conditional(p2, tuple0, tuple0), true_computation=true_computation.0, false_computation=false_computation.0  // 11
7214     conditional1 = f32[3]{0} conditional(p2, tuple1, tuple1), true_computation=true_computation.1, false_computation=false_computation.1  // 18
7215     ROOT tuple2 = (f32[3]{0}, f32[3]{0}) tuple(conditional0, conditional1)  // 19
7216   }
7217   )";
7218   TF_ASSERT_OK_AND_ASSIGN(auto module,
7219                           ParseAndReturnVerifiedModule(hlo_string));
7220 
7221   HloCostAnalysis hlo_cost_analysis(ShapeSize);
7222   Options options;
7223   TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
7224                           FakeMemorySpaceAssignmentCostAnalysis::Create(
7225                               hlo_cost_analysis, *module, options));
7226   CostAnalysisPrefetchIntervalPicker interval_picker(
7227       *cost_analysis,
7228       /*min_overlap_to_async_copy_ratio=*/1.0,
7229       /*preferred_overlap_to_async_copy_ratio=*/2.0,
7230       /*max_overlap_to_mem_size_async_copy_ratio=*/12.0,
7231       /*mem_size_bytes=*/32);
7232 
7233   LOG(INFO) << module->ToString();
7234 
7235   HloInstruction* conditional1 =
7236       module->entry_computation()->GetInstructionWithName("conditional1");
7237   const HloUse use{conditional1, /*operand_number=*/1, /*operand_index=*/{0}};
7238   const Shape& shape =
7239       module->entry_computation()->parameter_instruction(0)->shape();
7240 
7241   // Expect that the prefetch to start before conditional0's called
7242   // computations.
7243   EXPECT_LT(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0,
7244                                                     /*end_time=*/11, &use),
7245             5);
7246 }
7247 
TEST_F(CostAnalysisPrefetchIntervalPickerTest,EarliestLatestWindowTooSmall)7248 TEST_F(CostAnalysisPrefetchIntervalPickerTest, EarliestLatestWindowTooSmall) {
7249   // This tests the scenario where there is an op that takes a long time (tanh
7250   // in this example) and as a result the earliest and latest times both fall
7251   // inside this long-running op. In this case, we should still return a valid
7252   // prefetch interval just before the long-running op.
7253   absl::string_view hlo_string = R"(
7254   HloModule bug, is_scheduled=true
7255 
7256   ENTRY Entry {
7257     param0 = f32[2,4] parameter(0)
7258     negate = f32[2,4] negate(param0)
7259     tanh = f32[2,4] tanh(param0)
7260     ROOT add = f32[2,4] add(tanh, negate)
7261   }
7262   )";
7263   TF_ASSERT_OK_AND_ASSIGN(auto module,
7264                           ParseAndReturnVerifiedModule(hlo_string));
7265 
7266   HloCostAnalysis hlo_cost_analysis(ShapeSize);
7267   Options options;
7268   TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
7269                           FakeMemorySpaceAssignmentCostAnalysis::Create(
7270                               hlo_cost_analysis, *module, options));
7271   cost_analysis->SetOverrideForGetInstructionElapsed(
7272       [](const HloInstruction& hlo) {
7273         if (hlo.opcode() == HloOpcode::kTanh) {
7274           return 20.0;
7275         }
7276         return 1.0;
7277       });
7278   CostAnalysisPrefetchIntervalPicker interval_picker(
7279       *cost_analysis,
7280       /*min_overlap_to_async_copy_ratio=*/1.0,
7281       /*preferred_overlap_to_async_copy_ratio=*/2.0,
7282       /*max_overlap_to_mem_size_async_copy_ratio=*/12.0,
7283       /*mem_size_bytes=*/32);
7284 
7285   HloInstruction* root = module->entry_computation()->root_instruction();
7286   const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
7287   interval_picker.Begin(use, /*start_time=*/1, /*end_time=*/3);
7288 
7289   LOG(INFO) << interval_picker.ToDebugString();
7290   EXPECT_FALSE(interval_picker.Done());
7291   EXPECT_EQ(interval_picker.Next(), 1);
7292   EXPECT_TRUE(interval_picker.Done());
7293 }
7294 
7295 }  // namespace
7296 }  // namespace xla
7297