xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <map>
21 #include <queue>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/heap_simulator.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
32 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/gtl/map_util.h"
40 #include "tensorflow/core/platform/logging.h"
41 
42 namespace xla {
43 namespace {
44 
45 using ::tensorflow::strings::HumanReadableNumBytes;
46 
47 // Class implementing a list scheduler of HLO instructions which produces a
48 // sequence which minimizes memory usage by preferring to schedule the node that
49 // frees bigger buffer and defines smaller outputs.
50 //
51 // Note that list scheduler is a greedy algorithm which cannot guarantee a
52 // global optimal solution. As a counterexample, considering the following
53 // graph:
54 //
55 //      +--> B ===> C -------+
56 // A -> |                    |
57 //      |                    v
58 //      +--> D ---> F=======>G
59 //      |           ^
60 //      |           |
61 //      +--> E -----+
62 //
63 //  --> : Buffer with size 1
64 //  ==> : Buffer with size 2
65 //
66 // The list scheduler will always try to defer scheduling B in a greedy way
67 // since its output buffer is bigger than input. The sequence it creates will
68 // be:
69 //   A D E F B C G
70 // , which has a maximum memory usage of 6 (B is alive while F is executing).
71 //
72 // An optimal way to schedule the previous graph is:
73 //   A B C D E F G
74 // , which has a maximum memory usage of 5 (when F is executing).
75 //
76 class ListScheduler {
77  public:
78   // Construct and return a memory-minimizing sequence of HLO instructions
79   // containing the given HLO computation.
Run(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation)80   static StatusOr<HloInstructionSequence> Run(
81       HloComputation* computation,
82       const TuplePointsToAnalysis& points_to_analysis,
83       const BufferValue::SizeFunction& size_function,
84       const absl::flat_hash_map<const HloComputation*, int64_t>&
85           memory_by_computation) {
86     ListScheduler scheduler(computation, points_to_analysis, size_function,
87                             memory_by_computation);
88     return scheduler.CreateSchedule();
89   }
90 
91   // Returns whether the memory used by the given HLO should be ignored by the
92   // scheduling heuristic.
IgnoreInstruction(const HloInstruction & instruction)93   static bool IgnoreInstruction(const HloInstruction& instruction) {
94     return instruction.opcode() == HloOpcode::kParameter ||
95            instruction.opcode() == HloOpcode::kConstant;
96   }
97 
98  private:
99   // The scheduling priority of an instruction is first the number of bytes
100   // freed by scheduling the instruction, and second (tie-breaker) by the number
101   // of users. This is represented as a std::pair containing these two values
102   // (first element is the bytes freed). std::pair provides the necessary
103   // comparison operators.
104   using Priority = std::pair<int64_t, int64_t>;
105 
ListScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation)106   ListScheduler(HloComputation* computation,
107                 const TuplePointsToAnalysis& points_to_analysis,
108                 const BufferValue::SizeFunction& size_function,
109                 const absl::flat_hash_map<const HloComputation*, int64_t>&
110                     memory_by_computation)
111       : computation_(computation),
112         points_to_analysis_(points_to_analysis),
113         size_function_(size_function),
114         memory_by_computation_(memory_by_computation) {
115     // Create a map containing the LogicalBuffer uses for each HLO
116     // instruction. An HLO instruction "uses" a LogicalBuffer if the
117     // LogicalBuffer is in an operand of the instruction as indicated by
118     // points-to analysis.
119     for (auto* instruction : computation->instructions()) {
120       absl::flat_hash_set<const LogicalBuffer*> instr_uses;
121       for (auto* operand : instruction->operands()) {
122         points_to_analysis.GetPointsToSet(operand).ForEachElement(
123             [&](const ShapeIndex& /*index*/,
124                 const PointsToSet::BufferList& buffers) {
125               instr_uses.insert(buffers.begin(), buffers.end());
126             });
127       }
128       buffer_uses_[instruction] = std::vector<const LogicalBuffer*>(
129           instr_uses.begin(), instr_uses.end());
130     }
131 
132     // Create map containing the number of unscheduled uses (hlo instructions)
133     // of each logical buffer.
134     unscheduled_use_count_.reserve(points_to_analysis.num_logical_buffers());
135     for (auto* instruction : computation->instructions()) {
136       for (auto* buffer :
137            points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
138         unscheduled_use_count_[buffer] = 0;
139       }
140     }
141     for (auto* instruction : computation->instructions()) {
142       for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) {
143         ++unscheduled_use_count_[buffer];
144       }
145     }
146 
147     // Buffers live out of the computation have an implicit use at the end of
148     // the computation.
149     for (const LogicalBuffer* live_out_buffer :
150          points_to_analysis.GetPointsToSet(computation->root_instruction())
151              .CreateFlattenedSet()) {
152       ++unscheduled_use_count_[live_out_buffer];
153     }
154   }
155 
156   // Returns whether the memory used by the given buffer should be ignored by
157   // the scheduling heuristic.
IgnoreBuffer(const LogicalBuffer & buffer)158   static bool IgnoreBuffer(const LogicalBuffer& buffer) {
159     return IgnoreInstruction(*buffer.instruction());
160   }
161 
162   // An entry in the worklist used by CreateSchedule.  Corresponds to one
163   // HloInstruction, plus some cached metadata, saved for the purposes of making
164   // BytesFreedIfScheduled fast.
165   struct ReadyListEntry {
166     HloInstruction* instruction;
167 
168     // The total size of all buffers defined by this instruction.
169     int64_t bytes_defined;
170 
171     // For each buffer B used by this instruction, we keep a pair (B, U), where
172     // U is the number of uses of B that have not yet been scheduled. This pair
173     // is a pointer into the unscheduled_use_count_ map, so it gets updated for
174     // free when we update counts in the map.
175     std::vector<const std::pair<const LogicalBuffer* const, int64_t>*>
176         used_buffer_unscheduled_use_counts;
177   };
178 
179   // Creates a ReadyListEntry for the given instruction.
MakeReadyListEntry(HloInstruction * instruction)180   ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) {
181     ReadyListEntry entry;
182     entry.instruction = instruction;
183 
184     entry.bytes_defined = 0;
185     for (auto* buffer :
186          points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
187       if (!IgnoreBuffer(*buffer)) {
188         entry.bytes_defined += size_function_(*buffer);
189       }
190     }
191 
192     for (auto* buffer : buffer_uses_.at(instruction)) {
193       if (IgnoreBuffer(*buffer)) {
194         continue;
195       }
196       auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer);
197       CHECK(unscheduled_use_count_it != unscheduled_use_count_.end());
198       entry.used_buffer_unscheduled_use_counts.push_back(
199           &*unscheduled_use_count_it);
200     }
201     return entry;
202   }
203 
204   // Returns the number of bytes freed *after* the HLO instruction finishes.
205   // The current List algorithm only considers two states for an instruction:
206   // right before it runs, and after it finishes. We don't represent memory
207   // usage during the execution of an instruction. But if the instruction calls
208   // subcomputations, they are only live during the instruction's execution.
209   // We end up counting the memory used by subcomputations as memory "defined"
210   // by the instruction. This is not entirely accurate, but it is more accurate
211   // than not taking subcomputations into account at all. In the future, we may
212   // improve accounting for subcomputation memory (b/65409243).
BytesFreedIfScheduled(const ReadyListEntry & entry)213   int64_t BytesFreedIfScheduled(const ReadyListEntry& entry) {
214     auto instruction = entry.instruction;
215     auto opcode = instruction->opcode();
216 
217     // Scheduling the outfeed early and the infeed late gives more time to the
218     // communicating processor to do its work.
219     if (opcode == HloOpcode::kOutfeed &&
220         !instruction->outfeed_config().empty()) {
221       return INT_MAX;
222     }
223     if (opcode == HloOpcode::kInfeed && !instruction->infeed_config().empty()) {
224       return INT_MIN;
225     }
226 
227     int64_t freed_bytes = 0;
228     for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
229       auto buffer = kv->first;
230       auto use_count = kv->second;
231       if (use_count == 1) {
232         freed_bytes += size_function_(*buffer);
233       }
234     }
235     // We only count the memory usage of the largest subcomputation, instead of
236     // adding them all, because subcomputations won't execute in parallel.
237     int64_t max_subcomputation_bytes = 0;
238     for (const auto* c : instruction->called_computations()) {
239       auto it = memory_by_computation_.find(c);
240       if (it != memory_by_computation_.end()) {
241         int64_t subcomputation_bytes = it->second;
242         if (subcomputation_bytes > max_subcomputation_bytes) {
243           max_subcomputation_bytes = subcomputation_bytes;
244         }
245       }
246     }
247     int64_t bytes_defined;
248     if (max_subcomputation_bytes > 0 &&
249         (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
250          opcode == HloOpcode::kConditional)) {
251       // The output buffer of while/call/conditional is always aliased with the
252       // output buffer of the root instruction in the body. Don't double count.
253       bytes_defined = max_subcomputation_bytes;
254     } else {
255       bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
256     }
257     return freed_bytes - bytes_defined;
258   }
259 
260   // Constructs the scheduling priority of the given instruction.
GetPriority(const ReadyListEntry & entry)261   Priority GetPriority(const ReadyListEntry& entry) {
262     // Try to cluster scalars as close together as possible so that if they are
263     // in unfused hlos, they can still live in machine registers without
264     // excessive spilling.
265     if (ShapeUtil::IsEffectiveScalar(entry.instruction->shape())) {
266       return {std::numeric_limits<int64_t>::max(),
267               std::numeric_limits<int64_t>::max()};
268     }
269     return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
270   }
271 
CreateSchedule()272   HloInstructionSequence CreateSchedule() {
273     HloInstructionSequence schedule;
274 
275     // Populate the ready list with instructions which have no operands or
276     // control predecessors.
277     absl::flat_hash_map<const HloInstruction*, int64_t> unscheduled_pred_count;
278     for (auto* instruction : computation_->instructions()) {
279       // TODO(b/34466113): Replace this and above with successors() or
280       // predecessors() when these methods are added to HloInstruction.
281       for (HloInstruction* user : instruction->users()) {
282         unscheduled_pred_count[user]++;
283       }
284       for (HloInstruction* succ : instruction->control_successors()) {
285         unscheduled_pred_count[succ]++;
286       }
287     }
288 
289     // Use a multimap to sort ReadyListEntry according to their priority.
290     std::multimap<Priority, ReadyListEntry> ready_queue;
291 
292     // Map of ready instructions to their iterators in ready_queue.
293     absl::flat_hash_map<const HloInstruction*,
294                         std::multimap<Priority, ReadyListEntry>::iterator>
295         ready_instructions;
296 
297     auto add_to_ready_queue = [&](HloInstruction* inst) {
298       auto entry = MakeReadyListEntry(inst);
299       auto it = ready_queue.emplace(GetPriority(entry), std::move(entry));
300       ready_instructions[inst] = it;
301     };
302 
303     for (auto* instruction : computation_->instructions()) {
304       if (instruction->operands().empty() &&
305           instruction->control_predecessors().empty()) {
306         add_to_ready_queue(instruction);
307       }
308     }
309 
310     while (!ready_queue.empty()) {
311       // Remove the selected instruction from the ready list and add it to the
312       // schedule.
313       auto best_it = ready_queue.end();
314       --best_it;
315       HloInstruction* best = best_it->second.instruction;
316       VLOG(2) << "Schedule instruction: " << best->ToShortString()
317               << " Bytes freed: " << best_it->first.first;
318       ready_queue.erase(best_it);
319       ready_instructions.erase(best);
320       schedule.push_back(best);
321       scheduled_instructions_.insert(best);
322 
323       bool adjust_ready_queue = false;
324       // Update the unscheduled uses of the logical buffers.
325       for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
326         int64_t& count = unscheduled_use_count_[buffer];
327         CHECK_GT(count, 0);
328         --count;
329         if (count == 1) {
330           adjust_ready_queue = true;
331         }
332       }
333 
334       // Add new instructions to ready list.
335       auto update_pred_count = [&](HloInstruction* inst) {
336         int64_t pred_count = --unscheduled_pred_count.at(inst);
337         CHECK_GE(pred_count, 0);
338         if (pred_count == 0) {
339           add_to_ready_queue(inst);
340         }
341       };
342       // TODO(b/34466113): Replace this and above with successors() or
343       // predecessors() when these methods are added to HloInstruction.
344       for (HloInstruction* user : best->users()) {
345         update_pred_count(user);
346       }
347       for (HloInstruction* succ : best->control_successors()) {
348         update_pred_count(succ);
349       }
350       // The unscheduled use count for a buffer has changed to 1, so the
351       // priorities of some ready instructions may go up. We update them in the
352       // ready queue, so that they can appear earlier.
353       if (adjust_ready_queue) {
354         for (HloInstruction* operand : best->operands()) {
355           for (HloInstruction* operand_user : operand->users()) {
356             auto ready_instructions_it = ready_instructions.find(operand_user);
357             if (ready_instructions_it == ready_instructions.end()) {
358               continue;
359             }
360             auto ready_queue_it = ready_instructions_it->second;
361             auto& entry = ready_queue_it->second;
362             Priority new_priority = GetPriority(entry);
363             if (new_priority == ready_queue_it->first) {
364               continue;
365             }
366             // Create a new entry in ready_queue, then update
367             // ready_instructions[operand_user] to refer to the new entry.
368             ready_instructions_it->second =
369                 ready_queue.emplace(new_priority, std::move(entry));
370             // Remove the old entry in ready_queue.
371             ready_queue.erase(ready_queue_it);
372           }
373         }
374       }
375     }
376     CHECK_EQ(schedule.size(), computation_->instruction_count());
377     CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count());
378 
379     return schedule;
380   }
381 
382   HloComputation* computation_;
383   const TuplePointsToAnalysis& points_to_analysis_;
384   const BufferValue::SizeFunction& size_function_;
385   // Computations are analyzed in post-order. When scheduling an instruction
386   // that includes subcomputations, such as a while loop, we use this map to
387   // look up the memory needed by subcomputations.
388   const absl::flat_hash_map<const HloComputation*, int64_t>&
389       memory_by_computation_;
390 
391   // A map containing the LogicalBuffers that each instruction uses.
392   absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
393       buffer_uses_;
394 
395   // A map containing the count of unscheduled HLOs which using a particular
396   // LogicalBuffer.
397   absl::flat_hash_map<const LogicalBuffer*, int64_t> unscheduled_use_count_;
398 
399   // Set of instructions which have been scheduled.
400   absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
401 };
402 
SumLogicalBufferSizes(const TuplePointsToAnalysis::BufferDefinitionVector & buffers,const BufferValue::SizeFunction & size_function)403 int64_t SumLogicalBufferSizes(
404     const TuplePointsToAnalysis::BufferDefinitionVector& buffers,
405     const BufferValue::SizeFunction& size_function) {
406   int64_t size = 0;
407   for (const LogicalBuffer* buffer : buffers) {
408     size += size_function(*buffer);
409   }
410   return size;
411 }
412 
ScheduleComputationHelper(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const MemorySchedulerAlgorithm & algorithm,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation,const MemorySchedulerPostprocessor & postprocessor,int64_t * peak_memory)413 StatusOr<HloInstructionSequence> ScheduleComputationHelper(
414     HloComputation* computation,
415     const TuplePointsToAnalysis& points_to_analysis,
416     const HloAliasAnalysis& alias_analysis,
417     const BufferValue::SizeFunction& size_function,
418     const MemorySchedulerAlgorithm& algorithm,
419     const absl::flat_hash_map<const HloComputation*, int64_t>&
420         memory_by_computation,
421     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
422   VLOG(2) << "Computation: " << computation->name();
423 
424   if (algorithm) {
425     return algorithm(computation, points_to_analysis, alias_analysis,
426                      size_function, memory_by_computation, postprocessor,
427                      peak_memory);
428   }
429   return DefaultMemoryScheduler(computation, points_to_analysis, alias_analysis,
430                                 size_function, memory_by_computation,
431                                 postprocessor, peak_memory);
432 }
433 
434 }  // namespace
435 
DFSMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation,const MemorySchedulerPostprocessor & postprocessor,int64_t * peak_memory)436 StatusOr<HloInstructionSequence> DFSMemoryScheduler(
437     HloComputation* computation,
438     const TuplePointsToAnalysis& points_to_analysis,
439     const HloAliasAnalysis& alias_analysis,
440     const BufferValue::SizeFunction& size_function,
441     const absl::flat_hash_map<const HloComputation*, int64_t>&
442         memory_by_computation,
443     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
444   // These variables are a hack to prevent overflows.
445   int64_t cumulative_total_size = 0;
446   int64_t total_hlos = computation->instruction_count();
447   struct Stats {
448     // Transitively includes the count of all nodes that lead to it.
449     int64_t extra_users = 0;
450     // Transitively includes the sizes of all nodes that lead to it.
451     int64_t total_sizes = 0;
452   };
453   absl::flat_hash_map<const HloInstruction*, Stats> stats_map;
454   stats_map.reserve(computation->instruction_count());
455 
456   for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
457     auto& stats = stats_map[hlo];
458     if (ListScheduler::IgnoreInstruction(*hlo)) {
459       continue;
460     }
461     // This ordering is based on DFS post-order, with a heuristic to decide
462     // which operand to visit first.  The heuristic is based on 'extra_users',
463     // which is simply users-1 for each instruction.  By subtracting 1, we're
464     // saying that instructions with no users or a single user don't count;
465     // instructions with lots of fan-out will be visited earlier.
466     stats.extra_users = hlo->users().empty() ? 0 : hlo->users().size() - 1;
467     int64_t logical_buffer_size = SumLogicalBufferSizes(
468         points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
469     stats.total_sizes = logical_buffer_size;
470     cumulative_total_size += logical_buffer_size;
471     absl::flat_hash_set<const HloInstruction*> unique_operands(
472         hlo->operands().begin(), hlo->operands().end());
473     for (const HloInstruction* operand : unique_operands) {
474       auto& operand_stats = stats_map.at(operand);
475       stats.extra_users += operand_stats.extra_users;
476       stats.total_sizes += operand_stats.total_sizes;
477     }
478     // stats.total_sizes transitively includes the sizes of all nodes that
479     // lead to it. But computation is a DAG, so we are double-counting nodes,
480     // which can lead to overflows for large programs.
481     // cumulative_total_size caps the size to prevent overflows.
482     // Same for total_hlos: it prevents overflows on very large and branchy
483     // models, where the number of paths is exponential to the number of nodes.
484     // NOTE(dimvar): this is quite ugly and should be changed. It's unclear
485     // why we care about transitive sizes; when scheduling a node, its input
486     // and output buffers should be all that matters, not its "history".
487     stats.total_sizes = std::min(stats.total_sizes, cumulative_total_size);
488     stats.extra_users = std::min(stats.extra_users, total_hlos);
489   }
490   CHECK_EQ(stats_map.size(), computation->instruction_count());
491 
492   // Construct a total order based on DFS post-order, visiting operands in
493   // decreasing cumulative extra user order, and next by cumulative size, with a
494   // tiebreaker by name for determinism.
495   HloInstructionSequence sequence;
496   FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
497     sequence.push_back(hlo);
498     return OkStatus();
499   });
500   visitor.ReserveVisitStates(computation->instruction_count());
501   TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder(
502       &visitor, [&stats_map](const HloInstruction* a, const HloInstruction* b) {
503         auto& stats_a = stats_map.at(a);
504         auto& stats_b = stats_map.at(b);
505         if (stats_a.extra_users != stats_b.extra_users) {
506           return stats_a.extra_users > stats_b.extra_users;
507         }
508         if (stats_a.total_sizes != stats_b.total_sizes) {
509           return stats_a.total_sizes > stats_b.total_sizes;
510         }
511         return a->name() < b->name();
512       }));
513   if (postprocessor) {
514     sequence = postprocessor(sequence);
515   }
516   CHECK_EQ(sequence.size(), computation->instruction_count());
517   if (peak_memory) {
518     TF_ASSIGN_OR_RETURN(
519         *peak_memory, HeapSimulator::MinimumMemoryForComputation(
520                           *computation, sequence, alias_analysis, size_function,
521                           &memory_by_computation));
522   }
523   return sequence;
524 }  // namespace xla
525 
ComputationSchedulerToModuleScheduler(const MemorySchedulerAlgorithm & computation_scheduler,const MemorySchedulerPostprocessor & postprocessor)526 ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler(
527     const MemorySchedulerAlgorithm& computation_scheduler,
528     const MemorySchedulerPostprocessor& postprocessor) {
529   return [computation_scheduler, postprocessor](
530              const HloModule* module,
531              const TuplePointsToAnalysis& points_to_analysis,
532              const HloAliasAnalysis& alias_analysis,
533              const LogicalBuffer::SizeFunction& size_func,
534              const absl::flat_hash_set<absl::string_view>& execution_threads,
535              int64_t* peak_memory) -> StatusOr<HloSchedule> {
536     HloSchedule schedule(module);
537     absl::flat_hash_map<const HloComputation*, int64_t> memory_by_computation;
538     for (auto* computation :
539          module->MakeComputationPostOrder(execution_threads)) {
540       if (!computation->IsFusionComputation()) {
541         TF_ASSIGN_OR_RETURN(
542             HloInstructionSequence computation_sequence,
543             ScheduleComputationHelper(
544                 computation, points_to_analysis, alias_analysis, size_func,
545                 computation_scheduler, memory_by_computation, postprocessor,
546                 /*peak_memory=*/nullptr));
547         schedule.set_sequence(computation, std::move(computation_sequence));
548       }
549     }
550     if (peak_memory) {
551       TF_ASSIGN_OR_RETURN(*peak_memory, HeapSimulator::MinimumMemoryForModule(
552                                             schedule, size_func));
553     }
554     return std::move(schedule);
555   };
556 }
557 
ListMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation,const MemorySchedulerPostprocessor & postprocessor,int64_t * peak_memory)558 StatusOr<HloInstructionSequence> ListMemoryScheduler(
559     HloComputation* computation,
560     const TuplePointsToAnalysis& points_to_analysis,
561     const HloAliasAnalysis& alias_analysis,
562     const BufferValue::SizeFunction& size_function,
563     const absl::flat_hash_map<const HloComputation*, int64_t>&
564         memory_by_computation,
565     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
566   TF_ASSIGN_OR_RETURN(HloInstructionSequence sequence,
567                       ListScheduler::Run(computation, points_to_analysis,
568                                          size_function, memory_by_computation));
569   if (postprocessor) {
570     sequence = postprocessor(sequence);
571   }
572   if (peak_memory) {
573     TF_ASSIGN_OR_RETURN(
574         *peak_memory, HeapSimulator::MinimumMemoryForComputation(
575                           *computation, sequence, alias_analysis, size_function,
576                           &memory_by_computation));
577   }
578   return sequence;
579 }
580 
PostOrderMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation,const MemorySchedulerPostprocessor & postprocessor,int64_t * peak_memory)581 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
582     HloComputation* computation,
583     const TuplePointsToAnalysis& points_to_analysis,
584     const HloAliasAnalysis& alias_analysis,
585     const BufferValue::SizeFunction& size_function,
586     const absl::flat_hash_map<const HloComputation*, int64_t>&
587         memory_by_computation,
588     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
589   HloInstructionSequence sequence(computation->MakeInstructionPostOrder());
590   if (postprocessor) {
591     sequence = postprocessor(sequence);
592   }
593   if (peak_memory) {
594     TF_ASSIGN_OR_RETURN(
595         *peak_memory, HeapSimulator::MinimumMemoryForComputation(
596                           *computation, sequence, alias_analysis, size_function,
597                           &memory_by_computation));
598   }
599   return sequence;
600 }
601 
DefaultMemoryScheduler(HloComputation * computation,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_map<const HloComputation *,int64_t> & memory_by_computation,const MemorySchedulerPostprocessor & postprocessor,int64_t * peak_memory)602 StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
603     HloComputation* computation,
604     const TuplePointsToAnalysis& points_to_analysis,
605     const HloAliasAnalysis& alias_analysis,
606     const BufferValue::SizeFunction& size_function,
607     const absl::flat_hash_map<const HloComputation*, int64_t>&
608         memory_by_computation,
609     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) {
610   // We try a few schedulers and choose whichever returns a lower min-memory,
611   // not accounting for fragmentation.
612   // - List is a scheduler that uses greedy heuristics.
613   // - DFS visits HLOs in postorder, with a heuristic to decide the order of
614   //   children.
615   // - Postorder does not use any heuristics.
616   // List wins for most of our benchmarks; postorder-based schedulers win for
617   // some RNNs.
618   int64_t list_memory;
619   TF_ASSIGN_OR_RETURN(
620       HloInstructionSequence list_sequence,
621       ListMemoryScheduler(computation, points_to_analysis, alias_analysis,
622                           size_function, memory_by_computation, postprocessor,
623                           &list_memory));
624   VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
625 
626   int64_t dfs_memory;
627   TF_ASSIGN_OR_RETURN(
628       HloInstructionSequence dfs_sequence,
629       DFSMemoryScheduler(computation, points_to_analysis, alias_analysis,
630                          size_function, memory_by_computation, postprocessor,
631                          &dfs_memory));
632   VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
633 
634   int64_t post_order_memory;
635   TF_ASSIGN_OR_RETURN(
636       HloInstructionSequence post_order_sequence,
637       PostOrderMemoryScheduler(computation, points_to_analysis, alias_analysis,
638                                size_function, memory_by_computation,
639                                postprocessor, &post_order_memory));
640   VLOG(2) << "Min-memory post order sequence: "
641           << HumanReadableNumBytes(post_order_memory);
642 
643   auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
644   if (peak_memory) {
645     *peak_memory = min_memory;
646   }
647 
648   if (min_memory == list_memory) {
649     VLOG(2) << "Chose min-memory list sequence: "
650             << HumanReadableNumBytes(list_memory);
651     return list_sequence;
652   } else if (min_memory == dfs_memory) {
653     VLOG(2) << "Chose min-memory dfs sequence: "
654             << HumanReadableNumBytes(dfs_memory);
655     return dfs_sequence;
656   } else {
657     VLOG(2) << "Chose min-memory post_order sequence: "
658             << HumanReadableNumBytes(post_order_memory);
659     return post_order_sequence;
660   }
661 }
662 
DefaultModuleScheduler(const HloModule * module,const TuplePointsToAnalysis & points_to_analysis,const HloAliasAnalysis & alias_analysis,const BufferValue::SizeFunction & size_function,const absl::flat_hash_set<absl::string_view> & execution_threads,int64_t * peak_memory)663 StatusOr<HloSchedule> DefaultModuleScheduler(
664     const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
665     const HloAliasAnalysis& alias_analysis,
666     const BufferValue::SizeFunction& size_function,
667     const absl::flat_hash_set<absl::string_view>& execution_threads,
668     int64_t* peak_memory) {
669   // We try a few schedulers and choose whichever returns a lower min-memory,
670   // not accounting for fragmentation.
671   // - List is a scheduler that uses greedy heuristics.
672   // - DFS visits HLOs in postorder, with a heuristic to decide the order of
673   //   children.
674   // - Postorder does not use any heuristics.
675   // List wins for most of our benchmarks; postorder-based schedulers win for
676   // some RNNs.
677   int64_t list_memory;
678   TF_ASSIGN_OR_RETURN(
679       HloSchedule list_sequence,
680       ComputationSchedulerToModuleScheduler(ListMemoryScheduler, {})(
681           module, points_to_analysis, alias_analysis, size_function,
682           execution_threads, &list_memory));
683 
684   VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
685 
686   int64_t dfs_memory;
687   TF_ASSIGN_OR_RETURN(
688       HloSchedule dfs_sequence,
689       ComputationSchedulerToModuleScheduler(DFSMemoryScheduler, {})(
690           module, points_to_analysis, alias_analysis, size_function,
691           execution_threads, &dfs_memory));
692   VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
693 
694   int64_t post_order_memory;
695   TF_ASSIGN_OR_RETURN(
696       HloSchedule post_order_sequence,
697       ComputationSchedulerToModuleScheduler(PostOrderMemoryScheduler, {})(
698           module, points_to_analysis, alias_analysis, size_function,
699           execution_threads, &post_order_memory));
700   VLOG(2) << "Min-memory post order sequence: "
701           << HumanReadableNumBytes(post_order_memory);
702 
703   auto min_memory = std::min({dfs_memory, post_order_memory, list_memory});
704   if (peak_memory) {
705     *peak_memory = min_memory;
706   }
707 
708   if (min_memory == list_memory) {
709     VLOG(2) << "Chose min-memory list sequence: "
710             << HumanReadableNumBytes(list_memory);
711     return list_sequence;
712   } else if (min_memory == dfs_memory) {
713     VLOG(2) << "Chose min-memory dfs sequence: "
714             << HumanReadableNumBytes(dfs_memory);
715     return dfs_sequence;
716   } else {
717     VLOG(2) << "Chose min-memory post_order sequence: "
718             << HumanReadableNumBytes(post_order_memory);
719     return post_order_sequence;
720   }
721 }
722 
ScheduleModule(const HloModule * module,const BufferValue::SizeFunction & size_function,const ModuleSchedulerAlgorithm & algorithm,const absl::flat_hash_set<absl::string_view> & execution_threads,int64_t * peak_memory)723 StatusOr<HloSchedule> ScheduleModule(
724     const HloModule* module, const BufferValue::SizeFunction& size_function,
725     const ModuleSchedulerAlgorithm& algorithm,
726     const absl::flat_hash_set<absl::string_view>& execution_threads,
727     int64_t* peak_memory) {
728   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
729                       TuplePointsToAnalysis::Run(module));
730   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
731                       HloAliasAnalysis::Run(module));
732 
733   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
734                       (algorithm ? algorithm : DefaultModuleScheduler)(
735                           module, *points_to_analysis, *alias_analysis,
736                           size_function, execution_threads, peak_memory));
737 
738   TF_RETURN_IF_ERROR(schedule.Verify());
739 
740   return std::move(schedule);
741 }
742 
ScheduleComputation(HloComputation * computation,const BufferValue::SizeFunction & size_function,const MemorySchedulerPostprocessor & postprocessor)743 StatusOr<HloInstructionSequence> ScheduleComputation(
744     HloComputation* computation, const BufferValue::SizeFunction& size_function,
745     const MemorySchedulerPostprocessor& postprocessor) {
746   CHECK(!computation->IsFusionComputation());
747   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
748                       TuplePointsToAnalysis::Run(computation->parent()));
749   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
750                       HloAliasAnalysis::Run(computation->parent()));
751   absl::flat_hash_map<const HloComputation*, int64_t> empty_map;
752   return ScheduleComputationHelper(
753       computation, *points_to_analysis, *alias_analysis, size_function,
754       /*algorithm=*/nullptr, empty_map, postprocessor,
755       /*peak_memory=*/nullptr);
756 }
757 
HloMemoryScheduler(const BufferValue::SizeFunction & size_function,const ModuleSchedulerAlgorithm & algorithm)758 HloMemoryScheduler::HloMemoryScheduler(
759     const BufferValue::SizeFunction& size_function,
760     const ModuleSchedulerAlgorithm& algorithm)
761     : size_function_(size_function), algorithm_(algorithm) {}
762 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)763 StatusOr<bool> HloMemoryScheduler::Run(
764     HloModule* module,
765     const absl::flat_hash_set<absl::string_view>& execution_threads) {
766   TF_ASSIGN_OR_RETURN(
767       HloSchedule schedule,
768       ScheduleModule(module, size_function_, algorithm_, execution_threads));
769   TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
770   return true;
771 }
772 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)773 StatusOr<bool> HloTrivialScheduler::Run(
774     HloModule* module,
775     const absl::flat_hash_set<absl::string_view>& execution_threads) {
776   HloSchedule schedule(module);
777   for (HloComputation* computation :
778        module->MakeComputationPostOrder(execution_threads)) {
779     if (!computation->IsFusionComputation()) {
780       HloInstructionSequence& computation_sequence =
781           schedule.GetOrCreateSequence(computation);
782       FunctionVisitor visitor(
783           [&computation_sequence](HloInstruction* instruction) {
784             computation_sequence.push_back(instruction);
785             return OkStatus();
786           });
787       visitor.ReserveVisitStates(computation->instruction_count());
788       TF_RETURN_IF_ERROR(computation->Accept(&visitor));
789     }
790   }
791   TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
792   return true;
793 }
794 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)795 StatusOr<bool> HloDescheduler::Run(
796     HloModule* module,
797     const absl::flat_hash_set<absl::string_view>& execution_threads) {
798   bool changed = module->has_schedule();
799   module->clear_schedule();
800   return changed;
801 }
802 
803 }  // namespace xla
804