xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_reachability.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
18 
19 #include <cstdio>
20 #include <list>
21 #include <vector>
22 
23 #include "absl/base/casts.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/core/lib/core/status.h"
32 
33 namespace xla {
34 
35 // A class for representing reachability between HloInstructions.
36 //
37 // It has an adjacency matrix and it is up to the user of the class to set the
38 // adjacency matrix such that it represents reachability, i.e. such that it is
39 // transitive. That the graph be transitive is thus not an invariant of this
40 // class, but it is required for the name of the class and its methods to make
41 // sense.
42 class HloReachabilityMap {
43  public:
44   // An opaque index that clients can use to make repeated operations for the
45   // same instruction faster, by calling GetIndex once for the instruction,
46   // and then calling the variants of other interfaces that take Index arguments
47   // rather than HloInstruction* arguments.
48   struct Index {
49    public:
50     bool operator==(Index other) const { return v == other.v; }
51     bool operator!=(Index other) const { return v != other.v; }
52 
53    private:
54     friend class HloReachabilityMap;
55 
56     // Index assigned for a particular instruction.  The value is used to index
57     // into the vector of BitVectors and the BitVectors themselves.
58     int v;
59   };
60   // Sets up a graph with no edges and where the nodes correspond to the given
61   // instructions.
62   explicit HloReachabilityMap(
63       absl::Span<const HloInstruction* const> instructions);
64 
65   // Computes and returns the reachability between HLO instructions in the
66   // computation. The returned HloReachabilityMap is constructed such that
67   // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a
68   // directed path (from producer to consumer) from 'a' to 'b'. Both data
69   // dependencies (operands) and control dependencies are considered for
70   // reachability. Trivially an instruction is reachable from itself.
71   static std::unique_ptr<HloReachabilityMap> Build(
72       const HloComputation* computation);
73 
74   // Similar to the above Build operation except that it tries to identify
75   // paths between instructions that do not contain control instructions
76   // and multiple operands, i.e., b is_reachable a == true iff
77   // b = f(f(f(f(f(a), constant), constant), constant).
78   // Further, the only ops allowed in a path are basic math operations such
79   // as add, sub, mul, div.
80   static std::unique_ptr<HloReachabilityMap> BuildWithRestrictions(
81       const HloComputation* computation,
82       absl::FunctionRef<void(const HloInstruction*,
83                              std::vector<HloInstruction*>*)>
84           add_dependencies);
85 
86   // Set the reachability set of 'instruction' to the union of the reachability
87   // sets of 'inputs'. Upon return, IsReachable(x, instruction) where
88   // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true
89   // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from
90   // itself. Returns whether the reachability set of 'instruction' changed.
91   //
92   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
93   // vector in the internal graph of this HloReachabilityMap for the given
94   // instruction and does not transitively update any other part of the
95   // adjacency matrix.
96   bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs,
97                               const HloInstruction* instruction);
98 
99   // As above, but faster because it does not check if the reachability changed.
100   void FastSetReachabilityToUnion(
101       absl::Span<const HloInstruction* const> inputs,
102       const HloInstruction* instruction);
103   // As above, but use Index instead if it's already looked up which is even
104   // faster since no hash map lookup will occur.
105   void FastSetReachabilityToUnion(absl::Span<const Index> input_indices,
106                                   Index index);
107 
GetIndex(const HloInstruction * instruction)108   Index GetIndex(const HloInstruction* instruction) const {
109     Index i;
110     i.v = FindOrDie(indices_, GetKey(instruction));
111     return i;
112   }
113 
114   // Sets entry so that IsReachable(a, b) will return true
115   //
116   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
117   // matrix in the internal graph of this HloReachabilityMap to have an edge
118   // from a to b and does not transitively update any other part of the
119   // adjacency matrix.
SetReachable(const HloInstruction * a,const HloInstruction * b)120   void SetReachable(const HloInstruction* a, const HloInstruction* b) {
121     SetReachable(GetIndex(a), GetIndex(b));
122   }
123   void SetReachable(Index a, Index b);
124 
125   // Updates the given reachability map after the immediate predecessor set
126   // (operands and control predecessors) of 'instruction' has changed.
127   void UpdateReachabilityThroughInstruction(const HloInstruction* instruction);
128 
129   // Returns true if "b" is reachable from "a"
130   //
131   // Note that this function only correctly answers queries about reachability
132   // if the set of edges that have been provided to this class are transitive.
IsReachable(const HloInstruction * a,const HloInstruction * b)133   bool IsReachable(const HloInstruction* a, const HloInstruction* b) const {
134     return IsReachable(GetIndex(a), GetIndex(b));
135   }
IsReachable(Index a,Index b)136   bool IsReachable(Index a, Index b) const { return GetBitVector(b).Get(a.v); }
137 
138   // Returns true if "b" is reachable from "a" or "a" is reachable from "b"
139   //
140   // Note that this function only correctly answers queries about reachability
141   // if the set of edges that have been provided to this class are transitive.
IsConnected(const HloInstruction * a,const HloInstruction * b)142   bool IsConnected(const HloInstruction* a, const HloInstruction* b) const {
143     return IsConnected(GetIndex(a), GetIndex(b));
144   }
IsConnected(Index a,Index b)145   bool IsConnected(Index a, Index b) const {
146     return IsReachable(a, b) || IsReachable(b, a);
147   }
148 
149   // Checks if an instruction is in the Reachability map.
IsPresent(const HloInstruction * a)150   bool IsPresent(const HloInstruction* a) const {
151     return indices_.contains(GetKey(a));
152   }
153 
154   // Replace the instruction "original" with "replacement" in the reachability
155   // map.
156   void Replace(const HloInstruction* original,
157                const HloInstruction* replacement);
158 
159  private:
160   // A bit-vector implementation specialized for this use case which provides a
161   // fast bitwise OR operation not available in tensorflow::gtl::BitMap.
162   class BitVector {
163    public:
164     BitVector() = default;
BitVector(size_t size)165     BitVector(size_t size)
166         : size_(size), vector_((size + kBits - 1) / kBits, 0) {}
167 
168     // Return the bit at the given index.
Get(size_t index)169     bool Get(size_t index) const {
170       DCHECK(index >= 0 && index < size_);
171       return vector_[index / kBits] & (1ull << (index % kBits));
172     }
173 
174     // Set the bit at the given index.
Set(size_t index)175     void Set(size_t index) {
176       DCHECK(index >= 0 && index < size_);
177       vector_[index / kBits] |= 1ull << (index % kBits);
178     }
179 
180     // Set this bitvector to the Logical OR of this bitvector and 'other'.
OrWith(const BitVector & other)181     void OrWith(const BitVector& other) {
182       for (size_t i = 0; i < vector_.size(); ++i) {
183         vector_[i] |= other.vector_[i];
184       }
185     }
186 
187     // Set the bitvector to all zeros.
SetToZero()188     void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); }
189 
190     bool operator==(const BitVector& other) const {
191       return vector_ == other.vector_;
192     }
193     bool operator!=(const BitVector& other) const {
194       return vector_ != other.vector_;
195     }
196 
197    private:
198     using Word = uint64_t;
199     static constexpr size_t kBits = 64;
200 
201     // Number of bits in the bitvector.
202     size_t size_;
203 
204     std::vector<Word> vector_;
205   };
206 
207   // Return the bitvector storing the reachability-to of the given instruction.
GetBitVector(const HloInstruction * instruction)208   const BitVector& GetBitVector(const HloInstruction* instruction) const {
209     return GetBitVector(GetIndex(instruction));
210   }
GetBitVector(const HloInstruction * instruction)211   BitVector& GetBitVector(const HloInstruction* instruction) {
212     return GetBitVector(GetIndex(instruction));
213   }
214 
GetBitVector(Index index)215   const BitVector& GetBitVector(Index index) const {
216     return bit_vectors_[index.v];
217   }
GetBitVector(Index index)218   BitVector& GetBitVector(Index index) { return bit_vectors_[index.v]; }
219 
220   // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
221   void SetReachabilityToUnionHelper(
222       absl::Span<const HloInstruction* const> inputs, Index index);
223   void SetReachabilityToUnionHelper(absl::Span<const Index> input_indices,
224                                     Index index);
225 
GetKey(const HloInstruction * instruction)226   uint64_t GetKey(const HloInstruction* instruction) const {
227     uint64_t unique_id = absl::bit_cast<uint32_t>(instruction->unique_id());
228     uint64_t module_id =
229         absl::bit_cast<uint32_t>(instruction->parent()->parent()->unique_id());
230     return (module_id << 32) | unique_id;
231   }
232   // Return the index of the given instruction.
GetIndexInternal(const HloInstruction * instruction)233   int GetIndexInternal(const HloInstruction* instruction) const {
234     return FindOrDie(indices_, GetKey(instruction));
235   }
236 
237   // The number of instructions in the reachability map.
238   const size_t size_;
239 
240   // Dense assignment from HloInstruction::unique_id to number. These numbers
241   // index into the bit_vectors_ vector and into the bits within a BitVector.
242   absl::flat_hash_map<uint64_t, int> indices_;
243 
244   // Bitvectors holding the reachability to each instruction. The bit vector for
245   // instruction X includes ones for each instruction which X is reachable from.
246   std::vector<BitVector> bit_vectors_;
247 
248   // A temporary used by SetReachabilityToUnion to avoid an allocation with each
249   // call to the method.
250   BitVector tmp_bit_vector_;
251 };
252 
253 }  // namespace xla
254 
255 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
256