1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 18 19 #include <cstdio> 20 #include <list> 21 #include <vector> 22 23 #include "absl/base/casts.h" 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/map_util.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/core/lib/core/status.h" 32 33 namespace xla { 34 35 // A class for representing reachability between HloInstructions. 36 // 37 // It has an adjacency matrix and it is up to the user of the class to set the 38 // adjacency matrix such that it represents reachability, i.e. such that it is 39 // transitive. That the graph be transitive is thus not an invariant of this 40 // class, but it is required for the name of the class and its methods to make 41 // sense. 42 class HloReachabilityMap { 43 public: 44 // An opaque index that clients can use to make repeated operations for the 45 // same instruction faster, by calling GetIndex once for the instruction, 46 // and then calling the variants of other interfaces that take Index arguments 47 // rather than HloInstruction* arguments. 48 struct Index { 49 public: 50 bool operator==(Index other) const { return v == other.v; } 51 bool operator!=(Index other) const { return v != other.v; } 52 53 private: 54 friend class HloReachabilityMap; 55 56 // Index assigned for a particular instruction. The value is used to index 57 // into the vector of BitVectors and the BitVectors themselves. 58 int v; 59 }; 60 // Sets up a graph with no edges and where the nodes correspond to the given 61 // instructions. 62 explicit HloReachabilityMap( 63 absl::Span<const HloInstruction* const> instructions); 64 65 // Computes and returns the reachability between HLO instructions in the 66 // computation. The returned HloReachabilityMap is constructed such that 67 // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a 68 // directed path (from producer to consumer) from 'a' to 'b'. Both data 69 // dependencies (operands) and control dependencies are considered for 70 // reachability. Trivially an instruction is reachable from itself. 71 static std::unique_ptr<HloReachabilityMap> Build( 72 const HloComputation* computation); 73 74 // Similar to the above Build operation except that it tries to identify 75 // paths between instructions that do not contain control instructions 76 // and multiple operands, i.e., b is_reachable a == true iff 77 // b = f(f(f(f(f(a), constant), constant), constant). 78 // Further, the only ops allowed in a path are basic math operations such 79 // as add, sub, mul, div. 80 static std::unique_ptr<HloReachabilityMap> BuildWithRestrictions( 81 const HloComputation* computation, 82 absl::FunctionRef<void(const HloInstruction*, 83 std::vector<HloInstruction*>*)> 84 add_dependencies); 85 86 // Set the reachability set of 'instruction' to the union of the reachability 87 // sets of 'inputs'. Upon return, IsReachable(x, instruction) where 88 // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true 89 // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from 90 // itself. Returns whether the reachability set of 'instruction' changed. 91 // 92 // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency 93 // vector in the internal graph of this HloReachabilityMap for the given 94 // instruction and does not transitively update any other part of the 95 // adjacency matrix. 96 bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs, 97 const HloInstruction* instruction); 98 99 // As above, but faster because it does not check if the reachability changed. 100 void FastSetReachabilityToUnion( 101 absl::Span<const HloInstruction* const> inputs, 102 const HloInstruction* instruction); 103 // As above, but use Index instead if it's already looked up which is even 104 // faster since no hash map lookup will occur. 105 void FastSetReachabilityToUnion(absl::Span<const Index> input_indices, 106 Index index); 107 GetIndex(const HloInstruction * instruction)108 Index GetIndex(const HloInstruction* instruction) const { 109 Index i; 110 i.v = FindOrDie(indices_, GetKey(instruction)); 111 return i; 112 } 113 114 // Sets entry so that IsReachable(a, b) will return true 115 // 116 // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency 117 // matrix in the internal graph of this HloReachabilityMap to have an edge 118 // from a to b and does not transitively update any other part of the 119 // adjacency matrix. SetReachable(const HloInstruction * a,const HloInstruction * b)120 void SetReachable(const HloInstruction* a, const HloInstruction* b) { 121 SetReachable(GetIndex(a), GetIndex(b)); 122 } 123 void SetReachable(Index a, Index b); 124 125 // Updates the given reachability map after the immediate predecessor set 126 // (operands and control predecessors) of 'instruction' has changed. 127 void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); 128 129 // Returns true if "b" is reachable from "a" 130 // 131 // Note that this function only correctly answers queries about reachability 132 // if the set of edges that have been provided to this class are transitive. IsReachable(const HloInstruction * a,const HloInstruction * b)133 bool IsReachable(const HloInstruction* a, const HloInstruction* b) const { 134 return IsReachable(GetIndex(a), GetIndex(b)); 135 } IsReachable(Index a,Index b)136 bool IsReachable(Index a, Index b) const { return GetBitVector(b).Get(a.v); } 137 138 // Returns true if "b" is reachable from "a" or "a" is reachable from "b" 139 // 140 // Note that this function only correctly answers queries about reachability 141 // if the set of edges that have been provided to this class are transitive. IsConnected(const HloInstruction * a,const HloInstruction * b)142 bool IsConnected(const HloInstruction* a, const HloInstruction* b) const { 143 return IsConnected(GetIndex(a), GetIndex(b)); 144 } IsConnected(Index a,Index b)145 bool IsConnected(Index a, Index b) const { 146 return IsReachable(a, b) || IsReachable(b, a); 147 } 148 149 // Checks if an instruction is in the Reachability map. IsPresent(const HloInstruction * a)150 bool IsPresent(const HloInstruction* a) const { 151 return indices_.contains(GetKey(a)); 152 } 153 154 // Replace the instruction "original" with "replacement" in the reachability 155 // map. 156 void Replace(const HloInstruction* original, 157 const HloInstruction* replacement); 158 159 private: 160 // A bit-vector implementation specialized for this use case which provides a 161 // fast bitwise OR operation not available in tensorflow::gtl::BitMap. 162 class BitVector { 163 public: 164 BitVector() = default; BitVector(size_t size)165 BitVector(size_t size) 166 : size_(size), vector_((size + kBits - 1) / kBits, 0) {} 167 168 // Return the bit at the given index. Get(size_t index)169 bool Get(size_t index) const { 170 DCHECK(index >= 0 && index < size_); 171 return vector_[index / kBits] & (1ull << (index % kBits)); 172 } 173 174 // Set the bit at the given index. Set(size_t index)175 void Set(size_t index) { 176 DCHECK(index >= 0 && index < size_); 177 vector_[index / kBits] |= 1ull << (index % kBits); 178 } 179 180 // Set this bitvector to the Logical OR of this bitvector and 'other'. OrWith(const BitVector & other)181 void OrWith(const BitVector& other) { 182 for (size_t i = 0; i < vector_.size(); ++i) { 183 vector_[i] |= other.vector_[i]; 184 } 185 } 186 187 // Set the bitvector to all zeros. SetToZero()188 void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); } 189 190 bool operator==(const BitVector& other) const { 191 return vector_ == other.vector_; 192 } 193 bool operator!=(const BitVector& other) const { 194 return vector_ != other.vector_; 195 } 196 197 private: 198 using Word = uint64_t; 199 static constexpr size_t kBits = 64; 200 201 // Number of bits in the bitvector. 202 size_t size_; 203 204 std::vector<Word> vector_; 205 }; 206 207 // Return the bitvector storing the reachability-to of the given instruction. GetBitVector(const HloInstruction * instruction)208 const BitVector& GetBitVector(const HloInstruction* instruction) const { 209 return GetBitVector(GetIndex(instruction)); 210 } GetBitVector(const HloInstruction * instruction)211 BitVector& GetBitVector(const HloInstruction* instruction) { 212 return GetBitVector(GetIndex(instruction)); 213 } 214 GetBitVector(Index index)215 const BitVector& GetBitVector(Index index) const { 216 return bit_vectors_[index.v]; 217 } GetBitVector(Index index)218 BitVector& GetBitVector(Index index) { return bit_vectors_[index.v]; } 219 220 // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. 221 void SetReachabilityToUnionHelper( 222 absl::Span<const HloInstruction* const> inputs, Index index); 223 void SetReachabilityToUnionHelper(absl::Span<const Index> input_indices, 224 Index index); 225 GetKey(const HloInstruction * instruction)226 uint64_t GetKey(const HloInstruction* instruction) const { 227 uint64_t unique_id = absl::bit_cast<uint32_t>(instruction->unique_id()); 228 uint64_t module_id = 229 absl::bit_cast<uint32_t>(instruction->parent()->parent()->unique_id()); 230 return (module_id << 32) | unique_id; 231 } 232 // Return the index of the given instruction. GetIndexInternal(const HloInstruction * instruction)233 int GetIndexInternal(const HloInstruction* instruction) const { 234 return FindOrDie(indices_, GetKey(instruction)); 235 } 236 237 // The number of instructions in the reachability map. 238 const size_t size_; 239 240 // Dense assignment from HloInstruction::unique_id to number. These numbers 241 // index into the bit_vectors_ vector and into the bits within a BitVector. 242 absl::flat_hash_map<uint64_t, int> indices_; 243 244 // Bitvectors holding the reachability to each instruction. The bit vector for 245 // instruction X includes ones for each instruction which X is reachable from. 246 std::vector<BitVector> bit_vectors_; 247 248 // A temporary used by SetReachabilityToUnion to avoid an allocation with each 249 // call to the method. 250 BitVector tmp_bit_vector_; 251 }; 252 253 } // namespace xla 254 255 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 256