1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 18 19 #include <functional> 20 #include <string> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 25 #include "tensorflow/compiler/xla/service/hlo_computation.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/statusor.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 32 namespace xla { 33 34 // HloCostAnalysis traverses an HLO graph and calculates the amount of 35 // computations required for the graph. Each HLO instruction handler provides 36 // the computation cost of the instruction, and the values are accumulated 37 // during the traversal for the entire graph. We treat normal floating point 38 // operations separately from transcendental operations. 39 class HloCostAnalysis : public ConstDfsHloVisitor { 40 public: 41 // Each HLO is associated to a vector of properties with the indices given 42 // below. Sub-classes can add further properties. 43 // MSVC 14.0 limitation requires the consts. 44 typedef std::map<std::string, float, std::less<>> Properties; 45 // shape_size is a function which returns the size in bytes of the top-level 46 // buffer of a shape. 47 using ShapeSizeFunction = std::function<int64_t(const Shape&)>; 48 49 static constexpr const char kFlopsKey[] = "flops"; 50 static constexpr const char kTranscendentalsKey[] = "transcendentals"; 51 static constexpr const char kBytesAccessedKey[] = "bytes accessed"; 52 static constexpr const char kOptimalSecondsKey[] = "optimal_seconds"; 53 54 // A struct to encapsulate hardware-related options. This includes the shape 55 // size function, which is used to encode hardware-specific padding and per 56 // second rates of FLOPs, bytes per second (available bandwidth), and 57 // transcendentals per second. 58 struct Options { 59 // Function which computes the size of the top-level of a given shape (not 60 // including nested elements, if any). If null then bytes_accessed methods 61 // return an error. 62 ShapeSizeFunction shape_size; 63 // How much of each property can be processed per second. E.g. if the 64 // property is bytes accessed, this is the number of bytes that can be 65 // processed per second. Is empty if no rates have been set. 66 Properties per_second_rates = {}; 67 68 // Set the rates used to calculate the time taken by the computation. set_flops_per_secondOptions69 void set_flops_per_second(float value) { 70 per_second_rates[kFlopsKey] = value; 71 } set_transcendentals_per_secondOptions72 void set_transcendentals_per_second(float value) { 73 per_second_rates[kTranscendentalsKey] = value; 74 } set_bytes_per_secondOptions75 void set_bytes_per_second(float value) { 76 per_second_rates[kBytesAccessedKey] = value; 77 } 78 79 // Returns the specified per-second rate used by cost analysis. per_second_rateOptions80 const float per_second_rate(const std::string& key) const { 81 return GetProperty(key, per_second_rates); 82 } 83 }; 84 85 explicit HloCostAnalysis(const Options& options); 86 explicit HloCostAnalysis(ShapeSizeFunction shape_size, 87 const Properties& per_second_rates = {}); 88 89 Status HandleElementwiseUnary(const HloInstruction* hlo) override; 90 Status HandleElementwiseBinary(const HloInstruction* hlo) override; 91 Status HandleConstant(const HloInstruction* constant) override; 92 Status HandleIota(const HloInstruction* iota) override; 93 Status HandleGetTupleElement( 94 const HloInstruction* get_tuple_element) override; 95 Status HandleSelect(const HloInstruction* hlo) override; 96 Status HandleCompare(const HloInstruction* compare) override; 97 Status HandleClamp(const HloInstruction* clamp) override; 98 Status HandleReducePrecision(const HloInstruction* hlo) override; 99 Status HandleConcatenate(const HloInstruction* concatenate) override; 100 Status HandleAsyncStart(const HloInstruction* async_start) override; 101 Status HandleAsyncUpdate(const HloInstruction* async_update) override; 102 Status HandleAsyncDone(const HloInstruction* async_done) override; 103 Status HandleCopyStart(const HloInstruction* send) override; 104 Status HandleCopyDone(const HloInstruction* send_done) override; 105 Status HandleSend(const HloInstruction* send) override; 106 Status HandleSendDone(const HloInstruction* send_done) override; 107 Status HandleRecv(const HloInstruction* recv) override; 108 Status HandleRecvDone(const HloInstruction* recv_done) override; 109 Status HandleConvert(const HloInstruction* convert) override; 110 Status HandleCopy(const HloInstruction* copy) override; 111 Status HandleDomain(const HloInstruction* domain) override; 112 Status HandleDot(const HloInstruction* dot) override; 113 Status HandleConvolution(const HloInstruction* convolution) override; 114 Status HandleFft(const HloInstruction* fft) override; 115 Status HandleTriangularSolve(const HloInstruction* hlo) override; 116 Status HandleCholesky(const HloInstruction* hlo) override; 117 Status HandleOptimizationBarrier(const HloInstruction* hlo) override; 118 Status HandleAllGather(const HloInstruction* hlo) override; 119 Status HandleAllGatherStart(const HloInstruction* hlo) override; 120 Status HandleAllGatherDone(const HloInstruction* hlo) override; 121 Status HandleAllReduce(const HloInstruction* crs) override; 122 Status HandleReduceScatter(const HloInstruction* hlo) override; 123 Status HandleAllReduceStart(const HloInstruction* hlo) override; 124 Status HandleAllReduceDone(const HloInstruction* hlo) override; 125 Status HandleAllToAll(const HloInstruction* hlo) override; 126 Status HandleCollectivePermute(const HloInstruction* hlo) override; 127 Status HandleCollectivePermuteStart(const HloInstruction* hlo) override; 128 Status HandleCollectivePermuteDone(const HloInstruction* hlo) override; 129 Status HandleReplicaId(const HloInstruction* hlo) override; 130 Status HandlePartitionId(const HloInstruction* hlo) override; 131 Status HandleInfeed(const HloInstruction* infeed) override; 132 Status HandleOutfeed(const HloInstruction* outfeed) override; 133 Status HandleRng(const HloInstruction* random) override; 134 Status HandleRngBitGenerator(const HloInstruction* random) override; 135 Status HandleRngGetAndUpdateState(const HloInstruction* random) override; 136 Status HandleReverse(const HloInstruction* reverse) override; 137 Status HandleSort(const HloInstruction* sort) override; 138 Status HandleParameter(const HloInstruction* parameter) override; 139 Status HandleReduce(const HloInstruction* reduce) override; 140 Status HandleBatchNormTraining( 141 const HloInstruction* batch_norm_training) override; 142 Status HandleBatchNormInference( 143 const HloInstruction* batch_norm_inference) override; 144 Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override; 145 Status HandleFusion(const HloInstruction* fusion) override; 146 Status HandleCall(const HloInstruction* call) override; 147 Status HandleCustomCall(const HloInstruction* custom_call) override; 148 Status HandleSlice(const HloInstruction* slice) override; 149 Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override; 150 Status HandleDynamicUpdateSlice( 151 const HloInstruction* dynamic_update_slice) override; 152 Status HandleTuple(const HloInstruction* tuple) override; 153 Status HandleMap(const HloInstruction* map) override; 154 Status HandleReduceWindow(const HloInstruction* reduce_window) override; 155 Status HandleSelectAndScatter(const HloInstruction* instruction) override; 156 Status HandleBitcast(const HloInstruction* bitcast) override; 157 Status HandleBroadcast(const HloInstruction* broadcast) override; 158 Status HandlePad(const HloInstruction* pad) override; 159 Status HandleReshape(const HloInstruction* reshape) override; 160 Status HandleDynamicReshape(const HloInstruction* reshape) override; 161 Status HandleAddDependency(const HloInstruction* add_dependency) override; 162 Status HandleAfterAll(const HloInstruction* token) override; 163 Status HandleTranspose(const HloInstruction* transpose) override; 164 Status HandleWhile(const HloInstruction* xla_while) override; 165 Status HandleConditional(const HloInstruction* conditional) override; 166 Status HandleGather(const HloInstruction* gather) override; 167 Status HandleScatter(const HloInstruction* hlo) override; 168 Status HandleGetDimensionSize(const HloInstruction* get_size) override; 169 Status HandleSetDimensionSize(const HloInstruction* set_size) override; 170 Status FinishVisit(const HloInstruction* root) override; 171 172 Status Preprocess(const HloInstruction* hlo) override; 173 Status Postprocess(const HloInstruction* hlo) override; 174 175 // Decorates shape_size_ by returning 0 immediately if the shape does not have 176 // a layout. 177 int64_t GetShapeSize(const Shape& shape) const; 178 179 // Returns properties for the computation. 180 float flop_count() const; 181 float transcendental_count() const; 182 float bytes_accessed() const; 183 float optimal_seconds() const; 184 185 // Returns the respective cost computed for a particular HLO instruction, or 0 186 // if the HLO was not found to have a cost in the analysis. 187 // 188 // Note that the cost for sub HLO instructions are also returned if asked. For 189 // example, body and condition of a while, fused instructions within a 190 // fusion, or the add instruction of a reduce. 191 int64_t flop_count(const HloInstruction& hlo) const; 192 int64_t transcendental_count(const HloInstruction& hlo) const; 193 int64_t bytes_accessed(const HloInstruction& hlo) const; 194 int64_t operand_bytes_accessed(const HloInstruction& hlo, int64_t operand_num, 195 ShapeIndex index = {}) const; 196 int64_t output_bytes_accessed(const HloInstruction& hlo, 197 ShapeIndex index = {}) const; 198 float optimal_seconds(const HloInstruction& hlo) const; 199 200 // Get bytes read/written by this HLO. If memory_space is provided, it returns 201 // the bytes read/written from/to the given memory space only. 202 int64_t GetBytesRead( 203 const HloInstruction& hlo, 204 std::optional<int64_t> memory_space = std::nullopt) const; 205 int64_t GetBytesWritten( 206 const HloInstruction& hlo, 207 std::optional<int64_t> memory_space = std::nullopt) const; 208 properties()209 const Properties& properties() const { return properties_sum_; } property(const std::string & key)210 const float property(const std::string& key) const { 211 return GetProperty(key, properties()); 212 } 213 214 // Returns the specified per-second rate used by cost analysis. per_second_rate(absl::string_view key)215 const float per_second_rate(absl::string_view key) const { 216 return GetProperty(key, options_.per_second_rates); 217 } 218 219 // Return the key that is used to index into Properties for the specified 220 // input/output at the shape index. 221 static std::string GetOperandBytesAccessedKey(int64_t operand_num, 222 ShapeIndex index = {}); 223 static std::string GetOutputBytesAccessedKey(ShapeIndex index = {}); 224 225 // Returns the estimated convolution flops. 226 virtual int64_t GetConvolutionFlops(const HloInstruction* convolution); 227 // Same as above but with parameters for shapes to allow for backends to 228 // refine these. 229 static int64_t GetConvolutionFlops(const HloInstruction* convolutions, 230 const Shape& lhs_shape, 231 const Shape& rhs_shape, 232 const Shape& result_shape); 233 234 // Returns the estimated dot flops. 235 static int64_t GetDotFlops(const Shape& lhs_shape, const Shape& result_shape, 236 const DotDimensionNumbers& dnums); 237 238 protected: 239 typedef absl::flat_hash_map<const HloInstruction*, Properties> 240 HloToProperties; 241 242 // An FMA counts as two floating point operations in these analyzes. 243 static constexpr int64_t kFmaFlops = 2; 244 245 // Creates a nested instance of HloCostAnalysis using the same Options. 246 virtual std::unique_ptr<HloCostAnalysis> CreateNestedCostAnalysis(); 247 248 // Returns the properties computed from visiting the computation rooted at the 249 // given hlo. The cost of visited sub HLO instructions is saved to 250 // hlo_properties_, which will be used by functions such as 251 // flop_count(hlo_instruction) to return cost of a particular HLO instruction. 252 StatusOr<Properties> ProcessSubcomputation(HloComputation* computation); 253 254 // Utility function to handle all element-wise operations. 255 Status HandleElementwiseOp(const HloInstruction* hlo_instruction); 256 257 // Returns the default value if the key is not present in the 258 // properties. Otherwise, returns the value that the key maps to from the 259 // properties parameter. 260 static float GetProperty(absl::string_view key, const Properties& properties, 261 float default_value = 0.0f); 262 263 // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key 264 // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that 265 // the key maps to in the properties of the given hlo. 266 static float GetPropertyForHlo(const HloInstruction& hlo, 267 const std::string& key, 268 const HloToProperties& hlo_to_properties); 269 270 // Traverses a fusion operand to find the actual bytes accessed by the fusion 271 // node. 272 int64_t FusionParameterReadBytes(const HloInstruction* hlo) const; 273 274 // Set bytes accessed by the specified operand and shape index. 275 void SetOperandBytesAccessed(int64_t operand_num, float value); 276 void SetOperandBytesAccessed(int64_t operand_num, ShapeIndex index, 277 float value); 278 279 // Set bytes accessed by the output at the shape index. 280 void SetOutputBytesAccessed(float value); 281 void SetOutputBytesAccessed(ShapeIndex index, float value); 282 283 HloToProperties hlo_properties_; 284 285 // If true, the time taken will be computed from the rates for each property 286 // and the total time will be the maximum time, which is the time of the 287 // bottleneck. 288 bool current_should_compute_bottleneck_time_; 289 290 // The properties of the currently visited instruction. A HandleFoo method can 291 // modify these to change the default values computed in Preprocess. 292 Properties current_properties_; 293 294 // The sum of the properties of all HLOs in the computation. 295 Properties properties_sum_; 296 297 // The hardware-specific options that contains things like the shape size 298 // function and per-second rates. 299 Options options_; 300 301 HloCostAnalysis(const HloCostAnalysis&) = delete; 302 HloCostAnalysis& operator=(const HloCostAnalysis&) = delete; 303 }; 304 305 } // namespace xla 306 307 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ 308