1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 17 18 #include "absl/container/flat_hash_map.h" 19 #include "absl/container/flat_hash_set.h" 20 #include "absl/strings/string_view.h" 21 #include "tensorflow/compiler/xla/service/call_graph.h" 22 #include "tensorflow/compiler/xla/service/hlo_computation.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 27 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 28 #include "tensorflow/compiler/xla/shape.h" 29 #include "tensorflow/compiler/xla/statusor.h" 30 31 namespace xla { 32 33 // HLO pass which rematerializes instructions to reduce peak memory use, where 34 // memory use is defined as the total size of all live HLO instruction 35 // values. Parameters and constants are included in memory use estimates. 36 // 37 // CSE will undo the effects of this optimization and should not be run after 38 // this pass. In general, this pass should be run very late, immediately before 39 // code generation. 40 class HloRematerialization : public HloModulePass { 41 public: 42 using ShapeSizeFunction = std::function<int64_t(const Shape&)>; 43 44 using CompactShapeFunction = std::function<StatusOr<Shape>(const Shape&)>; 45 46 // Helper struct that communicates the before / after sizes for the 47 // rematerialization process. 48 struct RematerializationSizes { 49 int64_t before_bytes = -1; 50 int64_t after_bytes = -1; 51 }; 52 53 // Mode in which the rematerialization algorithm should be run. 54 enum class RematerializationMode { 55 kRecomputeOnly, // Only consider the kCompress RematStrategy. 56 kCompressOnly, // Only consider the kRecompute RematStrategy. 57 kRecomputeAndCompress // Consider both kRecompute and kRemat. 58 }; 59 60 // Enum to specify whether this rematerialization pass occurs before or after 61 // multi-output fusion. 62 enum class RematerializationPass { 63 kPreFusion, // Rematerialization pass before multi-output fusion. 64 kPostFusion // Rematerialization pass after multi-output fusion. 65 }; 66 DefaultCompactShapeFunction(const Shape & shape)67 static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } 68 69 // Constructor parameters: 70 // 71 // size_function: Function which returns the size in bytes of the top-level 72 // buffer of the given shape. 73 // 74 // memory_limit_bytes: The threshold number of bytes to reduce memory use to 75 // via rematerialization. Size of aliased outputs should be subtracted 76 // from this. 77 // 78 // sizes: Pointer to data structure which records the peak memory usage of 79 // the HLO module before/after rematerialization. Value are set during 80 // Run(). Can be nullptr. 81 // 82 // compact_shape_function: Function which returns the compact form of a 83 // shape. If nullptr is provided, an default identity function is used. 84 explicit HloRematerialization( 85 const ShapeSizeFunction& size_function, int64_t memory_limit_bytes, 86 RematerializationSizes* sizes, RematerializationPass pass_location, 87 int block_size_limit, int block_rematerialization_factor, 88 CompactShapeFunction compact_shape_function = nullptr, 89 RematerializationMode mode = RematerializationMode::kRecomputeAndCompress, 90 int64_t min_remat_size = 0) size_function_(size_function)91 : size_function_(size_function), 92 memory_limit_bytes_(memory_limit_bytes), 93 sizes_(sizes), 94 pass_location_(pass_location), 95 block_size_limit_(block_size_limit), 96 block_rematerialization_factor_(block_rematerialization_factor), 97 compact_shape_function_(compact_shape_function == nullptr 98 ? DefaultCompactShapeFunction 99 : std::move(compact_shape_function)), 100 mode_(mode), 101 min_remat_size_(min_remat_size) {} 102 ~HloRematerialization() override = default; 103 name()104 absl::string_view name() const override { return "rematerialization"; } 105 106 // Get the next available channel id and increment count. NextChannelId()107 int64_t NextChannelId() { return next_channel_id_++; } 108 109 // Get the peak memory for the computation. ComputationPeakMemory(const HloComputation * computation)110 int64_t ComputationPeakMemory(const HloComputation* computation) const { 111 return computation_peak_memory_.at(computation); 112 } 113 114 // Runs rematerialization on the given module. Returns whether the module was 115 // changed. Requires that the module has a schedule set 116 // (HloModule::has_schedule() is true) before running. Returns whether any 117 // instructions were rematerialized. If memory use is already below the limit 118 // specified in the constructor then no instructions are rematerialized and 119 // false is returned. 120 using HloPassInterface::Run; 121 StatusOr<bool> Run( 122 HloModule* module, 123 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 124 125 protected: 126 // Rematerializes instructions within the given computation. 'order' is the 127 // order in which the computation's instructions will be emitted in the 128 // backend. Rematerialized instructions will be added to the HLO computation 129 // and inserted into 'order'. RematerializeComputation(HloComputation * computation,HloSchedule * schedule,int64_t memory_limit_bytes,int64_t min_remat_size)130 StatusOr<bool> RematerializeComputation(HloComputation* computation, 131 HloSchedule* schedule, 132 int64_t memory_limit_bytes, 133 int64_t min_remat_size) { 134 return RematerializeComputation(computation, schedule, memory_limit_bytes, 135 min_remat_size, /*execution_threads=*/{}); 136 } 137 138 virtual StatusOr<bool> RematerializeComputation( 139 HloComputation* computation, HloSchedule* schedule, 140 int64_t memory_limit_bytes, int64_t min_remat_size, 141 const absl::flat_hash_set<absl::string_view>& execution_threads); 142 143 // Computes and returns the peak memory used by the given computation. The 144 // peak memory is the maximum total size of all live HLO instruction values at 145 // any program point. 'order' is the order in which the HLO instructions will 146 // be emitted which is used to determine lifespans of HLO values. 147 StatusOr<int64_t> ComputePeakMemory( 148 const HloComputation* computation, const HloInstructionSequence& order, 149 const absl::flat_hash_set<absl::string_view>& execution_threads) const; 150 151 // Returns the peak memory usage of the called computations for the given 152 // instruction. Zero is returned if the instruction calls no computations. 153 StatusOr<int64_t> CalledComputationsMemoryUsage( 154 const HloInstruction* instruction, 155 const absl::flat_hash_set<absl::string_view>& execution_threads) const; 156 157 // Returns true if `thread` is considered included within given 158 // `execution_threads`. 159 bool IsExecutionThreadIncluded( 160 const absl::flat_hash_set<absl::string_view>& execution_threads, 161 absl::string_view thread) const; 162 163 // Selects an algorithm to use for HLO scheduling. 164 MemorySchedulerAlgorithm scheduler_algorithm_; 165 166 // Function which computes the size of the top-level buffer of a shape. 167 const ShapeSizeFunction size_function_; 168 169 // The threshold number of bytes to reduce memory use to via 170 // rematerialization. 171 const int64_t memory_limit_bytes_; 172 173 // Pointer to data structure which records the peak memory usage of the HLO 174 // module before/after rematerialization 175 RematerializationSizes* sizes_; 176 177 // Specifies whether this rematerialization pass occurs before or after 178 // multi-output fusion. 179 RematerializationPass pass_location_; 180 181 // Maximum number of consecutive instructions to consider for 182 // rematerialization. 183 int block_size_limit_; 184 185 // Controls the amount of effort spent trying to find large blocks for 186 // rematerialization. Larger values leads to longer compilation times in 187 // return for potentially reduced memory consumption. 188 int block_rematerialization_factor_ = 1; 189 190 // Converts a shape into compact form, returns the same shape if a shape is 191 // already considered compact. 192 const CompactShapeFunction compact_shape_function_; 193 194 // Call graph of the hlo_module. 195 std::unique_ptr<CallGraph> call_graph_; 196 197 // The peak memory usage of each computation. The map contains only those 198 // computations called from sequential context 199 // (CallContext::kSequential). These values are updated as rematerialization 200 // occurs. 201 absl::flat_hash_map<const HloComputation*, int64_t> computation_peak_memory_; 202 203 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; 204 205 // Set of computations which have had rematerialization 206 // applied. Rematerialization is only applied once per computation. 207 absl::flat_hash_set<const HloComputation*> rematerialized_computations_; 208 209 // Count of the total instructions rematerialized. 210 int64_t instructions_rematerialized_ = 0; 211 212 // Count of the net instructions added to the HLO module by 213 // rematerialization. This can be different than instructions_rematerialized_ 214 // because some rematerializations are effectively moves in the HLO 215 // schedule. In these cases, the rematerialization instruction replaces all 216 // uses of the original instruction and the original instruction is 217 // dead. Hence, no net instructions were added. 218 int64_t net_instructions_added_ = 0; 219 220 // Size of the largest block that has been rematerialized. This is actually an 221 // upper bound (within a factor of 2) on the block size. 222 int max_rematerialized_block_size_ = 0; 223 224 RematerializationMode mode_; 225 226 int64_t min_remat_size_; 227 228 // Tracking available channel id numbers to use to apply to rematerialized 229 // channel instructions 230 int64_t next_channel_id_; 231 }; 232 233 } // namespace xla 234 235 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ 236