1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <map>
21 #include <queue>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/heap_simulator.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
32 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/gtl/map_util.h"
40 #include "tensorflow/core/platform/logging.h"
41
42 namespace xla {
43 namespace {
44
45 using ::tensorflow::strings::HumanReadableNumBytes;
46
47 // Class implementing a list scheduler of HLO instructions which produces a
48 // sequence which minimizes memory usage by preferring to schedule the node that
49 // frees bigger buffer and defines smaller outputs.
50 //
51 // Note that list scheduler is a greedy algorithm which cannot guarantee a
52 // global optimal solution. As a counterexample, considering the following
53 // graph:
54 //
55 // +--> B ===> C -------+
56 // A -> | |
57 // | v
58 // +--> D ---> F=======>G
59 // | ^
60 // | |
61 // +--> E -----+
62 //
63 // --> : Buffer with size 1
64 // ==> : Buffer with size 2
65 //
66 // The list scheduler will always try to defer scheduling B in a greedy way
67 // since its output buffer is bigger than input. The sequence it creates will
68 // be:
69 // A D E F B C G
70 // , which has a maximum memory usage of 6 (B is alive while F is executing).
71 //
72 // An optimal way to schedule the previous graph is:
73 // A B C D E F G
74 // , which has a maximum memory usage of 5 (when F is executing).
75 //
76 class ListScheduler {
77 public:
78 // Construct and return a memory-minimizing sequence of HLO instructions
79 // containing the given HLO computation.
Run(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation)80 static StatusOr<HloInstructionSequence> Run(
81 HloComputation* computation,
82 const TuplePointsToAnalysis& points_to_analysis,
83 const BufferValue::SizeFunction& size_function,
84 const absl::flat_hash_map<const HloComputation*, int64_t>&
85 memory_by_computation) {
86 ListScheduler scheduler(computation, points_to_analysis, size_function,
87 memory_by_computation);
88 return scheduler.CreateSchedule();
89 }
90
91 // Returns whether the memory used by the given HLO should be ignored by the
92 // scheduling heuristic.
IgnoreInstruction(const HloInstruction & instruction)93 static bool IgnoreInstruction(const HloInstruction& instruction) {
94 return instruction.opcode() == HloOpcode::kParameter ||
95 instruction.opcode() == HloOpcode::kConstant;
96 }
97
98 private:
99 // The scheduling priority of an instruction is first the number of bytes
100 // freed by scheduling the instruction, and second (tie-breaker) by the number
101 // of users. This is represented as a std::pair containing these two values
102 // (first element is the bytes freed). std::pair provides the necessary
103 // comparison operators.
104 using Priority = std::pair<int64_t, int64_t>;
105
ListScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation)106 ListScheduler(HloComputation* computation,
107 const TuplePointsToAnalysis& points_to_analysis,
108 const BufferValue::SizeFunction& size_function,
109 const absl::flat_hash_map<const HloComputation*, int64_t>&
110 memory_by_computation)
111 : computation_(computation),
112 points_to_analysis_(points_to_analysis),
113 size_function_(size_function),
114 memory_by_computation_(memory_by_computation) {
115 // Create a map containing the LogicalBuffer uses for each HLO
116 // instruction. An HLO instruction "uses" a LogicalBuffer if the
117 // LogicalBuffer is in an operand of the instruction as indicated by
118 // points-to analysis.
119 for (auto* instruction : computation->instructions()) {
120 absl::flat_hash_set<const LogicalBuffer*> instr_uses;
121 for (auto* operand : instruction->operands()) {
122 points_to_analysis.GetPointsToSet(operand).ForEachElement(
123 [&](const ShapeIndex& /*index*/,
124 const PointsToSet::BufferList& buffers) {
125 instr_uses.insert(buffers.begin(), buffers.end());
126 });
127 }
128 buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
129 instr_uses.begin(), instr_uses.end());
130 }
131
132 // Create map containing the number of unscheduled uses (hlo instructions)
133 // of each logical buffer.
134 unscheduled_use_count_.reserve(points_to_analysis.num_logical_buffers());
135 for (auto* instruction : computation->instructions()) {
136 for (auto* buffer :
137 points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
138 unscheduled_use_count_[buffer] = 0;
139 }
140 }
141 for (auto* instruction : computation->instructions()) {
142 for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
143 ++unscheduled_use_count_[buffer];
144 }
145 }
146
147 // Buffers live out of the computation have an implicit use at the end of
148 // the computation.
149 for (const LogicalBuffer* live_out_buffer :
150 points_to_analysis.GetPointsToSet(computation->root_instruction())
151 .CreateFlattenedSet()) {
152 ++unscheduled_use_count_[live_out_buffer];
153 }
154 }
155
156 // Returns whether the memory used by the given buffer should be ignored by
157 // the scheduling heuristic.
IgnoreBuffer(const LogicalBuffer & buffer)158 static bool IgnoreBuffer(const LogicalBuffer& buffer) {
159 return IgnoreInstruction(*buffer.instruction());
160 }
161
162 // An entry in the worklist used by CreateSchedule. Corresponds to one
163 // HloInstruction, plus some cached metadata, saved for the purposes of making
164 // BytesFreedIfScheduled fast.
165 struct ReadyListEntry {
166 HloInstruction* instruction;
167
168 // The total size of all buffers defined by this instruction.
169 int64_t bytes_defined;
170
171 // For each buffer B used by this instruction, we keep a pair (B, U), where
172 // U is the number of uses of B that have not yet been scheduled. This pair
173 // is a pointer into the unscheduled_use_count_ map, so it gets updated for
174 // free when we update counts in the map.
175 std::vector<const std::pair<const LogicalBuffer* const, int64_t>*>
176 used_buffer_unscheduled_use_counts;
177 };
178
179 // Creates a ReadyListEntry for the given instruction.
MakeReadyListEntry(HloInstruction * instruction)180 ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) {
181 ReadyListEntry entry;
182 entry.instruction = instruction;
183
184 entry.bytes_defined = 0;
185 for (auto* buffer :
186 points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
187 if (!IgnoreBuffer(*buffer)) {
188 entry.bytes_defined += size_function_(*buffer);
189 }
190 }
191
192 for (auto* buffer : buffer_uses_.at(instruction)) {
193 if (IgnoreBuffer(*buffer)) {
194 continue;
195 }
196 auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
197 CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
198 entry.used_buffer_unscheduled_use_counts.push_back(
199 &*unscheduled_use_count_it);
200 }
201 return entry;
202 }
203
204 // Returns the number of bytes freed *after* the HLO instruction finishes.
205 // The current List algorithm only considers two states for an instruction:
206 // right before it runs, and after it finishes. We don't represent memory
207 // usage during the execution of an instruction. But if the instruction calls
208 // subcomputations, they are only live during the instruction's execution.
209 // We end up counting the memory used by subcomputations as memory "defined"
210 // by the instruction. This is not entirely accurate, but it is more accurate
211 // than not taking subcomputations into account at all. In the future, we may
212 // improve accounting for subcomputation memory (b/65409243).
BytesFreedIfScheduled(const ReadyListEntry & entry)213 int64_t BytesFreedIfScheduled(const ReadyListEntry& entry) {
214 auto instruction = entry.instruction;
215 auto opcode = instruction->opcode();
216
217 // Scheduling the outfeed early and the infeed late gives more time to the
218 // communicating processor to do its work.
219 if (opcode == HloOpcode::kOutfeed &&
220 !instruction->outfeed_config().empty()) {
221 return INT_MAX;
222 }
223 if (opcode == HloOpcode::kInfeed && !instruction->infeed_config().empty()) {
224 return INT_MIN;
225 }
226
227 int64_t freed_bytes = 0;
228 for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
229 auto buffer = kv->first;
230 auto use_count = kv->second;
231 if (use_count == 1) {
232 freed_bytes += size_function_(*buffer);
233 }
234 }
235 // We only count the memory usage of the largest subcomputation, instead of
236 // adding them all, because subcomputations won't execute in parallel.
237 int64_t max_subcomputation_bytes = 0;
238 for (const auto* c : instruction->called_computations()) {
239 auto it = memory_by_computation_.find(c);
240 if (it != memory_by_computation_.end()) {
241 int64_t subcomputation_bytes = it->second;
242 if (subcomputation_bytes > max_subcomputation_bytes) {
243 max_subcomputation_bytes = subcomputation_bytes;
244 }
245 }
246 }
247 int64_t bytes_defined;
248 if (max_subcomputation_bytes > 0 &&
249 (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
250 opcode == HloOpcode::kConditional)) {
251 // The output buffer of while/call/conditional is always aliased with the
252 // output buffer of the root instruction in the body. Don't double count.
253 bytes_defined = max_subcomputation_bytes;
254 } else {
255 bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
256 }
257 return freed_bytes - bytes_defined;
258 }
259
260 // Constructs the scheduling priority of the given instruction.
GetPriority(const ReadyListEntry & entry)261 Priority GetPriority(const ReadyListEntry& entry) {
262 // Try to cluster scalars as close together as possible so that if they are
263 // in unfused hlos, they can still live in machine registers without
264 // excessive spilling.
265 if (ShapeUtil::IsEffectiveScalar(entry.instruction->shape())) {
266 return {std::numeric_limits<int64_t>::max(),
267 std::numeric_limits<int64_t>::max()};
268 }
269 return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
270 }
271
CreateSchedule()272 HloInstructionSequence CreateSchedule() {
273 HloInstructionSequence schedule;
274
275 // Populate the ready list with instructions which have no operands or
276 // control predecessors.
277 absl::flat_hash_map<const HloInstruction*, int64_t> unscheduled_pred_count;
278 for (auto* instruction : computation_->instructions()) {
279 // TODO(b/34466113): Replace this and above with successors() or
280 // predecessors() when these methods are added to HloInstruction.
281 for (HloInstruction* user : instruction->users()) {
282 unscheduled_pred_count[user]++;
283 }
284 for (HloInstruction* succ : instruction->control_successors()) {
285 unscheduled_pred_count[succ]++;
286 }
287 }
288
289 // Use a multimap to sort ReadyListEntry according to their priority.
290 std::multimap<Priority, ReadyListEntry> ready_queue;
291
292 // Map of ready instructions to their iterators in ready_queue.
293 absl::flat_hash_map<const HloInstruction*,
294 std::multimap<Priority, ReadyListEntry>::iterator>
295 ready_instructions;
296
297 auto add_to_ready_queue = [&](HloInstruction* inst) {
298 auto entry = MakeReadyListEntry(inst);
299 auto it = ready_queue.emplace(GetPriority(entry), std::move(entry));
300 ready_instructions[inst] = it;
301 };
302
303 for (auto* instruction : computation_->instructions()) {
304 if (instruction->operands().empty() &&
305 instruction->control_predecessors().empty()) {
306 add_to_ready_queue(instruction);
307 }
308 }
309
310 while (!ready_queue.empty()) {
311 // Remove the selected instruction from the ready list and add it to the
312 // schedule.
313 auto best_it = ready_queue.end();
314 --best_it;
315 HloInstruction* best = best_it->second.instruction;
316 VLOG(2) << "Schedule instruction: " << best->ToShortString()
317 << " Bytes freed: " << best_it->first.first;
318 ready_queue.erase(best_it);
319 ready_instructions.erase(best);
320 schedule.push_back(best);
321 scheduled_instructions_.insert(best);
322
323 bool adjust_ready_queue = false;
324 // Update the unscheduled uses of the logical buffers.
325 for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
326 int64_t& count = unscheduled_use_count_[buffer];
327 CHECK_GT(count, 0);
328 --count;
329 if (count == 1) {
330 adjust_ready_queue = true;
331 }
332 }
333
334 // Add new instructions to ready list.
335 auto update_pred_count = [&](HloInstruction* inst) {
336 int64_t pred_count = --unscheduled_pred_count.at(inst);
337 CHECK_GE(pred_count, 0);
338 if (pred_count == 0) {
339 add_to_ready_queue(inst);
340 }
341 };
342 // TODO(b/34466113): Replace this and above with successors() or
343 // predecessors() when these methods are added to HloInstruction.
344 for (HloInstruction* user : best->users()) {
345 update_pred_count(user);
346 }
347 for (HloInstruction* succ : best->control_successors()) {
348 update_pred_count(succ);
349 }
350 // The unscheduled use count for a buffer has changed to 1, so the
351 // priorities of some ready instructions may go up. We update them in the
352 // ready queue, so that they can appear earlier.
353 if (adjust_ready_queue) {
354 for (HloInstruction* operand : best->operands()) {
355 for (HloInstruction* operand_user : operand->users()) {
356 auto ready_instructions_it = ready_instructions.find(operand_user);
357 if (ready_instructions_it == ready_instructions.end()) {
358 continue;
359 }
360 auto ready_queue_it = ready_instructions_it->second;
361 auto& entry = ready_queue_it->second;
362 Priority new_priority = GetPriority(entry);
363 if (new_priority == ready_queue_it->first) {
364 continue;
365 }
366 // Create a new entry in ready_queue, then update
367 // ready_instructions[operand_user] to refer to the new entry.
368 ready_instructions_it->second =
369 ready_queue.emplace(new_priority, std::move(entry));
370 // Remove the old entry in ready_queue.
371 ready_queue.erase(ready_queue_it);
372 }
373 }
374 }
375 }
376 CHECK_EQ(schedule.size(), computation_->instruction_count());
377 CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
378
379 return schedule;
380 }
381
382 HloComputation* computation_;
383 const TuplePointsToAnalysis& points_to_analysis_;
384 const BufferValue::SizeFunction& size_function_;
385 // Computations are analyzed in post-order. When scheduling an instruction
386 // that includes subcomputations, such as a while loop, we use this map to
387 // look up the memory needed by subcomputations.
388 const absl::flat_hash_map<const HloComputation*, int64_t>&
389 memory_by_computation_;
390
391 // A map containing the LogicalBuffers that each instruction uses.
392 absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
393 buffer_uses_;
394
395 // A map containing the count of unscheduled HLOs which using a particular
396 // LogicalBuffer.
397 absl::flat_hash_map<const LogicalBuffer*, int64_t> unscheduled_use_count_;
398
399 // Set of instructions which have been scheduled.
400 absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
401 };
402
SumLogicalBufferSizes(const TuplePointsToAnalysis::BufferDefinitionVector & buffers,const BufferValue::SizeFunction & size_function)403 int64_t SumLogicalBufferSizes(
404 const TuplePointsToAnalysis::BufferDefinitionVector& buffers,
405 const BufferValue::SizeFunction& size_function) {
406 int64_t size = 0;
407 for (const LogicalBuffer* buffer : buffers) {
408 size += size_function(*buffer);
409 }
410 return size;
411 }
412
ScheduleComputationHelper(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const MemorySchedulerAlgorithm & algorithm,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation,const MemorySchedulerPostprocessor & postprocessor,int64_t * peak_memory)413 StatusOr<HloInstructionSequence> ScheduleComputationHelper(
414 HloComputation* computation,
415 const TuplePointsToAnalysis& points_to_analysis,
416 const HloAliasAnalysis& alias_analysis,
417 const BufferValue::SizeFunction& size_function,
418 const MemorySchedulerAlgorithm& algorithm,
419 const absl::flat_hash_map<const HloComputation*, int64_t>&
420 memory_by_computation,
421 const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
422 VLOG(2) << "Computation: " << computation->name();
423
424 if (algorithm) {
425 return algorithm(computation, points_to_analysis, alias_analysis,
426 size_function, memory_by_computation, postprocessor,
427 peak_memory);
428 }
429 return DefaultMemoryScheduler(computation, points_to_analysis, alias_analysis,
430 size_function, memory_by_computation,
431 postprocessor, peak_memory);
432 }
433
434 } // namespace
435
DFSMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation,const MemorySchedulerPostprocessor & postprocessor,int64_t * peak_memory)436 StatusOr<HloInstructionSequence> DFSMemoryScheduler(
437 HloComputation* computation,
438 const TuplePointsToAnalysis& points_to_analysis,
439 const HloAliasAnalysis& alias_analysis,
440 const BufferValue::SizeFunction& size_function,
441 const absl::flat_hash_map<const HloComputation*, int64_t>&
442 memory_by_computation,
443 const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
444 // These variables are a hack to prevent overflows.
445 int64_t cumulative_total_size = 0;
446 int64_t total_hlos = computation->instruction_count();
447 struct Stats {
448 // Transitively includes the count of all nodes that lead to it.
449 int64_t extra_users = 0;
450 // Transitively includes the sizes of all nodes that lead to it.
451 int64_t total_sizes = 0;
452 };
453 absl::flat_hash_map<const HloInstruction*, Stats> stats_map;
454 stats_map.reserve(computation->instruction_count());
455
456 for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
457 auto& stats = stats_map[hlo];
458 if (ListScheduler::IgnoreInstruction(*hlo)) {
459 continue;
460 }
461 // This ordering is based on DFS post-order, with a heuristic to decide
462 // which operand to visit first. The heuristic is based on 'extra_users',
463 // which is simply users-1 for each instruction. By subtracting 1, we're
464 // saying that instructions with no users or a single user don't count;
465 // instructions with lots of fan-out will be visited earlier.
466 stats.extra_users = hlo->users().empty() ? 0 : hlo->users().size() - 1;
467 int64_t logical_buffer_size = SumLogicalBufferSizes(
468 points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
469 stats.total_sizes = logical_buffer_size;
470 cumulative_total_size += logical_buffer_size;
471 absl::flat_hash_set<const HloInstruction*> unique_operands(
472 hlo->operands().begin(), hlo->operands().end());
473 for (const HloInstruction* operand : unique_operands) {
474 auto& operand_stats = stats_map.at(operand);
475 stats.extra_users += operand_stats.extra_users;
476 stats.total_sizes += operand_stats.total_sizes;
477 }
478 // stats.total_sizes transitively includes the sizes of all nodes that
479 // lead to it. But computation is a DAG, so we are double-counting nodes,
480 // which can lead to overflows for large programs.
481 // cumulative_total_size caps the size to prevent overflows.
482 // Same for total_hlos: it prevents overflows on very large and branchy
483 // models, where the number of paths is exponential to the number of nodes.
484 // NOTE(dimvar): this is quite ugly and should be changed. It's unclear
485 // why we care about transitive sizes; when scheduling a node, its input
486 // and output buffers should be all that matters, not its "history".
487 stats.total_sizes = std::min(stats.total_sizes, cumulative_total_size);
488 stats.extra_users = std::min(stats.extra_users, total_hlos);
489 }
490 CHECK_EQ(stats_map.size(), computation->instruction_count());
491
492 // Construct a total order based on DFS post-order, visiting operands in
493 // decreasing cumulative extra user order, and next by cumulative size, with a
494 // tiebreaker by name for determinism.
495 HloInstructionSequence sequence;
496 FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
497 sequence.push_back(hlo);
498 return OkStatus();
499 });
500 visitor.ReserveVisitStates(computation->instruction_count());
501 TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder(
502 &visitor, [&stats_map](const HloInstruction* a, const HloInstruction* b) {
503 auto& stats_a = stats_map.at(a);
504 auto& stats_b = stats_map.at(b);
505 if (stats_a.extra_users != stats_b.extra_users) {
506 return stats_a.extra_users > stats_b.extra_users;
507 }
508 if (stats_a.total_sizes != stats_b.total_sizes) {
509 return stats_a.total_sizes > stats_b.total_sizes;
510 }
511 return a->name() < b->name();
512 }));
513 if (postprocessor) {
514 sequence = postprocessor(sequence);
515 }
516 CHECK_EQ(sequence.size(), computation->instruction_count());
517 if (peak_memory) {
518 TF_ASSIGN_OR_RETURN(
519 *peak_memory, HeapSimulator::MinimumMemoryForComputation(
520 *computation, sequence, alias_analysis, size_function,
521 &memory_by_computation));
522 }
523 return sequence;
524 } // namespace xla
525
ComputationSchedulerToModuleScheduler(const MemorySchedulerAlgorithm & computation_scheduler,const MemorySchedulerPostprocessor & postprocessor)526 ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler(
527 const MemorySchedulerAlgorithm& computation_scheduler,
528 const MemorySchedulerPostprocessor& postprocessor) {
529 return [computation_scheduler, postprocessor](
530 const HloModule* module,
531 const TuplePointsToAnalysis& points_to_analysis,
532 const HloAliasAnalysis& alias_analysis,
533 const LogicalBuffer::SizeFunction& size_func,
534 const absl::flat_hash_set<absl::string_view>& execution_threads,
535 int64_t* peak_memory) -> StatusOr<HloSchedule> {
536 HloSchedule schedule(module);
537 absl::flat_hash_map<const HloComputation*, int64_t> memory_by_computation;
538 for (auto* computation :
539 module->MakeComputationPostOrder(execution_threads)) {
540 if (!computation->IsFusionComputation()) {
541 TF_ASSIGN_OR_RETURN(
542 HloInstructionSequence computation_sequence,
543 ScheduleComputationHelper(
544 computation, points_to_analysis, alias_analysis, size_func,
545 computation_scheduler, memory_by_computation, postprocessor,
546 /*peak_memory=*/nullptr));
547 schedule.set_sequence(computation, std::move(computation_sequence));
548 }
549 }
550 if (peak_memory) {
551 TF_ASSIGN_OR_RETURN(*peak_memory, HeapSimulator::MinimumMemoryForModule(
552 schedule, size_func));
553 }
554 return std::move(schedule);
555 };
556 }
557
ListMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation,const MemorySchedulerPostprocessor & postprocessor,int64_t * peak_memory)558 StatusOr<HloInstructionSequence> ListMemoryScheduler(
559 HloComputation* computation,
560 const TuplePointsToAnalysis& points_to_analysis,
561 const HloAliasAnalysis& alias_analysis,
562 const BufferValue::SizeFunction& size_function,
563 const absl::flat_hash_map<const HloComputation*, int64_t>&
564 memory_by_computation,
565 const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
566 TF_ASSIGN_OR_RETURN(HloInstructionSequence sequence,
567 ListScheduler::Run(computation, points_to_analysis,
568 size_function, memory_by_computation));
569 if (postprocessor) {
570 sequence = postprocessor(sequence);
571 }
572 if (peak_memory) {
573 TF_ASSIGN_OR_RETURN(
574 *peak_memory, HeapSimulator::MinimumMemoryForComputation(
575 *computation, sequence, alias_analysis, size_function,
576 &memory_by_computation));
577 }
578 return sequence;
579 }
580
PostOrderMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation,const MemorySchedulerPostprocessor & postprocessor,int64_t * peak_memory)581 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
582 HloComputation* computation,
583 const TuplePointsToAnalysis& points_to_analysis,
584 const HloAliasAnalysis& alias_analysis,
585 const BufferValue::SizeFunction& size_function,
586 const absl::flat_hash_map<const HloComputation*, int64_t>&
587 memory_by_computation,
588 const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
589 HloInstructionSequence sequence(computation->MakeInstructionPostOrder());
590 if (postprocessor) {
591 sequence = postprocessor(sequence);
592 }
593 if (peak_memory) {
594 TF_ASSIGN_OR_RETURN(
595 *peak_memory, HeapSimulator::MinimumMemoryForComputation(
596 *computation, sequence, alias_analysis, size_function,
597 &memory_by_computation));
598 }
599 return sequence;
600 }
601
DefaultMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation,const MemorySchedulerPostprocessor & postprocessor,int64_t * peak_memory)602 StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
603 HloComputation* computation,
604 const TuplePointsToAnalysis& points_to_analysis,
605 const HloAliasAnalysis& alias_analysis,
606 const BufferValue::SizeFunction& size_function,
607 const absl::flat_hash_map<const HloComputation*, int64_t>&
608 memory_by_computation,
609 const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
610 // We try a few schedulers and choose whichever returns a lower min-memory,
611 // not accounting for fragmentation.
612 // - List is a scheduler that uses greedy heuristics.
613 // - DFS visits HLOs in postorder, with a heuristic to decide the order of
614 // children.
615 // - Postorder does not use any heuristics.
616 // List wins for most of our benchmarks; postorder-based schedulers win for
617 // some RNNs.
618 int64_t list_memory;
619 TF_ASSIGN_OR_RETURN(
620 HloInstructionSequence list_sequence,
621 ListMemoryScheduler(computation, points_to_analysis, alias_analysis,
622 size_function, memory_by_computation, postprocessor,
623 &list_memory));
624 VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
625
626 int64_t dfs_memory;
627 TF_ASSIGN_OR_RETURN(
628 HloInstructionSequence dfs_sequence,
629 DFSMemoryScheduler(computation, points_to_analysis, alias_analysis,
630 size_function, memory_by_computation, postprocessor,
631 &dfs_memory));
632 VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
633
634 int64_t post_order_memory;
635 TF_ASSIGN_OR_RETURN(
636 HloInstructionSequence post_order_sequence,
637 PostOrderMemoryScheduler(computation, points_to_analysis, alias_analysis,
638 size_function, memory_by_computation,
639 postprocessor, &post_order_memory));
640 VLOG(2) << "Min-memory post order sequence: "
641 << HumanReadableNumBytes(post_order_memory);
642
643 auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
644 if (peak_memory) {
645 *peak_memory = min_memory;
646 }
647
648 if (min_memory == list_memory) {
649 VLOG(2) << "Chose min-memory list sequence: "
650 << HumanReadableNumBytes(list_memory);
651 return list_sequence;
652 } else if (min_memory == dfs_memory) {
653 VLOG(2) << "Chose min-memory dfs sequence: "
654 << HumanReadableNumBytes(dfs_memory);
655 return dfs_sequence;
656 } else {
657 VLOG(2) << "Chose min-memory post_order sequence: "
658 << HumanReadableNumBytes(post_order_memory);
659 return post_order_sequence;
660 }
661 }
662
DefaultModuleScheduler(const HloModule * module,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_set<absl::string_view> & execution_threads,int64_t * peak_memory)663 StatusOr<HloSchedule> DefaultModuleScheduler(
664 const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
665 const HloAliasAnalysis& alias_analysis,
666 const BufferValue::SizeFunction& size_function,
667 const absl::flat_hash_set<absl::string_view>& execution_threads,
668 int64_t* peak_memory) {
669 // We try a few schedulers and choose whichever returns a lower min-memory,
670 // not accounting for fragmentation.
671 // - List is a scheduler that uses greedy heuristics.
672 // - DFS visits HLOs in postorder, with a heuristic to decide the order of
673 // children.
674 // - Postorder does not use any heuristics.
675 // List wins for most of our benchmarks; postorder-based schedulers win for
676 // some RNNs.
677 int64_t list_memory;
678 TF_ASSIGN_OR_RETURN(
679 HloSchedule list_sequence,
680 ComputationSchedulerToModuleScheduler(ListMemoryScheduler, {})(
681 module, points_to_analysis, alias_analysis, size_function,
682 execution_threads, &list_memory));
683
684 VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
685
686 int64_t dfs_memory;
687 TF_ASSIGN_OR_RETURN(
688 HloSchedule dfs_sequence,
689 ComputationSchedulerToModuleScheduler(DFSMemoryScheduler, {})(
690 module, points_to_analysis, alias_analysis, size_function,
691 execution_threads, &dfs_memory));
692 VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
693
694 int64_t post_order_memory;
695 TF_ASSIGN_OR_RETURN(
696 HloSchedule post_order_sequence,
697 ComputationSchedulerToModuleScheduler(PostOrderMemoryScheduler, {})(
698 module, points_to_analysis, alias_analysis, size_function,
699 execution_threads, &post_order_memory));
700 VLOG(2) << "Min-memory post order sequence: "
701 << HumanReadableNumBytes(post_order_memory);
702
703 auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
704 if (peak_memory) {
705 *peak_memory = min_memory;
706 }
707
708 if (min_memory == list_memory) {
709 VLOG(2) << "Chose min-memory list sequence: "
710 << HumanReadableNumBytes(list_memory);
711 return list_sequence;
712 } else if (min_memory == dfs_memory) {
713 VLOG(2) << "Chose min-memory dfs sequence: "
714 << HumanReadableNumBytes(dfs_memory);
715 return dfs_sequence;
716 } else {
717 VLOG(2) << "Chose min-memory post_order sequence: "
718 << HumanReadableNumBytes(post_order_memory);
719 return post_order_sequence;
720 }
721 }
722
ScheduleModule(const HloModule * module,const BufferValue::SizeFunction & size_function,const ModuleSchedulerAlgorithm & algorithm,const absl::flat_hash_set<absl::string_view> & execution_threads,int64_t * peak_memory)723 StatusOr<HloSchedule> ScheduleModule(
724 const HloModule* module, const BufferValue::SizeFunction& size_function,
725 const ModuleSchedulerAlgorithm& algorithm,
726 const absl::flat_hash_set<absl::string_view>& execution_threads,
727 int64_t* peak_memory) {
728 TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
729 TuplePointsToAnalysis::Run(module));
730 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
731 HloAliasAnalysis::Run(module));
732
733 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
734 (algorithm ? algorithm : DefaultModuleScheduler)(
735 module, *points_to_analysis, *alias_analysis,
736 size_function, execution_threads, peak_memory));
737
738 TF_RETURN_IF_ERROR(schedule.Verify());
739
740 return std::move(schedule);
741 }
742
ScheduleComputation(HloComputation * computation,const BufferValue::SizeFunction & size_function,const MemorySchedulerPostprocessor & postprocessor)743 StatusOr<HloInstructionSequence> ScheduleComputation(
744 HloComputation* computation, const BufferValue::SizeFunction& size_function,
745 const MemorySchedulerPostprocessor& postprocessor) {
746 CHECK(!computation->IsFusionComputation());
747 TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
748 TuplePointsToAnalysis::Run(computation->parent()));
749 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
750 HloAliasAnalysis::Run(computation->parent()));
751 absl::flat_hash_map<const HloComputation*, int64_t> empty_map;
752 return ScheduleComputationHelper(
753 computation, *points_to_analysis, *alias_analysis, size_function,
754 /*algorithm=*/nullptr, empty_map, postprocessor,
755 /*peak_memory=*/nullptr);
756 }
757
HloMemoryScheduler(const BufferValue::SizeFunction & size_function,const ModuleSchedulerAlgorithm & algorithm)758 HloMemoryScheduler::HloMemoryScheduler(
759 const BufferValue::SizeFunction& size_function,
760 const ModuleSchedulerAlgorithm& algorithm)
761 : size_function_(size_function), algorithm_(algorithm) {}
762
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)763 StatusOr<bool> HloMemoryScheduler::Run(
764 HloModule* module,
765 const absl::flat_hash_set<absl::string_view>& execution_threads) {
766 TF_ASSIGN_OR_RETURN(
767 HloSchedule schedule,
768 ScheduleModule(module, size_function_, algorithm_, execution_threads));
769 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
770 return true;
771 }
772
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)773 StatusOr<bool> HloTrivialScheduler::Run(
774 HloModule* module,
775 const absl::flat_hash_set<absl::string_view>& execution_threads) {
776 HloSchedule schedule(module);
777 for (HloComputation* computation :
778 module->MakeComputationPostOrder(execution_threads)) {
779 if (!computation->IsFusionComputation()) {
780 HloInstructionSequence& computation_sequence =
781 schedule.GetOrCreateSequence(computation);
782 FunctionVisitor visitor(
783 [&computation_sequence](HloInstruction* instruction) {
784 computation_sequence.push_back(instruction);
785 return OkStatus();
786 });
787 visitor.ReserveVisitStates(computation->instruction_count());
788 TF_RETURN_IF_ERROR(computation->Accept(&visitor));
789 }
790 }
791 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
792 return true;
793 }
794
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)795 StatusOr<bool> HloDescheduler::Run(
796 HloModule* module,
797 const absl::flat_hash_set<absl::string_view>& execution_threads) {
798 bool changed = module->has_schedule();
799 module->clear_schedule();
800 return changed;
801 }
802
803 } // namespace xla
804