xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/compiler/xla/service/heap_simulator.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 
36 namespace xla {
37 namespace {
38 
39 class HloSchedulingTest : public HloTestBase {};
40 
PeakMemoryUseOfEntryComputation(HloModule * module,LogicalBuffer::SizeFunction size_function)41 int64_t PeakMemoryUseOfEntryComputation(
42     HloModule* module, LogicalBuffer::SizeFunction size_function) {
43   CHECK(module->has_entry_computation());
44   CHECK(module->has_schedule());
45 
46   std::unique_ptr<HloAliasAnalysis> alias_analysis =
47       HloAliasAnalysis::Run(module).value();
48 
49   const HloSchedule& schedule = module->schedule();
50 
51   HloComputation* computation = module->entry_computation();
52   const HloInstructionSequence& sequence = schedule.sequence(computation);
53   return HeapSimulator::Run(
54              std::make_unique<NoFragmentationStatsHeap<HloValue>>(),
55              *computation, sequence, *alias_analysis, size_function)
56       .ValueOrDie()
57       .heap_size;
58 }
59 
TEST_F(HloSchedulingTest,LastUseScheduledFirst)60 TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
61   // Tests scheduling of the following HLO code:
62   //
63   //   %ab = abs(%param)
64   //   %exp = exp(%param)
65   //   %add = add(%ab, %exp)
66   //   %negate = negate(%exp)
67   //   %sub = subtract(%add, %negate)
68   //
69   // %add should be scheduled before %negate because %add is the last (and only)
70   // use of %ab. Scheduling %add first then frees up %ab's buffer.
71   const Shape vec = ShapeUtil::MakeShape(xla::F32, {42});
72   auto builder = HloComputation::Builder(TestName());
73   auto param =
74       builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param"));
75   auto ab = builder.AddInstruction(
76       HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param));
77   auto exp = builder.AddInstruction(
78       HloInstruction::CreateUnary(vec, HloOpcode::kExp, param));
79 
80   auto add = builder.AddInstruction(
81       HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp));
82   auto negate = builder.AddInstruction(
83       HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp));
84   auto sub = builder.AddInstruction(
85       HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate));
86 
87   auto module = CreateNewVerifiedModule();
88   module->AddEntryComputation(builder.Build());
89 
90   HloMemoryScheduler scheduler([](const BufferValue& buffer) {
91     return ShapeUtil::ByteSizeOf(buffer.shape());
92   });
93   ASSERT_FALSE(module->has_schedule());
94   TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get()));
95   EXPECT_TRUE(changed);
96   ASSERT_TRUE(module->has_schedule());
97   TF_ASSERT_OK(module->schedule().Verify());
98 
99   // Verify that all instructions are in the sequence.
100   const std::vector<HloInstruction*>& sequence =
101       module->schedule().sequence(module->entry_computation()).instructions();
102   EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
103 
104   // The first instruction should be the parameter and the last the root "sub".
105   EXPECT_EQ(param, sequence.front());
106   EXPECT_EQ(sub, sequence.back());
107 
108   SequentialHloOrdering ordering(module->schedule());
109   EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
110 
111   // Clear the schedule using the descheduling pass.
112   HloDescheduler descheduler;
113   EXPECT_TRUE(module->has_schedule());
114   TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed,
115                           descheduler.Run(module.get()));
116   EXPECT_TRUE(descheduler_changed);
117   EXPECT_FALSE(module->has_schedule());
118 }
119 
TEST_F(HloSchedulingTest,ListSchedulerHandlesAliasing)120 TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
121   const char* module_str = R"(
122 HloModule test_aliasing_module
123 
124 ENTRY root {
125   param = s32[1000] parameter(0)
126   p0 = s32[1000] copy(param)
127   p1 = s32[1000] copy(param)
128   t = (s32[1000], s32[1000]) tuple(p0, p1)
129   a = s32[1000] get-tuple-element(t), index=0
130   b = s32[1000] get-tuple-element(t), index=1
131   c = s32[1000] add(a, b)
132   d = s32[1000] add(c, b)
133   e = s32[1000] add(c, c)
134   f = s32[1000] add(e, e)
135   ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f)
136 })";
137 
138   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
139                           ParseAndReturnVerifiedModule(module_str));
140 
141   auto size_fn = [](const BufferValue& buffer) {
142     return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
143   };
144   int64_t peak_memory;
145   TF_ASSERT_OK_AND_ASSIGN(
146       HloSchedule schedule,
147       ScheduleModule(module.get(), size_fn,
148                      ComputationSchedulerToModuleScheduler(ListMemoryScheduler),
149                      /*execution_threads=*/{}, &peak_memory));
150   TF_ASSERT_OK(module->set_schedule(schedule));
151   // Verify that all instructions are in the sequence.
152   const std::vector<HloInstruction*>& sequence =
153       schedule.sequence(module->entry_computation()).instructions();
154   EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
155 
156   absl::flat_hash_map<std::string, const HloInstruction*> instructions_by_name;
157   for (const HloInstruction* instruction : sequence) {
158     instructions_by_name[instruction->name()] = instruction;
159   }
160 
161   // The first instruction should be the parameter and the last the root.
162   EXPECT_EQ(instructions_by_name.at("param"), sequence.front());
163   EXPECT_EQ(instructions_by_name.at("result"), sequence.back());
164 
165   // Instructions "d" and "e" will both be schedulable at the same time, but
166   // instruction "d" allows us to free the buffer of "p1", so the list scheduler
167   // should prefer it.
168   SequentialHloOrdering ordering(schedule);
169   EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"),
170                                       instructions_by_name.at("e")));
171   EXPECT_EQ(PeakMemoryUseOfEntryComputation(module.get(), size_fn),
172             peak_memory);
173 }
174 
TEST_F(HloSchedulingTest,HostSendDoneSchedule)175 TEST_F(HloSchedulingTest, HostSendDoneSchedule) {
176   const char* const module_str = R"(
177 HloModule module
178 
179 ENTRY entry {
180   %p = f32[1000, 1000] parameter(0)
181   %token.0 = token[] after-all()
182   %send = (f32[1000, 1000], token[]) send(%p, %token.0),
183     channel_id=1, is_host_transfer=true
184   %n1 = f32[1000, 1000] negate(%p)
185   %n2 = f32[1000, 1000] negate(%n1)
186   %n3 = f32[1000, 1000] negate(%n2)
187   %send-done = token[] send-done(%send), channel_id=1, is_host_transfer=true
188 }
189 )";
190 
191   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
192                           ParseAndReturnVerifiedModule(module_str));
193 
194   auto size_fn = [](const BufferValue& buffer) {
195     return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
196   };
197 
198   TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
199                           ScheduleModule(module.get(), size_fn,
200                                          ComputationSchedulerToModuleScheduler(
201                                              ListMemoryScheduler)));
202   // Verify that all instructions are in the sequence.
203   const std::vector<HloInstruction*>& sequence =
204       schedule.sequence(module->entry_computation()).instructions();
205   EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
206 
207   absl::flat_hash_map<std::string, const HloInstruction*> instructions_by_name;
208   for (const HloInstruction* instruction : sequence) {
209     instructions_by_name[instruction->name()] = instruction;
210   }
211 
212   EXPECT_LT(absl::c_find(sequence, instructions_by_name.at("send-done")),
213             absl::c_find(sequence, instructions_by_name.at("n1")));
214 }
215 
TEST_F(HloSchedulingTest,TuplesAreAccountedCorrectly)216 TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
217   auto builder = HloComputation::Builder(TestName());
218   const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6});
219 
220   // Wrap lit in abs because constants are considered free by
221   // IgnoreInstruction, and it skews the accounting.
222   auto lit = builder.AddInstruction(HloInstruction::CreateConstant(
223       LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1, 1})));
224   auto abs_const = builder.AddInstruction(
225       HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit));
226 
227   auto abs_abs1 = builder.AddInstruction(
228       HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
229   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
230       absl::Span<HloInstruction* const>({abs_abs1})));
231   auto tuple_elm = builder.AddInstruction(
232       HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
233 
234   auto abs_abs2 = builder.AddInstruction(
235       HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
236 
237   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd,
238                                                       tuple_elm, abs_abs2));
239 
240   auto module = CreateNewVerifiedModule();
241   module->AddEntryComputation(builder.Build());
242   TF_ASSERT_OK_AND_ASSIGN(
243       HloSchedule schedule,
244       ScheduleModule(
245           module.get(),
246           [](const BufferValue& buffer) {
247             return ShapeUtil::ByteSizeOf(buffer.shape(), 1);
248           },
249           ComputationSchedulerToModuleScheduler(ListMemoryScheduler)));
250 
251   // Verify that all instructions are in the sequence.
252   EXPECT_EQ(module->entry_computation()->instruction_count(),
253             schedule.sequence(module->entry_computation()).size());
254   SequentialHloOrdering ordering(schedule);
255   // tuple allocates the tuple buffer and doesn't free anything.
256   // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
257   // abs_abs2 should be scheduled before tuple by List.
258   EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple));
259 }
260 
TEST_F(HloSchedulingTest,MultiOutputFusionAccountedCorrectly)261 TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
262   const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5});
263   HloComputation::Builder builder(TestName());
264 
265   auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
266       LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1})));
267   auto c2 = builder.AddInstruction(HloInstruction::CreateConstant(
268       LiteralUtil::CreateR1<float>({1, 2, 3, 4, 5})));
269   auto c3 = builder.AddInstruction(HloInstruction::CreateConstant(
270       LiteralUtil::CreateR1<float>({0, 2, 4, 6, 8})));
271 
272   auto add = builder.AddInstruction(
273       HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2));
274   auto mul = builder.AddInstruction(
275       HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3));
276   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul}));
277 
278   auto tuple_elm = builder.AddInstruction(
279       HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
280 
281   auto exp = builder.AddInstruction(
282       HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3));
283 
284   builder.AddInstruction(
285       HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp));
286 
287   auto module = CreateNewVerifiedModule();
288   auto* computation = module->AddEntryComputation(builder.Build());
289 
290   auto fusion = computation->CreateFusionInstruction(
291       {tuple, mul, add}, HloInstruction::FusionKind::kLoop);
292 
293   TF_ASSERT_OK_AND_ASSIGN(
294       HloSchedule schedule,
295       ScheduleModule(
296           module.get(),
297           [](const BufferValue& buffer) {
298             return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
299           },
300           ComputationSchedulerToModuleScheduler(ListMemoryScheduler)));
301 
302   // Verify that all instructions are in the sequence.
303   EXPECT_EQ(module->entry_computation()->instruction_count(),
304             schedule.sequence(module->entry_computation()).size());
305   SequentialHloOrdering ordering(schedule);
306   // fusion allocates memory for the tuple elements and doesn't free anything,
307   // so it's more expensive than exp.
308   EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
309 }
310 
TEST_F(HloSchedulingTest,TrivialScheduler)311 TEST_F(HloSchedulingTest, TrivialScheduler) {
312   const char* const hlo_string = R"(
313 HloModule ModuleWithWhile
314 
315 body {
316   param.b = (s32[], s32[]) parameter(0)
317   gte.0 = s32[] get-tuple-element(param.b), index=0
318   gte.1 = s32[] get-tuple-element(param.b), index=1
319   add = s32[] add(gte.0, gte.1)
320   ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
321 }
322 
323 cond {
324   param.c = (s32[], s32[]) parameter(0)
325   ROOT constant = pred[] constant(true)
326 }
327 
328 ENTRY main {
329   init = (s32[], s32[]) parameter(0)
330   ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
331 }
332 )";
333   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
334                           ParseAndReturnVerifiedModule(hlo_string));
335   EXPECT_FALSE(module->has_schedule());
336   TF_ASSERT_OK(HloTrivialScheduler().Run(module.get()).status());
337   ASSERT_TRUE(module->has_schedule());
338   TF_ASSERT_OK(module->schedule().Verify());
339 
340   // Verify that a clone of the module also has a schedule.
341   std::unique_ptr<HloModule> clone = module->Clone();
342   ASSERT_TRUE(clone->has_schedule());
343   TF_ASSERT_OK(clone->schedule().Verify());
344 }
345 
346 }  // namespace
347 }  // namespace xla
348