xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Analysis for determining the possible set of values for all positions
17 // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped
18 // tracking values across computation boundaries.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
21 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
22 
23 #include <functional>
24 #include <iterator>
25 #include <memory>
26 #include <string>
27 #include <vector>
28 
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/types/span.h"
32 #include "tensorflow/compiler/xla/service/call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_module.h"
35 #include "tensorflow/compiler/xla/service/hlo_phi_graph.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 
42 namespace xla {
43 
44 // Identifies one array input of an HloInstruction.
45 struct HloOperandIndex {
46   // The operand number in which the array value appears.
47   int64_t operand_number;
48 
49   // The shape index within the operand in which the array value appears.
50   ShapeIndex operand_index;
51 
52   bool operator==(const HloOperandIndex& other) const {
53     return operand_number == other.operand_number &&
54            operand_index == other.operand_index;
55   }
56 
57   bool operator!=(const HloOperandIndex& other) const {
58     return !(*this == other);
59   }
60 };
61 
62 // Analysis which identifies all HLO values and their uses in an HLO module.
63 class HloDataflowAnalysis {
64  public:
65   // Infrastructure for passing may-alias hints: HLO passes can populate the
66   // may-alias table. If an empty optional is returned, default rules are used.
67   //
68   // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be
69   // overriden using backend-specific overrides.
70   //
71   // The first parameter of the function should be the instruction, the
72   // second parameter should be an operand of the instruction. The third
73   // parameter should be the output index of the instruction.
74   using CanShareBuffer = std::function<std::optional<bool>(
75       const HloInstruction* instr, const HloInstruction* operand,
76       const ShapeIndex& user_index)>;
77 
78   // Runs dataflow analysis on the given module. Parameters:
79   //
80   //   ssa_form : If true then new values are defined at the merge points of
81   //     kWhile instructions. Abusing nomenclature somewhat, we call these "phi
82   //     values".  The merge is formed by the init value and loop backedge. The
83   //     SSA form is minimal in that a new phi value is defined only if the
84   //     merge point is reachable by multiple different values. The SSA form is
85   //     also in loop-closed form in that no values defined inside of a loop
86   //     (while body) is used outside of the loop. Example use of this ssa_form
87   //     mode is to reason about live range interference of buffers.
88   //
89   //     If ssa_form is false, then merge points do not define new
90   //     values. Rather, the HloValueSet for the merge point contains the union
91   //     of the merged HloValues.
92   //
93   //   bitcast_defines_value : If true then the Bitcast HLO instruction defines
94   //     a new HLO value in the analysis. If false then Bitcast forwards the
95   //     value of its operand.
96   static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
97       const HloModule& module, bool ssa_form = false,
98       bool bitcast_defines_value = false,
99       const CanShareBuffer& can_share_buffer = nullptr);
100 
101   // Returns true if 'instruction' defines an HLO value at the given shape index
102   // of its output.
103   bool ValueIsDefinedAt(const HloInstruction* instruction,
104                         const ShapeIndex& index = {}) const;
105 
106   // Returns the HloValue defined by 'instruction' at the given shape index of
107   // its output.
108   //
109   // Precondition: ValueIsDefinedAt is true for this instruction and index.
110   const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
111                                     const ShapeIndex& index = {}) const;
112   HloValue& GetValueDefinedAt(const HloInstruction* instruction,
113                               const ShapeIndex& index = {});
114 
115   // Returns the InstructionValueSet for the given instruction.
116   const InstructionValueSet& GetInstructionValueSet(
117       const HloInstruction* instruction) const;
118   InstructionValueSet& GetInstructionValueSet(
119       const HloInstruction* instruction);
120 
121   // Returns all values that are contained in the output of this instruction in
122   // a flattened set.
123   HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const;
124 
125   // Returns the HloValueSet for the given instruction at the given index or the
126   // given position.
127   const HloValueSet& GetValueSet(const HloInstruction* instruction,
128                                  const ShapeIndex& index = {}) const;
129   const HloValueSet& GetValueSet(const HloPosition& position) const;
130   HloValueSet& GetValueSet(const HloPosition& position);
131   HloValueSet& GetValueSet(const HloInstruction* instruction,
132                            const ShapeIndex& index = {});
133 
134   // Returns the unique value in the HloValueSet at the given instruction and
135   // shape index. CHECKs if the value set does not contain a exactly one value.
136   const HloValue& GetUniqueValueAt(const HloInstruction* instruction,
137                                    const ShapeIndex& index = {}) const {
138     return GetValueSet(instruction, index).GetUniqueValue();
139   }
140   HloValue& GetUniqueValueAt(const HloInstruction* instruction,
141                              const ShapeIndex& index = {}) {
142     return GetValue(GetValueSet(instruction, index).GetUniqueValue().id());
143   }
144 
145   // Returns the HloValue with the given Id.
146   const HloValue& GetValue(HloValue::Id value_id) const;
147   HloValue& GetValue(HloValue::Id value_id);
148 
149   // Returns the total number of HloValues.
value_count()150   int64_t value_count() const { return values_.size(); }
151 
152   // Returns a vector of all HloValues stabily sorted by HloValue::Id.
values()153   const std::vector<HloValue*>& values() const { return values_vector_; }
154 
155   // Returns the call graph used for computing the dataflow.
call_graph()156   const CallGraph& call_graph() const { return *call_graph_; }
157 
158   std::string ToString() const;
159 
160   // Returns true if 'user' cannot possibly use the buffer at 'index' in
161   // 'operand'. Returns false otherwise.
162   //
163   // 'operand' does not have to be an operand of 'user'. This can be the
164   // case with indirect uses.
165   bool DoesNotUseOperandBuffer(const HloInstruction* operand,
166                                const ShapeIndex& index,
167                                const HloInstruction* user) const;
168 
169   // Returns true if 'user' (at 'user_index') can share a buffer with its
170   // operand 'operand' (at 'operand_index'). Returns false otherwise.
171   //
172   // REQUIRES: 'operand' is an operand of 'user'.
173   bool CanShareOperandBufferWithUser(HloInstruction* operand,
174                                      const ShapeIndex& operand_index,
175                                      HloInstruction* user,
176                                      const ShapeIndex& user_index) const;
177 
module()178   const HloModule& module() const { return module_; }
179 
180   // Returns true if the operation is an in-place operation and its operand 0
181   // must alias with the output.
182   static bool IsInPlaceOperation(HloOpcode opcode);
183 
184   // Returns true if the operation is the start/done of an asynchronous
185   // operation, where the buffer used/produced by the op needs to stay alive
186   // until the asynchronous operation completes.
187   static bool IsAsynchronousOperationStart(HloOpcode opcode);
188   static bool IsAsynchronousOperationDone(HloOpcode opcode);
189 
190   // Returns the pairs of inputs and outputs that must share the same buffer,
191   // according to the aliasing rules for that instruction.
192   //
193   // This function only considers array values as inputs and outputs, so
194   // when tuples are present it "sees through" to the array values inside. The
195   // HloUse describing the input parameter contains not only the operand number
196   // but also a shape index describing its position inside a nested tuple shape
197   // (if any). Similarly, the output parameter is described by a shape index
198   // into the nested tuple shape (if any) of the output value.
199   //
200   // For example, for this hypothetical op:
201   //   %foo = (f32[1], (f32[2], f32[3]))
202   //              op((f32[4], f32[5]) %arg0, f32[6] %arg1)
203   //
204   // ... the results can include any of the 3 * 3 = 9 possible pairs of
205   // input and output arrays.
206   static std::vector<std::pair<HloOperandIndex, ShapeIndex>>
207   GetInPlaceInputOutputPairs(const HloInstruction* instruction);
208 
209  private:
210   static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst);
211 
212   HloDataflowAnalysis(const HloModule& module, bool ssa_form,
213                       bool bitcast_defines_value = false,
214                       const CanShareBuffer& can_share_buffer = nullptr);
215 
216   // 1. During value propagation (Propagate function), always create phi
217   // values once it see multiple inputs merging at the same point. It then
218   // records those phi values as well as their inputs in a phi graph.
219   //
220   // 2. Post value propagation, Dataflow analysis can then do certain
221   // optimization(OptimizePhiValues) on the phi graph to prune uncessary phi
222   // nodes.
223   //
224   // Note that this applies in SSA form, and Both of the functions are
225   // guaranteed to exit.
226   //
227   void OptimizePhiValues();
228 
229   // Returns a new HloValue defined at the given instruction and shape index.
230   HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
231                         bool is_phi);
232 
233   // Marks the HloValue with the given ID for deletion.
234   void MarkValueForDeletion(HloValue::Id value_id);
235 
236   // Deletes all HloValues marked for deletion. Should be called after
237   // propagation is complete.
238   void DeleteMarkedValues();
239 
240   // Constructs and initializes the InstructionValueSets of all instructions to
241   // contain exactly the HloValues defined by each instruction. These values can
242   // then propagated throughout the HLO graph by calling Propagate.
243   Status InitializeInstructionValueSets();
244 
245   // Updates the value set of the given instruction based on the values flowing
246   // into the instruction (operands and cross-computation dataflow).
247   bool UpdateInstructionValueSet(HloInstruction* instruction);
248 
249   // Updates the value set for a particular instruction type. Returns whether
250   // the instruction value set changed.
251   bool UpdateBitcastValueSet(HloInstruction* bitcast);
252   bool UpdateCallValueSet(HloInstruction* call);
253   bool UpdateConditionalValueSet(HloInstruction* conditional);
254   bool UpdateCopyValueSet(HloInstruction* copy);
255   bool UpdateCustomCallValueSet(HloInstruction* custom_call);
256   bool UpdateDomainValueSet(HloInstruction* domain);
257   bool UpdateGetTupleElementValueSet(HloInstruction* gte);
258   bool UpdateParameterValueSet(HloInstruction* parameter);
259   // Async op propagation rules:
260   //  - Operand of async-start to parameter of async wrapped computation and at
261   //    index {0, operand_number} of async-start and async-update outputs.
262   //  - Root of async wrapped computation to index {1} of async-start and
263   //    async-update and index {} of async-done.
264   //  - The contexts in indices {2+} of async-start to the same indices of
265   //    async-update.
266   //
267   // As a result of this, the operands/outputs of async-start and async-done
268   // instructions share the same values as the parameters/roots of the async
269   // wrapped computation.
270   bool UpdateAsyncStartValueSet(HloInstruction* async_start);
271   bool UpdateAsyncUpdateValueSet(HloInstruction* async_update);
272   bool UpdateAsyncDoneValueSet(HloInstruction* async_done);
273   bool UpdateCopyStartValueSet(HloInstruction* copy_start);
274   bool UpdateCopyDoneValueSet(HloInstruction* copy_done);
275   bool UpdateOptimizationBarrierValueSet(HloInstruction* barrier);
276   bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
277   bool UpdateSendValueSet(HloInstruction* send);
278   bool UpdateSetDimensionSizeValueSet(HloInstruction* set_dimension_size);
279   bool UpdateTupleValueSet(HloInstruction* tuple);
280   bool UpdateWhileValueSet(HloInstruction* xla_while);
281   bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
282   bool UpdateAllGatherStartValueSet(HloInstruction* all_gather_start);
283   bool UpdateAllGatherDoneValueSet(HloInstruction* all_gather_done);
284   bool UpdateAllReduceDoneValueSet(HloInstruction* all_reduce_done);
285   bool UpdateCollectivePermuteStartValueSet(
286       HloInstruction* collective_permute_start);
287   bool UpdateCollectivePermuteDoneValueSet(
288       HloInstruction* collective_permute_done);
289 
290   // Propagates the dataflow through the module. In particular, it propagates
291   // the HloValueSet from its defining instruction to the users of the
292   // instructions.
293   void Propagate();
294 
295   // Returns the result of the SSA Phi function applied to the given inputs at
296   // the given instruction.
297   bool Phi(HloInstruction* instruction,
298            absl::Span<const InstructionValueSet* const> inputs);
299 
300   // Updates the positions of the HloValues in the output of the given
301   // instruction. This should be called after the instruction value set of
302   // 'instruction' has been changed. 'prev_value_set' must point to the previous
303   // state of the value set prior to the change. 'prev_value_set' may be null if
304   // this is the first time positions are being computed. The previous state is
305   // necessary to efficiently remove positions which have been eliminated due to
306   // changes in the instructions' InstructionValueSet.
307   void UpdatePositionsOfValuesAt(
308       HloInstruction* instruction, const InstructionValueSet& new_value_set,
309       const InstructionValueSet* prev_value_set = nullptr);
310 
311   // Verifies various invariants of the dataflow analysis.
312   Status Verify() const;
313 
314   const HloModule& module_;
315   const bool ssa_form_;
316   const bool bitcast_defines_value_;
317 
318   std::unique_ptr<CallGraph> call_graph_;
319 
320   // The map of all HloValues in the module. We pass around pointers to the
321   // mapped HloValues, so the underlying container must keep them valid despite
322   // mutations touching other map entries.
323   absl::flat_hash_map<HloValue::Id, std::unique_ptr<HloValue>> values_;
324 
325   // A map from instruction to InstructionValueSet.
326   absl::flat_hash_map<const HloInstruction*,
327                       std::unique_ptr<InstructionValueSet>>
328       value_sets_;
329 
330   // Values marked for deletion during construction. We don't delete them
331   // immediately because references to them may remain in ValueSets temporarily
332   // during propagation. After construction, these values are deleted.
333   std::vector<HloValue::Id> value_ids_to_delete_;
334 
335   // A vector containing all HloValues sorted by HloValue::Id.
336   std::vector<HloValue*> values_vector_;
337 
338   // The Id to use for the next HloValue.
339   HloValue::Id next_value_id_ = 0;
340 
341   // An explicit graph holding phi values and edges.
342   PhiGraph phi_graph_;
343 
344   // Backend specific function that decides whether an instruction can share
345   // a buffer with its operand.
346   CanShareBuffer can_share_buffer_ = nullptr;
347 };
348 
349 }  // namespace xla
350 
351 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
352