1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Analysis for determining the possible set of values for all positions 17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped 18 // tracking values across computation boundaries. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 22 23 #include <functional> 24 #include <iterator> 25 #include <memory> 26 #include <string> 27 #include <vector> 28 29 #include "absl/container/flat_hash_map.h" 30 #include "absl/container/flat_hash_set.h" 31 #include "absl/types/span.h" 32 #include "tensorflow/compiler/xla/service/call_graph.h" 33 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 34 #include "tensorflow/compiler/xla/service/hlo_module.h" 35 #include "tensorflow/compiler/xla/service/hlo_phi_graph.h" 36 #include "tensorflow/compiler/xla/shape_util.h" 37 #include "tensorflow/compiler/xla/status.h" 38 #include "tensorflow/compiler/xla/statusor.h" 39 #include "tensorflow/compiler/xla/types.h" 40 #include "tensorflow/compiler/xla/xla_data.pb.h" 41 42 namespace xla { 43 44 // Identifies one array input of an HloInstruction. 45 struct HloOperandIndex { 46 // The operand number in which the array value appears. 47 int64_t operand_number; 48 49 // The shape index within the operand in which the array value appears. 50 ShapeIndex operand_index; 51 52 bool operator==(const HloOperandIndex& other) const { 53 return operand_number == other.operand_number && 54 operand_index == other.operand_index; 55 } 56 57 bool operator!=(const HloOperandIndex& other) const { 58 return !(*this == other); 59 } 60 }; 61 62 // Analysis which identifies all HLO values and their uses in an HLO module. 63 class HloDataflowAnalysis { 64 public: 65 // Infrastructure for passing may-alias hints: HLO passes can populate the 66 // may-alias table. If an empty optional is returned, default rules are used. 67 // 68 // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be 69 // overriden using backend-specific overrides. 70 // 71 // The first parameter of the function should be the instruction, the 72 // second parameter should be an operand of the instruction. The third 73 // parameter should be the output index of the instruction. 74 using CanShareBuffer = std::function<std::optional<bool>( 75 const HloInstruction* instr, const HloInstruction* operand, 76 const ShapeIndex& user_index)>; 77 78 // Runs dataflow analysis on the given module. Parameters: 79 // 80 // ssa_form : If true then new values are defined at the merge points of 81 // kWhile instructions. Abusing nomenclature somewhat, we call these "phi 82 // values". The merge is formed by the init value and loop backedge. The 83 // SSA form is minimal in that a new phi value is defined only if the 84 // merge point is reachable by multiple different values. The SSA form is 85 // also in loop-closed form in that no values defined inside of a loop 86 // (while body) is used outside of the loop. Example use of this ssa_form 87 // mode is to reason about live range interference of buffers. 88 // 89 // If ssa_form is false, then merge points do not define new 90 // values. Rather, the HloValueSet for the merge point contains the union 91 // of the merged HloValues. 92 // 93 // bitcast_defines_value : If true then the Bitcast HLO instruction defines 94 // a new HLO value in the analysis. If false then Bitcast forwards the 95 // value of its operand. 96 static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run( 97 const HloModule& module, bool ssa_form = false, 98 bool bitcast_defines_value = false, 99 const CanShareBuffer& can_share_buffer = nullptr); 100 101 // Returns true if 'instruction' defines an HLO value at the given shape index 102 // of its output. 103 bool ValueIsDefinedAt(const HloInstruction* instruction, 104 const ShapeIndex& index = {}) const; 105 106 // Returns the HloValue defined by 'instruction' at the given shape index of 107 // its output. 108 // 109 // Precondition: ValueIsDefinedAt is true for this instruction and index. 110 const HloValue& GetValueDefinedAt(const HloInstruction* instruction, 111 const ShapeIndex& index = {}) const; 112 HloValue& GetValueDefinedAt(const HloInstruction* instruction, 113 const ShapeIndex& index = {}); 114 115 // Returns the InstructionValueSet for the given instruction. 116 const InstructionValueSet& GetInstructionValueSet( 117 const HloInstruction* instruction) const; 118 InstructionValueSet& GetInstructionValueSet( 119 const HloInstruction* instruction); 120 121 // Returns all values that are contained in the output of this instruction in 122 // a flattened set. 123 HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const; 124 125 // Returns the HloValueSet for the given instruction at the given index or the 126 // given position. 127 const HloValueSet& GetValueSet(const HloInstruction* instruction, 128 const ShapeIndex& index = {}) const; 129 const HloValueSet& GetValueSet(const HloPosition& position) const; 130 HloValueSet& GetValueSet(const HloPosition& position); 131 HloValueSet& GetValueSet(const HloInstruction* instruction, 132 const ShapeIndex& index = {}); 133 134 // Returns the unique value in the HloValueSet at the given instruction and 135 // shape index. CHECKs if the value set does not contain a exactly one value. 136 const HloValue& GetUniqueValueAt(const HloInstruction* instruction, 137 const ShapeIndex& index = {}) const { 138 return GetValueSet(instruction, index).GetUniqueValue(); 139 } 140 HloValue& GetUniqueValueAt(const HloInstruction* instruction, 141 const ShapeIndex& index = {}) { 142 return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); 143 } 144 145 // Returns the HloValue with the given Id. 146 const HloValue& GetValue(HloValue::Id value_id) const; 147 HloValue& GetValue(HloValue::Id value_id); 148 149 // Returns the total number of HloValues. value_count()150 int64_t value_count() const { return values_.size(); } 151 152 // Returns a vector of all HloValues stabily sorted by HloValue::Id. values()153 const std::vector<HloValue*>& values() const { return values_vector_; } 154 155 // Returns the call graph used for computing the dataflow. call_graph()156 const CallGraph& call_graph() const { return *call_graph_; } 157 158 std::string ToString() const; 159 160 // Returns true if 'user' cannot possibly use the buffer at 'index' in 161 // 'operand'. Returns false otherwise. 162 // 163 // 'operand' does not have to be an operand of 'user'. This can be the 164 // case with indirect uses. 165 bool DoesNotUseOperandBuffer(const HloInstruction* operand, 166 const ShapeIndex& index, 167 const HloInstruction* user) const; 168 169 // Returns true if 'user' (at 'user_index') can share a buffer with its 170 // operand 'operand' (at 'operand_index'). Returns false otherwise. 171 // 172 // REQUIRES: 'operand' is an operand of 'user'. 173 bool CanShareOperandBufferWithUser(HloInstruction* operand, 174 const ShapeIndex& operand_index, 175 HloInstruction* user, 176 const ShapeIndex& user_index) const; 177 module()178 const HloModule& module() const { return module_; } 179 180 // Returns true if the operation is an in-place operation and its operand 0 181 // must alias with the output. 182 static bool IsInPlaceOperation(HloOpcode opcode); 183 184 // Returns true if the operation is the start/done of an asynchronous 185 // operation, where the buffer used/produced by the op needs to stay alive 186 // until the asynchronous operation completes. 187 static bool IsAsynchronousOperationStart(HloOpcode opcode); 188 static bool IsAsynchronousOperationDone(HloOpcode opcode); 189 190 // Returns the pairs of inputs and outputs that must share the same buffer, 191 // according to the aliasing rules for that instruction. 192 // 193 // This function only considers array values as inputs and outputs, so 194 // when tuples are present it "sees through" to the array values inside. The 195 // HloUse describing the input parameter contains not only the operand number 196 // but also a shape index describing its position inside a nested tuple shape 197 // (if any). Similarly, the output parameter is described by a shape index 198 // into the nested tuple shape (if any) of the output value. 199 // 200 // For example, for this hypothetical op: 201 // %foo = (f32[1], (f32[2], f32[3])) 202 // op((f32[4], f32[5]) %arg0, f32[6] %arg1) 203 // 204 // ... the results can include any of the 3 * 3 = 9 possible pairs of 205 // input and output arrays. 206 static std::vector<std::pair<HloOperandIndex, ShapeIndex>> 207 GetInPlaceInputOutputPairs(const HloInstruction* instruction); 208 209 private: 210 static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst); 211 212 HloDataflowAnalysis(const HloModule& module, bool ssa_form, 213 bool bitcast_defines_value = false, 214 const CanShareBuffer& can_share_buffer = nullptr); 215 216 // 1. During value propagation (Propagate function), always create phi 217 // values once it see multiple inputs merging at the same point. It then 218 // records those phi values as well as their inputs in a phi graph. 219 // 220 // 2. Post value propagation, Dataflow analysis can then do certain 221 // optimization(OptimizePhiValues) on the phi graph to prune uncessary phi 222 // nodes. 223 // 224 // Note that this applies in SSA form, and Both of the functions are 225 // guaranteed to exit. 226 // 227 void OptimizePhiValues(); 228 229 // Returns a new HloValue defined at the given instruction and shape index. 230 HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, 231 bool is_phi); 232 233 // Marks the HloValue with the given ID for deletion. 234 void MarkValueForDeletion(HloValue::Id value_id); 235 236 // Deletes all HloValues marked for deletion. Should be called after 237 // propagation is complete. 238 void DeleteMarkedValues(); 239 240 // Constructs and initializes the InstructionValueSets of all instructions to 241 // contain exactly the HloValues defined by each instruction. These values can 242 // then propagated throughout the HLO graph by calling Propagate. 243 Status InitializeInstructionValueSets(); 244 245 // Updates the value set of the given instruction based on the values flowing 246 // into the instruction (operands and cross-computation dataflow). 247 bool UpdateInstructionValueSet(HloInstruction* instruction); 248 249 // Updates the value set for a particular instruction type. Returns whether 250 // the instruction value set changed. 251 bool UpdateBitcastValueSet(HloInstruction* bitcast); 252 bool UpdateCallValueSet(HloInstruction* call); 253 bool UpdateConditionalValueSet(HloInstruction* conditional); 254 bool UpdateCopyValueSet(HloInstruction* copy); 255 bool UpdateCustomCallValueSet(HloInstruction* custom_call); 256 bool UpdateDomainValueSet(HloInstruction* domain); 257 bool UpdateGetTupleElementValueSet(HloInstruction* gte); 258 bool UpdateParameterValueSet(HloInstruction* parameter); 259 // Async op propagation rules: 260 // - Operand of async-start to parameter of async wrapped computation and at 261 // index {0, operand_number} of async-start and async-update outputs. 262 // - Root of async wrapped computation to index {1} of async-start and 263 // async-update and index {} of async-done. 264 // - The contexts in indices {2+} of async-start to the same indices of 265 // async-update. 266 // 267 // As a result of this, the operands/outputs of async-start and async-done 268 // instructions share the same values as the parameters/roots of the async 269 // wrapped computation. 270 bool UpdateAsyncStartValueSet(HloInstruction* async_start); 271 bool UpdateAsyncUpdateValueSet(HloInstruction* async_update); 272 bool UpdateAsyncDoneValueSet(HloInstruction* async_done); 273 bool UpdateCopyStartValueSet(HloInstruction* copy_start); 274 bool UpdateCopyDoneValueSet(HloInstruction* copy_done); 275 bool UpdateOptimizationBarrierValueSet(HloInstruction* barrier); 276 bool UpdateRecvDoneValueSet(HloInstruction* recv_done); 277 bool UpdateSendValueSet(HloInstruction* send); 278 bool UpdateSetDimensionSizeValueSet(HloInstruction* set_dimension_size); 279 bool UpdateTupleValueSet(HloInstruction* tuple); 280 bool UpdateWhileValueSet(HloInstruction* xla_while); 281 bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); 282 bool UpdateAllGatherStartValueSet(HloInstruction* all_gather_start); 283 bool UpdateAllGatherDoneValueSet(HloInstruction* all_gather_done); 284 bool UpdateAllReduceDoneValueSet(HloInstruction* all_reduce_done); 285 bool UpdateCollectivePermuteStartValueSet( 286 HloInstruction* collective_permute_start); 287 bool UpdateCollectivePermuteDoneValueSet( 288 HloInstruction* collective_permute_done); 289 290 // Propagates the dataflow through the module. In particular, it propagates 291 // the HloValueSet from its defining instruction to the users of the 292 // instructions. 293 void Propagate(); 294 295 // Returns the result of the SSA Phi function applied to the given inputs at 296 // the given instruction. 297 bool Phi(HloInstruction* instruction, 298 absl::Span<const InstructionValueSet* const> inputs); 299 300 // Updates the positions of the HloValues in the output of the given 301 // instruction. This should be called after the instruction value set of 302 // 'instruction' has been changed. 'prev_value_set' must point to the previous 303 // state of the value set prior to the change. 'prev_value_set' may be null if 304 // this is the first time positions are being computed. The previous state is 305 // necessary to efficiently remove positions which have been eliminated due to 306 // changes in the instructions' InstructionValueSet. 307 void UpdatePositionsOfValuesAt( 308 HloInstruction* instruction, const InstructionValueSet& new_value_set, 309 const InstructionValueSet* prev_value_set = nullptr); 310 311 // Verifies various invariants of the dataflow analysis. 312 Status Verify() const; 313 314 const HloModule& module_; 315 const bool ssa_form_; 316 const bool bitcast_defines_value_; 317 318 std::unique_ptr<CallGraph> call_graph_; 319 320 // The map of all HloValues in the module. We pass around pointers to the 321 // mapped HloValues, so the underlying container must keep them valid despite 322 // mutations touching other map entries. 323 absl::flat_hash_map<HloValue::Id, std::unique_ptr<HloValue>> values_; 324 325 // A map from instruction to InstructionValueSet. 326 absl::flat_hash_map<const HloInstruction*, 327 std::unique_ptr<InstructionValueSet>> 328 value_sets_; 329 330 // Values marked for deletion during construction. We don't delete them 331 // immediately because references to them may remain in ValueSets temporarily 332 // during propagation. After construction, these values are deleted. 333 std::vector<HloValue::Id> value_ids_to_delete_; 334 335 // A vector containing all HloValues sorted by HloValue::Id. 336 std::vector<HloValue*> values_vector_; 337 338 // The Id to use for the next HloValue. 339 HloValue::Id next_value_id_ = 0; 340 341 // An explicit graph holding phi values and edges. 342 PhiGraph phi_graph_; 343 344 // Backend specific function that decides whether an instruction can share 345 // a buffer with its operand. 346 CanShareBuffer can_share_buffer_ = nullptr; 347 }; 348 349 } // namespace xla 350 351 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ 352