1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/compiler/xla/service/heap_simulator.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35
36 namespace xla {
37 namespace {
38
39 class HloSchedulingTest : public HloTestBase {};
40
PeakMemoryUseOfEntryComputation(HloModule * module,LogicalBuffer::SizeFunction size_function)41 int64_t PeakMemoryUseOfEntryComputation(
42 HloModule* module, LogicalBuffer::SizeFunction size_function) {
43 CHECK(module->has_entry_computation());
44 CHECK(module->has_schedule());
45
46 std::unique_ptr<HloAliasAnalysis> alias_analysis =
47 HloAliasAnalysis::Run(module).value();
48
49 const HloSchedule& schedule = module->schedule();
50
51 HloComputation* computation = module->entry_computation();
52 const HloInstructionSequence& sequence = schedule.sequence(computation);
53 return HeapSimulator::Run(
54 std::make_unique<NoFragmentationStatsHeap<HloValue>>(),
55 *computation, sequence, *alias_analysis, size_function)
56 .ValueOrDie()
57 .heap_size;
58 }
59
TEST_F(HloSchedulingTest,LastUseScheduledFirst)60 TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
61 // Tests scheduling of the following HLO code:
62 //
63 // %ab = abs(%param)
64 // %exp = exp(%param)
65 // %add = add(%ab, %exp)
66 // %negate = negate(%exp)
67 // %sub = subtract(%add, %negate)
68 //
69 // %add should be scheduled before %negate because %add is the last (and only)
70 // use of %ab. Scheduling %add first then frees up %ab's buffer.
71 const Shape vec = ShapeUtil::MakeShape(xla::F32, {42});
72 auto builder = HloComputation::Builder(TestName());
73 auto param =
74 builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param"));
75 auto ab = builder.AddInstruction(
76 HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param));
77 auto exp = builder.AddInstruction(
78 HloInstruction::CreateUnary(vec, HloOpcode::kExp, param));
79
80 auto add = builder.AddInstruction(
81 HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp));
82 auto negate = builder.AddInstruction(
83 HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp));
84 auto sub = builder.AddInstruction(
85 HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate));
86
87 auto module = CreateNewVerifiedModule();
88 module->AddEntryComputation(builder.Build());
89
90 HloMemoryScheduler scheduler([](const BufferValue& buffer) {
91 return ShapeUtil::ByteSizeOf(buffer.shape());
92 });
93 ASSERT_FALSE(module->has_schedule());
94 TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get()));
95 EXPECT_TRUE(changed);
96 ASSERT_TRUE(module->has_schedule());
97 TF_ASSERT_OK(module->schedule().Verify());
98
99 // Verify that all instructions are in the sequence.
100 const std::vector<HloInstruction*>& sequence =
101 module->schedule().sequence(module->entry_computation()).instructions();
102 EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
103
104 // The first instruction should be the parameter and the last the root "sub".
105 EXPECT_EQ(param, sequence.front());
106 EXPECT_EQ(sub, sequence.back());
107
108 SequentialHloOrdering ordering(module->schedule());
109 EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
110
111 // Clear the schedule using the descheduling pass.
112 HloDescheduler descheduler;
113 EXPECT_TRUE(module->has_schedule());
114 TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed,
115 descheduler.Run(module.get()));
116 EXPECT_TRUE(descheduler_changed);
117 EXPECT_FALSE(module->has_schedule());
118 }
119
TEST_F(HloSchedulingTest,ListSchedulerHandlesAliasing)120 TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
121 const char* module_str = R"(
122 HloModule test_aliasing_module
123
124 ENTRY root {
125 param = s32[1000] parameter(0)
126 p0 = s32[1000] copy(param)
127 p1 = s32[1000] copy(param)
128 t = (s32[1000], s32[1000]) tuple(p0, p1)
129 a = s32[1000] get-tuple-element(t), index=0
130 b = s32[1000] get-tuple-element(t), index=1
131 c = s32[1000] add(a, b)
132 d = s32[1000] add(c, b)
133 e = s32[1000] add(c, c)
134 f = s32[1000] add(e, e)
135 ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f)
136 })";
137
138 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
139 ParseAndReturnVerifiedModule(module_str));
140
141 auto size_fn = [](const BufferValue& buffer) {
142 return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
143 };
144 int64_t peak_memory;
145 TF_ASSERT_OK_AND_ASSIGN(
146 HloSchedule schedule,
147 ScheduleModule(module.get(), size_fn,
148 ComputationSchedulerToModuleScheduler(ListMemoryScheduler),
149 /*execution_threads=*/{}, &peak_memory));
150 TF_ASSERT_OK(module->set_schedule(schedule));
151 // Verify that all instructions are in the sequence.
152 const std::vector<HloInstruction*>& sequence =
153 schedule.sequence(module->entry_computation()).instructions();
154 EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
155
156 absl::flat_hash_map<std::string, const HloInstruction*> instructions_by_name;
157 for (const HloInstruction* instruction : sequence) {
158 instructions_by_name[instruction->name()] = instruction;
159 }
160
161 // The first instruction should be the parameter and the last the root.
162 EXPECT_EQ(instructions_by_name.at("param"), sequence.front());
163 EXPECT_EQ(instructions_by_name.at("result"), sequence.back());
164
165 // Instructions "d" and "e" will both be schedulable at the same time, but
166 // instruction "d" allows us to free the buffer of "p1", so the list scheduler
167 // should prefer it.
168 SequentialHloOrdering ordering(schedule);
169 EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"),
170 instructions_by_name.at("e")));
171 EXPECT_EQ(PeakMemoryUseOfEntryComputation(module.get(), size_fn),
172 peak_memory);
173 }
174
TEST_F(HloSchedulingTest,HostSendDoneSchedule)175 TEST_F(HloSchedulingTest, HostSendDoneSchedule) {
176 const char* const module_str = R"(
177 HloModule module
178
179 ENTRY entry {
180 %p = f32[1000, 1000] parameter(0)
181 %token.0 = token[] after-all()
182 %send = (f32[1000, 1000], token[]) send(%p, %token.0),
183 channel_id=1, is_host_transfer=true
184 %n1 = f32[1000, 1000] negate(%p)
185 %n2 = f32[1000, 1000] negate(%n1)
186 %n3 = f32[1000, 1000] negate(%n2)
187 %send-done = token[] send-done(%send), channel_id=1, is_host_transfer=true
188 }
189 )";
190
191 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
192 ParseAndReturnVerifiedModule(module_str));
193
194 auto size_fn = [](const BufferValue& buffer) {
195 return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
196 };
197
198 TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
199 ScheduleModule(module.get(), size_fn,
200 ComputationSchedulerToModuleScheduler(
201 ListMemoryScheduler)));
202 // Verify that all instructions are in the sequence.
203 const std::vector<HloInstruction*>& sequence =
204 schedule.sequence(module->entry_computation()).instructions();
205 EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
206
207 absl::flat_hash_map<std::string, const HloInstruction*> instructions_by_name;
208 for (const HloInstruction* instruction : sequence) {
209 instructions_by_name[instruction->name()] = instruction;
210 }
211
212 EXPECT_LT(absl::c_find(sequence, instructions_by_name.at("send-done")),
213 absl::c_find(sequence, instructions_by_name.at("n1")));
214 }
215
TEST_F(HloSchedulingTest,TuplesAreAccountedCorrectly)216 TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
217 auto builder = HloComputation::Builder(TestName());
218 const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6});
219
220 // Wrap lit in abs because constants are considered free by
221 // IgnoreInstruction, and it skews the accounting.
222 auto lit = builder.AddInstruction(HloInstruction::CreateConstant(
223 LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1, 1})));
224 auto abs_const = builder.AddInstruction(
225 HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit));
226
227 auto abs_abs1 = builder.AddInstruction(
228 HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
229 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
230 absl::Span<HloInstruction* const>({abs_abs1})));
231 auto tuple_elm = builder.AddInstruction(
232 HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
233
234 auto abs_abs2 = builder.AddInstruction(
235 HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
236
237 builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd,
238 tuple_elm, abs_abs2));
239
240 auto module = CreateNewVerifiedModule();
241 module->AddEntryComputation(builder.Build());
242 TF_ASSERT_OK_AND_ASSIGN(
243 HloSchedule schedule,
244 ScheduleModule(
245 module.get(),
246 [](const BufferValue& buffer) {
247 return ShapeUtil::ByteSizeOf(buffer.shape(), 1);
248 },
249 ComputationSchedulerToModuleScheduler(ListMemoryScheduler)));
250
251 // Verify that all instructions are in the sequence.
252 EXPECT_EQ(module->entry_computation()->instruction_count(),
253 schedule.sequence(module->entry_computation()).size());
254 SequentialHloOrdering ordering(schedule);
255 // tuple allocates the tuple buffer and doesn't free anything.
256 // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
257 // abs_abs2 should be scheduled before tuple by List.
258 EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple));
259 }
260
TEST_F(HloSchedulingTest,MultiOutputFusionAccountedCorrectly)261 TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
262 const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5});
263 HloComputation::Builder builder(TestName());
264
265 auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
266 LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1})));
267 auto c2 = builder.AddInstruction(HloInstruction::CreateConstant(
268 LiteralUtil::CreateR1<float>({1, 2, 3, 4, 5})));
269 auto c3 = builder.AddInstruction(HloInstruction::CreateConstant(
270 LiteralUtil::CreateR1<float>({0, 2, 4, 6, 8})));
271
272 auto add = builder.AddInstruction(
273 HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2));
274 auto mul = builder.AddInstruction(
275 HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3));
276 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul}));
277
278 auto tuple_elm = builder.AddInstruction(
279 HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
280
281 auto exp = builder.AddInstruction(
282 HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3));
283
284 builder.AddInstruction(
285 HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp));
286
287 auto module = CreateNewVerifiedModule();
288 auto* computation = module->AddEntryComputation(builder.Build());
289
290 auto fusion = computation->CreateFusionInstruction(
291 {tuple, mul, add}, HloInstruction::FusionKind::kLoop);
292
293 TF_ASSERT_OK_AND_ASSIGN(
294 HloSchedule schedule,
295 ScheduleModule(
296 module.get(),
297 [](const BufferValue& buffer) {
298 return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
299 },
300 ComputationSchedulerToModuleScheduler(ListMemoryScheduler)));
301
302 // Verify that all instructions are in the sequence.
303 EXPECT_EQ(module->entry_computation()->instruction_count(),
304 schedule.sequence(module->entry_computation()).size());
305 SequentialHloOrdering ordering(schedule);
306 // fusion allocates memory for the tuple elements and doesn't free anything,
307 // so it's more expensive than exp.
308 EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
309 }
310
TEST_F(HloSchedulingTest,TrivialScheduler)311 TEST_F(HloSchedulingTest, TrivialScheduler) {
312 const char* const hlo_string = R"(
313 HloModule ModuleWithWhile
314
315 body {
316 param.b = (s32[], s32[]) parameter(0)
317 gte.0 = s32[] get-tuple-element(param.b), index=0
318 gte.1 = s32[] get-tuple-element(param.b), index=1
319 add = s32[] add(gte.0, gte.1)
320 ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
321 }
322
323 cond {
324 param.c = (s32[], s32[]) parameter(0)
325 ROOT constant = pred[] constant(true)
326 }
327
328 ENTRY main {
329 init = (s32[], s32[]) parameter(0)
330 ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
331 }
332 )";
333 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
334 ParseAndReturnVerifiedModule(hlo_string));
335 EXPECT_FALSE(module->has_schedule());
336 TF_ASSERT_OK(HloTrivialScheduler().Run(module.get()).status());
337 ASSERT_TRUE(module->has_schedule());
338 TF_ASSERT_OK(module->schedule().Verify());
339
340 // Verify that a clone of the module also has a schedule.
341 std::unique_ptr<HloModule> clone = module->Clone();
342 ASSERT_TRUE(clone->has_schedule());
343 TF_ASSERT_OK(clone->schedule().Verify());
344 }
345
346 } // namespace
347 } // namespace xla
348