xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Analysis/userange_analysis.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H_
18 
19 #include <vector>
20 
21 #include "mlir/Analysis/Liveness.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/Value.h"
24 #include "mlir/Transforms/BufferUtils.h"
25 
26 namespace mlir {
27 
28 /// Represents an analysis for computing the useranges of all alloc values
29 /// inside a given function operation. The analysis uses liveness information to
30 /// compute intervals starting at the first and ending with the last use of
31 /// every alloc value.
32 class UserangeAnalysis {
33  public:
34   /// A typedef declaration of an UseInterval, which represents an interval as a
35   /// pair of begin to end.
36   using UseInterval = std::pair<size_t, size_t>;
37   using IntervalVector = SmallVector<UseInterval, 8>;
38 
39   UserangeAnalysis(Operation *op, const BufferPlacementAllocs &allocs,
40                    const BufferViewFlowAnalysis &aliases);
41 
42   /// Returns the index of the first operation that uses the given value.
43   /// Returns an empty Optional if the value has no uses.
getFirstUseIndex(Value value)44   llvm::Optional<size_t> getFirstUseIndex(Value value) const {
45     auto &intervals = useIntervalMap.find(value)->second;
46     return intervals.empty() ? llvm::None
47                              : llvm::Optional<size_t>(intervals.begin()->first);
48   }
49 
50   /// Checks if the use intervals of the given values interfere.
51   bool rangesInterfere(Value itemA, Value itemB) const;
52 
53   /// Merges the userange of itemB into the userange of itemA.
54   /// Note: This assumes that there is no interference between the two
55   /// ranges.
56   void unionRanges(Value itemA, Value itemB);
57 
58   /// Dumps the liveness information to the given stream.
59   void dump(raw_ostream &os);
60 
61  private:
62   using ValueSetT = BufferViewFlowAnalysis::ValueSetT;
63   using OperationListT = Liveness::OperationListT;
64 
65   /// Builds an IntervalVector corresponding to the given OperationList.
66   IntervalVector computeInterval(Value value,
67                                  const Liveness::OperationListT &operationList);
68 
69   /// Checks each operand of the operation for its memory effects and separates
70   /// them into read and write. Operands with read or write effects are added
71   /// to the opReadWriteMap.
72   void gatherMemoryEffects(Operation *op);
73 
74   /// Computes the ID for the operation. If the operation contains operands
75   /// which have read effects, the returning ID will be odd.
76   size_t computeID(Value v, Operation *op) const;
77 
78   /// Merge two IntervalVectors into a new IntervalVector. Return a pair with
79   /// the resulting IntervalVector and a boolean if there were interferences
80   /// during merging.
81   std::pair<IntervalVector, bool> intervalMerge(
82       const IntervalVector &intervalA, const IntervalVector &intervalB) const;
83 
84   /// Performs an interval union of the interval vectors from the given values.
85   /// Returns an empty Optional if there is an interval interference.
86   bool intervalUnion(Value itemA, Value itemB) const;
87 
88   /// Performs an interval subtraction => A = A - B.
89   /// Note: This assumes that all intervals of b are included in some interval
90   ///       of a.
91   void intervalSubtract(IntervalVector &a, const IntervalVector &b) const;
92 
93   /// Maps each Operation to a unique ID according to the program sequence.
94   DenseMap<Operation *, size_t> operationIds;
95 
96   /// Maps a value to its use range interval.
97   DenseMap<Value, IntervalVector> useIntervalMap;
98 
99   /// Maps an Operation to a pair of read and write Operands.
100   DenseMap<Operation *, std::pair<SmallPtrSet<Value, 2>, SmallPtrSet<Value, 2>>>
101       opReadWriteMap;
102 
103   /// Maps aliasValues to their use ranges. This is necessary to prevent
104   /// recomputations of the use range intervals of the aliases.
105   DenseMap<Value, OperationListT> aliasUseranges;
106 
107   /// Cache the alias lists for all values to avoid recomputation.
108   BufferViewFlowAnalysis::ValueMapT aliasCache;
109 
110   /// The current liveness info.
111   Liveness liveness;
112 };
113 
114 }  // namespace mlir
115 
116 #endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H_
117