xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_memory_scheduler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
18 
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
26 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
27 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
28 #include "tensorflow/compiler/xla/service/logical_buffer.h"
29 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 
33 namespace xla {
34 
35 // Postprocessor of the HloInstructionSequence. This is an opt-in postprocessing
36 // function to MemorySchedulerAlgorithm to enforce certain hlo schedule
37 // constraints desired for custom-calls.
38 using MemorySchedulerPostprocessor =
39     std::function<HloInstructionSequence(const HloInstructionSequence&)>;
40 
41 // A memory scheduler computes an execution sequence for the HLO instructions in
42 // 'computation' that minimizes peak memory, given a points-to analysis result
43 // that describes buffer aliasing, together with a target-specific size function
44 // that maps a tensor's logical size to its padded size. peak_memory (may be
45 // nullptr) is set to the peak memory of the resulting schedule according to the
46 // HeapSimulator.
47 //
48 // TODO(yunxing): Cleanup usage of TuplePointsToAnalysis.
49 typedef std::function<StatusOr<HloInstructionSequence>(
50     HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&,
51     const LogicalBuffer::SizeFunction&,
52     const absl::flat_hash_map<const HloComputation*, int64_t>&,
53     const MemorySchedulerPostprocessor&,
54     /*peak_memory*/ int64_t*)>
55     MemorySchedulerAlgorithm;
56 
57 // Scheduler for the entire module.
58 typedef std::function<StatusOr<HloSchedule>(
59     const HloModule*, const TuplePointsToAnalysis&, const HloAliasAnalysis&,
60     const LogicalBuffer::SizeFunction&,
61     const absl::flat_hash_set<absl::string_view>& execution_threads,
62     /*peak_memory*/ int64_t*)>
63     ModuleSchedulerAlgorithm;
64 
65 // Lift a computation scheduler into a module scheduler by calling the
66 // computation scheduler on all computations in a module.
67 ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler(
68     const MemorySchedulerAlgorithm&, const MemorySchedulerPostprocessor& = {});
69 
70 // List scheduler
71 StatusOr<HloInstructionSequence> ListMemoryScheduler(
72     HloComputation* computation,
73     const TuplePointsToAnalysis& points_to_analysis,
74     const HloAliasAnalysis& alias_analysis,
75     const LogicalBuffer::SizeFunction& size_function,
76     const absl::flat_hash_map<const HloComputation*, int64_t>&
77         memory_by_computation,
78     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
79 
80 // DFS-order scheduler
81 StatusOr<HloInstructionSequence> DFSMemoryScheduler(
82     HloComputation* computation,
83     const TuplePointsToAnalysis& points_to_analysis,
84     const HloAliasAnalysis& alias_analysis,
85     const LogicalBuffer::SizeFunction& size_function,
86     const absl::flat_hash_map<const HloComputation*, int64_t>&
87         memory_by_computation,
88     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
89 
90 // Naive Post Order scheduler
91 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
92     HloComputation* computation,
93     const TuplePointsToAnalysis& points_to_analysis,
94     const HloAliasAnalysis& alias_analysis,
95     const LogicalBuffer::SizeFunction& size_function,
96     const absl::flat_hash_map<const HloComputation*, int64_t>&
97         memory_by_computation,
98     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
99 
100 // The default scheduling algorithm. Runs the list scheduler, the DFS scheduler,
101 // and the post-order scheduler and chooses whichever returns a lower min-
102 // memory, not accounting for fragmentation. peak_memory (may be nullptr) is set
103 // to the peak memory of the resulting schedule according to the HeapSimulator.
104 StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
105     HloComputation* computation,
106     const TuplePointsToAnalysis& points_to_analysis,
107     const HloAliasAnalysis& alias_analysis,
108     const LogicalBuffer::SizeFunction& size_function,
109     const absl::flat_hash_map<const HloComputation*, int64_t>&
110         memory_by_computation,
111     const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory);
112 
113 StatusOr<HloSchedule> DefaultModuleScheduler(
114     const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
115     const HloAliasAnalysis& alias_analysis,
116     const LogicalBuffer::SizeFunction& size_function,
117     const absl::flat_hash_set<absl::string_view>& execution_threads,
118     int64_t* peak_memory);
119 
120 // Returns an HloSchedule which seeks to minimize the memory required for the
121 // module. size_function is the function returning the number of bytes required
122 // for a LogicalBuffer. peak_memory (if not nullptr) is set to the largest peak
123 // memory (according to the HeapSimulator) of all computations in the module.
124 StatusOr<HloSchedule> ScheduleModule(
125     const HloModule* module, const LogicalBuffer::SizeFunction& size_function,
126     const ModuleSchedulerAlgorithm& algorithm = {},
127     const absl::flat_hash_set<absl::string_view>& execution_threads = {},
128     int64_t* peak_memory = nullptr);
129 
130 // Computes the schedule for a single computation.
131 // Currently only used by the GPU backend.
132 StatusOr<HloInstructionSequence> ScheduleComputation(
133     HloComputation* computation,
134     const LogicalBuffer::SizeFunction& size_function,
135     const MemorySchedulerPostprocessor& postprocessor);
136 
137 // A pass which schedules the HLO instructions in a module. The HloModule's
138 // schedule field is set to the resulting HloSchedule using
139 // HloModule::set_schedule.
140 class HloMemoryScheduler : public HloModulePass {
141  public:
142   // size_function is the function returning the number of bytes required for a
143   // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
144   // specified, then DefaultMemoryScheduler is used.
145   HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function,
146                      const ModuleSchedulerAlgorithm& algorithm = {});
147 
148   ~HloMemoryScheduler() override = default;
149 
name()150   absl::string_view name() const override { return "hlo-memory-scheduler"; }
151 
152   using HloPassInterface::Run;
153   StatusOr<bool> Run(
154       HloModule* module,
155       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
156 
157  private:
158   LogicalBuffer::SizeFunction size_function_;
159 
160   ModuleSchedulerAlgorithm algorithm_;
161 };
162 
163 // A pass which produces a naive, but correct schedule. The schedule is produced
164 // using a DFS traversal of the graph with no attempt to minimize memory use.
165 class HloTrivialScheduler : public HloModulePass {
166  public:
name()167   absl::string_view name() const override { return "hlo-trivial-scheduler"; }
168 
169   using HloPassInterface::Run;
170   StatusOr<bool> Run(
171       HloModule* module,
172       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
173 };
174 
175 // A trivial pass which clears the schedule currently set on the
176 // HloModule. After this pass runs HloModule::has_schedule will return false.
177 class HloDescheduler : public HloModulePass {
178  public:
179   HloDescheduler() = default;
180   ~HloDescheduler() override = default;
name()181   absl::string_view name() const override { return "hlo-descheduler"; }
182 
183   using HloPassInterface::Run;
184   StatusOr<bool> Run(
185       HloModule* module,
186       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
187 };
188 
189 }  // namespace xla
190 
191 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
192