1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ 18 19 #include <vector> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/service/hlo_module.h" 25 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 26 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 27 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 28 #include "tensorflow/compiler/xla/service/logical_buffer.h" 29 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 30 #include "tensorflow/compiler/xla/statusor.h" 31 #include "tensorflow/compiler/xla/types.h" 32 33 namespace xla { 34 35 // Postprocessor of the HloInstructionSequence. This is an opt-in postprocessing 36 // function to MemorySchedulerAlgorithm to enforce certain hlo schedule 37 // constraints desired for custom-calls. 38 using MemorySchedulerPostprocessor = 39 std::function<HloInstructionSequence(const HloInstructionSequence&)>; 40 41 // A memory scheduler computes an execution sequence for the HLO instructions in 42 // 'computation' that minimizes peak memory, given a points-to analysis result 43 // that describes buffer aliasing, together with a target-specific size function 44 // that maps a tensor's logical size to its padded size. peak_memory (may be 45 // nullptr) is set to the peak memory of the resulting schedule according to the 46 // HeapSimulator. 47 // 48 // TODO(yunxing): Cleanup usage of TuplePointsToAnalysis. 49 typedef std::function<StatusOr<HloInstructionSequence>( 50 HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, 51 const LogicalBuffer::SizeFunction&, 52 const absl::flat_hash_map<const HloComputation*, int64_t>&, 53 const MemorySchedulerPostprocessor&, 54 /*peak_memory*/ int64_t*)> 55 MemorySchedulerAlgorithm; 56 57 // Scheduler for the entire module. 58 typedef std::function<StatusOr<HloSchedule>( 59 const HloModule*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, 60 const LogicalBuffer::SizeFunction&, 61 const absl::flat_hash_set<absl::string_view>& execution_threads, 62 /*peak_memory*/ int64_t*)> 63 ModuleSchedulerAlgorithm; 64 65 // Lift a computation scheduler into a module scheduler by calling the 66 // computation scheduler on all computations in a module. 67 ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler( 68 const MemorySchedulerAlgorithm&, const MemorySchedulerPostprocessor& = {}); 69 70 // List scheduler 71 StatusOr<HloInstructionSequence> ListMemoryScheduler( 72 HloComputation* computation, 73 const TuplePointsToAnalysis& points_to_analysis, 74 const HloAliasAnalysis& alias_analysis, 75 const LogicalBuffer::SizeFunction& size_function, 76 const absl::flat_hash_map<const HloComputation*, int64_t>& 77 memory_by_computation, 78 const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); 79 80 // DFS-order scheduler 81 StatusOr<HloInstructionSequence> DFSMemoryScheduler( 82 HloComputation* computation, 83 const TuplePointsToAnalysis& points_to_analysis, 84 const HloAliasAnalysis& alias_analysis, 85 const LogicalBuffer::SizeFunction& size_function, 86 const absl::flat_hash_map<const HloComputation*, int64_t>& 87 memory_by_computation, 88 const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); 89 90 // Naive Post Order scheduler 91 StatusOr<HloInstructionSequence> PostOrderMemoryScheduler( 92 HloComputation* computation, 93 const TuplePointsToAnalysis& points_to_analysis, 94 const HloAliasAnalysis& alias_analysis, 95 const LogicalBuffer::SizeFunction& size_function, 96 const absl::flat_hash_map<const HloComputation*, int64_t>& 97 memory_by_computation, 98 const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); 99 100 // The default scheduling algorithm. Runs the list scheduler, the DFS scheduler, 101 // and the post-order scheduler and chooses whichever returns a lower min- 102 // memory, not accounting for fragmentation. peak_memory (may be nullptr) is set 103 // to the peak memory of the resulting schedule according to the HeapSimulator. 104 StatusOr<HloInstructionSequence> DefaultMemoryScheduler( 105 HloComputation* computation, 106 const TuplePointsToAnalysis& points_to_analysis, 107 const HloAliasAnalysis& alias_analysis, 108 const LogicalBuffer::SizeFunction& size_function, 109 const absl::flat_hash_map<const HloComputation*, int64_t>& 110 memory_by_computation, 111 const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); 112 113 StatusOr<HloSchedule> DefaultModuleScheduler( 114 const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, 115 const HloAliasAnalysis& alias_analysis, 116 const LogicalBuffer::SizeFunction& size_function, 117 const absl::flat_hash_set<absl::string_view>& execution_threads, 118 int64_t* peak_memory); 119 120 // Returns an HloSchedule which seeks to minimize the memory required for the 121 // module. size_function is the function returning the number of bytes required 122 // for a LogicalBuffer. peak_memory (if not nullptr) is set to the largest peak 123 // memory (according to the HeapSimulator) of all computations in the module. 124 StatusOr<HloSchedule> ScheduleModule( 125 const HloModule* module, const LogicalBuffer::SizeFunction& size_function, 126 const ModuleSchedulerAlgorithm& algorithm = {}, 127 const absl::flat_hash_set<absl::string_view>& execution_threads = {}, 128 int64_t* peak_memory = nullptr); 129 130 // Computes the schedule for a single computation. 131 // Currently only used by the GPU backend. 132 StatusOr<HloInstructionSequence> ScheduleComputation( 133 HloComputation* computation, 134 const LogicalBuffer::SizeFunction& size_function, 135 const MemorySchedulerPostprocessor& postprocessor); 136 137 // A pass which schedules the HLO instructions in a module. The HloModule's 138 // schedule field is set to the resulting HloSchedule using 139 // HloModule::set_schedule. 140 class HloMemoryScheduler : public HloModulePass { 141 public: 142 // size_function is the function returning the number of bytes required for a 143 // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not 144 // specified, then DefaultMemoryScheduler is used. 145 HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, 146 const ModuleSchedulerAlgorithm& algorithm = {}); 147 148 ~HloMemoryScheduler() override = default; 149 name()150 absl::string_view name() const override { return "hlo-memory-scheduler"; } 151 152 using HloPassInterface::Run; 153 StatusOr<bool> Run( 154 HloModule* module, 155 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 156 157 private: 158 LogicalBuffer::SizeFunction size_function_; 159 160 ModuleSchedulerAlgorithm algorithm_; 161 }; 162 163 // A pass which produces a naive, but correct schedule. The schedule is produced 164 // using a DFS traversal of the graph with no attempt to minimize memory use. 165 class HloTrivialScheduler : public HloModulePass { 166 public: name()167 absl::string_view name() const override { return "hlo-trivial-scheduler"; } 168 169 using HloPassInterface::Run; 170 StatusOr<bool> Run( 171 HloModule* module, 172 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 173 }; 174 175 // A trivial pass which clears the schedule currently set on the 176 // HloModule. After this pass runs HloModule::has_schedule will return false. 177 class HloDescheduler : public HloModulePass { 178 public: 179 HloDescheduler() = default; 180 ~HloDescheduler() override = default; name()181 absl::string_view name() const override { return "hlo-descheduler"; } 182 183 using HloPassInterface::Run; 184 StatusOr<bool> Run( 185 HloModule* module, 186 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 187 }; 188 189 } // namespace xla 190 191 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ 192