xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_cost_analysis.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
18 
19 #include <functional>
20 #include <string>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 namespace xla {
33 
34 // HloCostAnalysis traverses an HLO graph and calculates the amount of
35 // computations required for the graph. Each HLO instruction handler provides
36 // the computation cost of the instruction, and the values are accumulated
37 // during the traversal for the entire graph. We treat normal floating point
38 // operations separately from transcendental operations.
39 class HloCostAnalysis : public ConstDfsHloVisitor {
40  public:
41   // Each HLO is associated to a vector of properties with the indices given
42   // below. Sub-classes can add further properties.
43   // MSVC 14.0 limitation requires the consts.
44   typedef std::map<std::string, float, std::less<>> Properties;
45   // shape_size is a function which returns the size in bytes of the top-level
46   // buffer of a shape.
47   using ShapeSizeFunction = std::function<int64_t(const Shape&)>;
48 
49   static constexpr const char kFlopsKey[] = "flops";
50   static constexpr const char kTranscendentalsKey[] = "transcendentals";
51   static constexpr const char kBytesAccessedKey[] = "bytes accessed";
52   static constexpr const char kOptimalSecondsKey[] = "optimal_seconds";
53 
54   // A struct to encapsulate hardware-related options. This includes the shape
55   // size function, which is used to encode hardware-specific padding and per
56   // second rates of FLOPs, bytes per second (available bandwidth), and
57   // transcendentals per second.
58   struct Options {
59     // Function which computes the size of the top-level of a given shape (not
60     // including nested elements, if any). If null then bytes_accessed methods
61     // return an error.
62     ShapeSizeFunction shape_size;
63     // How much of each property can be processed per second. E.g. if the
64     // property is bytes accessed, this is the number of bytes that can be
65     // processed per second. Is empty if no rates have been set.
66     Properties per_second_rates = {};
67 
68     // Set the rates used to calculate the time taken by the computation.
set_flops_per_secondOptions69     void set_flops_per_second(float value) {
70       per_second_rates[kFlopsKey] = value;
71     }
set_transcendentals_per_secondOptions72     void set_transcendentals_per_second(float value) {
73       per_second_rates[kTranscendentalsKey] = value;
74     }
set_bytes_per_secondOptions75     void set_bytes_per_second(float value) {
76       per_second_rates[kBytesAccessedKey] = value;
77     }
78 
79     // Returns the specified per-second rate used by cost analysis.
per_second_rateOptions80     const float per_second_rate(const std::string& key) const {
81       return GetProperty(key, per_second_rates);
82     }
83   };
84 
85   explicit HloCostAnalysis(const Options& options);
86   explicit HloCostAnalysis(ShapeSizeFunction shape_size,
87                            const Properties& per_second_rates = {});
88 
89   Status HandleElementwiseUnary(const HloInstruction* hlo) override;
90   Status HandleElementwiseBinary(const HloInstruction* hlo) override;
91   Status HandleConstant(const HloInstruction* constant) override;
92   Status HandleIota(const HloInstruction* iota) override;
93   Status HandleGetTupleElement(
94       const HloInstruction* get_tuple_element) override;
95   Status HandleSelect(const HloInstruction* hlo) override;
96   Status HandleCompare(const HloInstruction* compare) override;
97   Status HandleClamp(const HloInstruction* clamp) override;
98   Status HandleReducePrecision(const HloInstruction* hlo) override;
99   Status HandleConcatenate(const HloInstruction* concatenate) override;
100   Status HandleAsyncStart(const HloInstruction* async_start) override;
101   Status HandleAsyncUpdate(const HloInstruction* async_update) override;
102   Status HandleAsyncDone(const HloInstruction* async_done) override;
103   Status HandleCopyStart(const HloInstruction* send) override;
104   Status HandleCopyDone(const HloInstruction* send_done) override;
105   Status HandleSend(const HloInstruction* send) override;
106   Status HandleSendDone(const HloInstruction* send_done) override;
107   Status HandleRecv(const HloInstruction* recv) override;
108   Status HandleRecvDone(const HloInstruction* recv_done) override;
109   Status HandleConvert(const HloInstruction* convert) override;
110   Status HandleCopy(const HloInstruction* copy) override;
111   Status HandleDomain(const HloInstruction* domain) override;
112   Status HandleDot(const HloInstruction* dot) override;
113   Status HandleConvolution(const HloInstruction* convolution) override;
114   Status HandleFft(const HloInstruction* fft) override;
115   Status HandleTriangularSolve(const HloInstruction* hlo) override;
116   Status HandleCholesky(const HloInstruction* hlo) override;
117   Status HandleOptimizationBarrier(const HloInstruction* hlo) override;
118   Status HandleAllGather(const HloInstruction* hlo) override;
119   Status HandleAllGatherStart(const HloInstruction* hlo) override;
120   Status HandleAllGatherDone(const HloInstruction* hlo) override;
121   Status HandleAllReduce(const HloInstruction* crs) override;
122   Status HandleReduceScatter(const HloInstruction* hlo) override;
123   Status HandleAllReduceStart(const HloInstruction* hlo) override;
124   Status HandleAllReduceDone(const HloInstruction* hlo) override;
125   Status HandleAllToAll(const HloInstruction* hlo) override;
126   Status HandleCollectivePermute(const HloInstruction* hlo) override;
127   Status HandleCollectivePermuteStart(const HloInstruction* hlo) override;
128   Status HandleCollectivePermuteDone(const HloInstruction* hlo) override;
129   Status HandleReplicaId(const HloInstruction* hlo) override;
130   Status HandlePartitionId(const HloInstruction* hlo) override;
131   Status HandleInfeed(const HloInstruction* infeed) override;
132   Status HandleOutfeed(const HloInstruction* outfeed) override;
133   Status HandleRng(const HloInstruction* random) override;
134   Status HandleRngBitGenerator(const HloInstruction* random) override;
135   Status HandleRngGetAndUpdateState(const HloInstruction* random) override;
136   Status HandleReverse(const HloInstruction* reverse) override;
137   Status HandleSort(const HloInstruction* sort) override;
138   Status HandleParameter(const HloInstruction* parameter) override;
139   Status HandleReduce(const HloInstruction* reduce) override;
140   Status HandleBatchNormTraining(
141       const HloInstruction* batch_norm_training) override;
142   Status HandleBatchNormInference(
143       const HloInstruction* batch_norm_inference) override;
144   Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override;
145   Status HandleFusion(const HloInstruction* fusion) override;
146   Status HandleCall(const HloInstruction* call) override;
147   Status HandleCustomCall(const HloInstruction* custom_call) override;
148   Status HandleSlice(const HloInstruction* slice) override;
149   Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override;
150   Status HandleDynamicUpdateSlice(
151       const HloInstruction* dynamic_update_slice) override;
152   Status HandleTuple(const HloInstruction* tuple) override;
153   Status HandleMap(const HloInstruction* map) override;
154   Status HandleReduceWindow(const HloInstruction* reduce_window) override;
155   Status HandleSelectAndScatter(const HloInstruction* instruction) override;
156   Status HandleBitcast(const HloInstruction* bitcast) override;
157   Status HandleBroadcast(const HloInstruction* broadcast) override;
158   Status HandlePad(const HloInstruction* pad) override;
159   Status HandleReshape(const HloInstruction* reshape) override;
160   Status HandleDynamicReshape(const HloInstruction* reshape) override;
161   Status HandleAddDependency(const HloInstruction* add_dependency) override;
162   Status HandleAfterAll(const HloInstruction* token) override;
163   Status HandleTranspose(const HloInstruction* transpose) override;
164   Status HandleWhile(const HloInstruction* xla_while) override;
165   Status HandleConditional(const HloInstruction* conditional) override;
166   Status HandleGather(const HloInstruction* gather) override;
167   Status HandleScatter(const HloInstruction* hlo) override;
168   Status HandleGetDimensionSize(const HloInstruction* get_size) override;
169   Status HandleSetDimensionSize(const HloInstruction* set_size) override;
170   Status FinishVisit(const HloInstruction* root) override;
171 
172   Status Preprocess(const HloInstruction* hlo) override;
173   Status Postprocess(const HloInstruction* hlo) override;
174 
175   // Decorates shape_size_ by returning 0 immediately if the shape does not have
176   // a layout.
177   int64_t GetShapeSize(const Shape& shape) const;
178 
179   // Returns properties for the computation.
180   float flop_count() const;
181   float transcendental_count() const;
182   float bytes_accessed() const;
183   float optimal_seconds() const;
184 
185   // Returns the respective cost computed for a particular HLO instruction, or 0
186   // if the HLO was not found to have a cost in the analysis.
187   //
188   // Note that the cost for sub HLO instructions are also returned if asked. For
189   // example, body and condition of a while, fused instructions within a
190   // fusion, or the add instruction of a reduce.
191   int64_t flop_count(const HloInstruction& hlo) const;
192   int64_t transcendental_count(const HloInstruction& hlo) const;
193   int64_t bytes_accessed(const HloInstruction& hlo) const;
194   int64_t operand_bytes_accessed(const HloInstruction& hlo, int64_t operand_num,
195                                  ShapeIndex index = {}) const;
196   int64_t output_bytes_accessed(const HloInstruction& hlo,
197                                 ShapeIndex index = {}) const;
198   float optimal_seconds(const HloInstruction& hlo) const;
199 
200   // Get bytes read/written by this HLO. If memory_space is provided, it returns
201   // the bytes read/written from/to the given memory space only.
202   int64_t GetBytesRead(
203       const HloInstruction& hlo,
204       std::optional<int64_t> memory_space = std::nullopt) const;
205   int64_t GetBytesWritten(
206       const HloInstruction& hlo,
207       std::optional<int64_t> memory_space = std::nullopt) const;
208 
properties()209   const Properties& properties() const { return properties_sum_; }
property(const std::string & key)210   const float property(const std::string& key) const {
211     return GetProperty(key, properties());
212   }
213 
214   // Returns the specified per-second rate used by cost analysis.
per_second_rate(absl::string_view key)215   const float per_second_rate(absl::string_view key) const {
216     return GetProperty(key, options_.per_second_rates);
217   }
218 
219   // Return the key that is used to index into Properties for the specified
220   // input/output at the shape index.
221   static std::string GetOperandBytesAccessedKey(int64_t operand_num,
222                                                 ShapeIndex index = {});
223   static std::string GetOutputBytesAccessedKey(ShapeIndex index = {});
224 
225   // Returns the estimated convolution flops.
226   virtual int64_t GetConvolutionFlops(const HloInstruction* convolution);
227   // Same as above but with parameters for shapes to allow for backends to
228   // refine these.
229   static int64_t GetConvolutionFlops(const HloInstruction* convolutions,
230                                      const Shape& lhs_shape,
231                                      const Shape& rhs_shape,
232                                      const Shape& result_shape);
233 
234   // Returns the estimated dot flops.
235   static int64_t GetDotFlops(const Shape& lhs_shape, const Shape& result_shape,
236                              const DotDimensionNumbers& dnums);
237 
238  protected:
239   typedef absl::flat_hash_map<const HloInstruction*, Properties>
240       HloToProperties;
241 
242   // An FMA counts as two floating point operations in these analyzes.
243   static constexpr int64_t kFmaFlops = 2;
244 
245   // Creates a nested instance of HloCostAnalysis using the same Options.
246   virtual std::unique_ptr<HloCostAnalysis> CreateNestedCostAnalysis();
247 
248   // Returns the properties computed from visiting the computation rooted at the
249   // given hlo. The cost of visited sub HLO instructions is saved to
250   // hlo_properties_, which will be used by functions such as
251   // flop_count(hlo_instruction) to return cost of a particular HLO instruction.
252   StatusOr<Properties> ProcessSubcomputation(HloComputation* computation);
253 
254   // Utility function to handle all element-wise operations.
255   Status HandleElementwiseOp(const HloInstruction* hlo_instruction);
256 
257   // Returns the default value if the key is not present in the
258   // properties. Otherwise, returns the value that the key maps to from the
259   // properties parameter.
260   static float GetProperty(absl::string_view key, const Properties& properties,
261                            float default_value = 0.0f);
262 
263   // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key
264   // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that
265   // the key maps to in the properties of the given hlo.
266   static float GetPropertyForHlo(const HloInstruction& hlo,
267                                  const std::string& key,
268                                  const HloToProperties& hlo_to_properties);
269 
270   // Traverses a fusion operand to find the actual bytes accessed by the fusion
271   // node.
272   int64_t FusionParameterReadBytes(const HloInstruction* hlo) const;
273 
274   // Set bytes accessed by the specified operand and shape index.
275   void SetOperandBytesAccessed(int64_t operand_num, float value);
276   void SetOperandBytesAccessed(int64_t operand_num, ShapeIndex index,
277                                float value);
278 
279   // Set bytes accessed by the output at the shape index.
280   void SetOutputBytesAccessed(float value);
281   void SetOutputBytesAccessed(ShapeIndex index, float value);
282 
283   HloToProperties hlo_properties_;
284 
285   // If true, the time taken will be computed from the rates for each property
286   // and the total time will be the maximum time, which is the time of the
287   // bottleneck.
288   bool current_should_compute_bottleneck_time_;
289 
290   // The properties of the currently visited instruction. A HandleFoo method can
291   // modify these to change the default values computed in Preprocess.
292   Properties current_properties_;
293 
294   // The sum of the properties of all HLOs in the computation.
295   Properties properties_sum_;
296 
297   // The hardware-specific options that contains things like the shape size
298   // function and per-second rates.
299   Options options_;
300 
301   HloCostAnalysis(const HloCostAnalysis&) = delete;
302   HloCostAnalysis& operator=(const HloCostAnalysis&) = delete;
303 };
304 
305 }  // namespace xla
306 
307 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
308