xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_rematerialization.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3  Licensed under the Apache License, Version 2.0 (the "License");
4  you may not use this file except in compliance with the License.
5  You may obtain a copy of the License at
6 
7      http://www.apache.org/licenses/LICENSE-2.0
8 
9  Unless required by applicable law or agreed to in writing, software
10  distributed under the License is distributed on an "AS IS" BASIS,
11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  See the License for the specific language governing permissions and
13  limitations under the License.
14  ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/compiler/xla/service/call_graph.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
27 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
28 #include "tensorflow/compiler/xla/shape.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 
31 namespace xla {
32 
33 // HLO pass which rematerializes instructions to reduce peak memory use, where
34 // memory use is defined as the total size of all live HLO instruction
35 // values. Parameters and constants are included in memory use estimates.
36 //
37 // CSE will undo the effects of this optimization and should not be run after
38 // this pass. In general, this pass should be run very late, immediately before
39 // code generation.
40 class HloRematerialization : public HloModulePass {
41  public:
42   using ShapeSizeFunction = std::function<int64_t(const Shape&)>;
43 
44   using CompactShapeFunction = std::function<StatusOr<Shape>(const Shape&)>;
45 
46   // Helper struct that communicates the before / after sizes for the
47   // rematerialization process.
48   struct RematerializationSizes {
49     int64_t before_bytes = -1;
50     int64_t after_bytes = -1;
51   };
52 
53   // Mode in which the rematerialization algorithm should be run.
54   enum class RematerializationMode {
55     kRecomputeOnly,        // Only consider the kCompress RematStrategy.
56     kCompressOnly,         // Only consider the kRecompute RematStrategy.
57     kRecomputeAndCompress  // Consider both kRecompute and kRemat.
58   };
59 
60   // Enum to specify whether this rematerialization pass occurs before or after
61   // multi-output fusion.
62   enum class RematerializationPass {
63     kPreFusion,  // Rematerialization pass before multi-output fusion.
64     kPostFusion  // Rematerialization pass after multi-output fusion.
65   };
66 
DefaultCompactShapeFunction(const Shape & shape)67   static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; }
68 
69   // Constructor parameters:
70   //
71   //   size_function: Function which returns the size in bytes of the top-level
72   //     buffer of the given shape.
73   //
74   //   memory_limit_bytes: The threshold number of bytes to reduce memory use to
75   //     via rematerialization. Size of aliased outputs should be subtracted
76   //     from this.
77   //
78   //   sizes: Pointer to data structure which records the peak memory usage of
79   //     the HLO module before/after rematerialization. Value are set during
80   //     Run(). Can be nullptr.
81   //
82   //   compact_shape_function: Function which returns the compact form of a
83   //   shape. If nullptr is provided, an default identity function is used.
84   explicit HloRematerialization(
85       const ShapeSizeFunction& size_function, int64_t memory_limit_bytes,
86       RematerializationSizes* sizes, RematerializationPass pass_location,
87       int block_size_limit, int block_rematerialization_factor,
88       CompactShapeFunction compact_shape_function = nullptr,
89       RematerializationMode mode = RematerializationMode::kRecomputeAndCompress,
90       int64_t min_remat_size = 0)
size_function_(size_function)91       : size_function_(size_function),
92         memory_limit_bytes_(memory_limit_bytes),
93         sizes_(sizes),
94         pass_location_(pass_location),
95         block_size_limit_(block_size_limit),
96         block_rematerialization_factor_(block_rematerialization_factor),
97         compact_shape_function_(compact_shape_function == nullptr
98                                     ? DefaultCompactShapeFunction
99                                     : std::move(compact_shape_function)),
100         mode_(mode),
101         min_remat_size_(min_remat_size) {}
102   ~HloRematerialization() override = default;
103 
name()104   absl::string_view name() const override { return "rematerialization"; }
105 
106   // Get the next available channel id and increment count.
NextChannelId()107   int64_t NextChannelId() { return next_channel_id_++; }
108 
109   // Get the peak memory for the computation.
ComputationPeakMemory(const HloComputation * computation)110   int64_t ComputationPeakMemory(const HloComputation* computation) const {
111     return computation_peak_memory_.at(computation);
112   }
113 
114   // Runs rematerialization on the given module. Returns whether the module was
115   // changed. Requires that the module has a schedule set
116   // (HloModule::has_schedule() is true) before running. Returns whether any
117   // instructions were rematerialized. If memory use is already below the limit
118   // specified in the constructor then no instructions are rematerialized and
119   // false is returned.
120   using HloPassInterface::Run;
121   StatusOr<bool> Run(
122       HloModule* module,
123       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
124 
125  protected:
126   // Rematerializes instructions within the given computation. 'order' is the
127   // order in which the computation's instructions will be emitted in the
128   // backend. Rematerialized instructions will be added to the HLO computation
129   // and inserted into 'order'.
RematerializeComputation(HloComputation * computation,HloSchedule * schedule,int64_t memory_limit_bytes,int64_t min_remat_size)130   StatusOr<bool> RematerializeComputation(HloComputation* computation,
131                                           HloSchedule* schedule,
132                                           int64_t memory_limit_bytes,
133                                           int64_t min_remat_size) {
134     return RematerializeComputation(computation, schedule, memory_limit_bytes,
135                                     min_remat_size, /*execution_threads=*/{});
136   }
137 
138   virtual StatusOr<bool> RematerializeComputation(
139       HloComputation* computation, HloSchedule* schedule,
140       int64_t memory_limit_bytes, int64_t min_remat_size,
141       const absl::flat_hash_set<absl::string_view>& execution_threads);
142 
143   // Computes and returns the peak memory used by the given computation. The
144   // peak memory is the maximum total size of all live HLO instruction values at
145   // any program point. 'order' is the order in which the HLO instructions will
146   // be emitted which is used to determine lifespans of HLO values.
147   StatusOr<int64_t> ComputePeakMemory(
148       const HloComputation* computation, const HloInstructionSequence& order,
149       const absl::flat_hash_set<absl::string_view>& execution_threads) const;
150 
151   // Returns the peak memory usage of the called computations for the given
152   // instruction. Zero is returned if the instruction calls no computations.
153   StatusOr<int64_t> CalledComputationsMemoryUsage(
154       const HloInstruction* instruction,
155       const absl::flat_hash_set<absl::string_view>& execution_threads) const;
156 
157   // Returns true if `thread` is considered included within given
158   // `execution_threads`.
159   bool IsExecutionThreadIncluded(
160       const absl::flat_hash_set<absl::string_view>& execution_threads,
161       absl::string_view thread) const;
162 
163   // Selects an algorithm to use for HLO scheduling.
164   MemorySchedulerAlgorithm scheduler_algorithm_;
165 
166   // Function which computes the size of the top-level buffer of a shape.
167   const ShapeSizeFunction size_function_;
168 
169   // The threshold number of bytes to reduce memory use to via
170   // rematerialization.
171   const int64_t memory_limit_bytes_;
172 
173   // Pointer to data structure which records the peak memory usage of the HLO
174   // module before/after rematerialization
175   RematerializationSizes* sizes_;
176 
177   // Specifies whether this rematerialization pass occurs before or after
178   // multi-output fusion.
179   RematerializationPass pass_location_;
180 
181   // Maximum number of consecutive instructions to consider for
182   // rematerialization.
183   int block_size_limit_;
184 
185   // Controls the amount of effort spent trying to find large blocks for
186   // rematerialization. Larger values leads to longer compilation times in
187   // return for potentially reduced memory consumption.
188   int block_rematerialization_factor_ = 1;
189 
190   // Converts a shape into compact form, returns the same shape if a shape is
191   // already considered compact.
192   const CompactShapeFunction compact_shape_function_;
193 
194   // Call graph of the hlo_module.
195   std::unique_ptr<CallGraph> call_graph_;
196 
197   // The peak memory usage of each computation. The map contains only those
198   // computations called from sequential context
199   // (CallContext::kSequential). These values are updated as rematerialization
200   // occurs.
201   absl::flat_hash_map<const HloComputation*, int64_t> computation_peak_memory_;
202 
203   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
204 
205   // Set of computations which have had rematerialization
206   // applied. Rematerialization is only applied once per computation.
207   absl::flat_hash_set<const HloComputation*> rematerialized_computations_;
208 
209   // Count of the total instructions rematerialized.
210   int64_t instructions_rematerialized_ = 0;
211 
212   // Count of the net instructions added to the HLO module by
213   // rematerialization. This can be different than instructions_rematerialized_
214   // because some rematerializations are effectively moves in the HLO
215   // schedule. In these cases, the rematerialization instruction replaces all
216   // uses of the original instruction and the original instruction is
217   // dead. Hence, no net instructions were added.
218   int64_t net_instructions_added_ = 0;
219 
220   // Size of the largest block that has been rematerialized. This is actually an
221   // upper bound (within a factor of 2) on the block size.
222   int max_rematerialized_block_size_ = 0;
223 
224   RematerializationMode mode_;
225 
226   int64_t min_remat_size_;
227 
228   // Tracking available channel id numbers to use to apply to rematerialized
229   // channel instructions
230   int64_t next_channel_id_;
231 };
232 
233 }  // namespace xla
234 
235 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
236