1 #pragma once 2 #include <c10/core/ScalarType.h> 3 #include <torch/csrc/Export.h> 4 #include <utility> 5 #include <vector> 6 7 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h> 8 #include <torch/csrc/jit/tensorexpr/ir_mutator.h> 9 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> 10 #include <torch/csrc/jit/tensorexpr/ir_visitor.h> 11 #include <torch/csrc/jit/tensorexpr/stmt.h> 12 13 namespace torch::jit::tensorexpr::analysis { 14 15 enum class AccessType { 16 Input, 17 Output, 18 Load, 19 Store, 20 Call, 21 AtomicAdd, 22 Alloc, 23 Free 24 }; 25 const char* AccessToString(AccessType a); 26 27 class AccessInfo; 28 using DependencySet = std::unordered_set<std::shared_ptr<AccessInfo>>; 29 30 /* AccessInfo 31 * 32 * Represents a single bounded memory access to a buffer, for instance a Load or 33 * a Store. Holds information relating to the specific access and links to 34 * connected accesses in the dependency graph. 35 */ 36 class TORCH_API AccessInfo { 37 public: AccessInfo(size_t id,AccessType type,StmtPtr stmt,VarPtr var,IndexBounds bounds)38 AccessInfo( 39 size_t id, 40 AccessType type, 41 StmtPtr stmt, 42 VarPtr var, 43 IndexBounds bounds) 44 : id_(id), 45 type_(type), 46 stmt_(std::move(stmt)), 47 expr_(nullptr), 48 var_(std::move(var)), 49 bounds_(std::move(bounds)) {} 50 AccessInfo(size_t id,AccessType type,ExprPtr expr,StmtPtr stmt,VarPtr var,IndexBounds bounds)51 AccessInfo( 52 size_t id, 53 AccessType type, 54 ExprPtr expr, 55 StmtPtr stmt, 56 VarPtr var, 57 IndexBounds bounds) 58 : id_(id), 59 type_(type), 60 stmt_(std::move(stmt)), 61 expr_(std::move(expr)), 62 var_(std::move(var)), 63 bounds_(std::move(bounds)) {} 64 65 // Id is a unique int representing the order this access occurred in the 66 // graph. id()67 size_t id() const { 68 return id_; 69 } 70 71 // The type of the access (Load, Store, etc). type()72 AccessType type() const { 73 return type_; 74 } 75 76 // The enclosing Stmt this access represents. E.g. if this is a Store then 77 // Stmt is the Store itself, while if the access is caused by an Expr, this is 78 // the most immediate parent Stmt. stmt()79 StmtPtr stmt() const { 80 return stmt_; 81 } 82 83 // If the access is represented by an Expr (such as Load or Call) then this is 84 // it, otherwise it's nullptr. expr()85 ExprPtr expr() const { 86 return expr_; 87 } 88 89 // The Var representing the underlying Buffer. var()90 VarPtr var() const { 91 return var_; 92 } 93 94 // A vector of Bounds representing the start and end expression for each 95 // dimension. bounds()96 IndexBounds& bounds() { 97 return bounds_; 98 } 99 100 // Each access that this depends upon, 101 // eg. if this is a Load, then it contains every Store that immediately 102 // contributes to a load of the bounds. 103 // or: if this is a Store, it contains all reads on the RHS of the Store. dependencies()104 const std::map<size_t, std::shared_ptr<AccessInfo>>& dependencies() const { 105 return dependencies_; 106 } 107 108 // Each access that depends on this one. 109 // ie. this access is present in the dependencies map of all accesses that are 110 // dependent. dependents()111 std::map<size_t, std::shared_ptr<AccessInfo>> dependents() const { 112 std::map<size_t, std::shared_ptr<AccessInfo>> res; 113 for (const auto& kv : dependents_) { 114 res.emplace(kv.first, kv.second.lock()); 115 } 116 return res; 117 } 118 119 // Returns the symbolic expression of the indices of this access. 120 std::vector<ExprPtr> getIndices() const; 121 122 // Establishes a dependency or dependent relationship with another access. 123 void addDependency(const std::shared_ptr<AccessInfo>& write); 124 void addDependent(const std::shared_ptr<AccessInfo>& read); 125 126 // helper for checking dependencies. 127 bool hasDependency(const std::shared_ptr<AccessInfo>& info) const; 128 129 // Returns the set of all nodes that are direct (immediate) dependencies of 130 // this access. 131 DependencySet getDirectDependencies(); 132 // likewise, returns all nodes that directly depend on this one. 133 DependencySet getDirectDependents(); 134 135 // Returns the full list of all nodes in the graph that this access depends 136 // on, and all nodes they depend on, and so forth, back to the inputs. 137 DependencySet getIndirectDependencies(); 138 // likewise, returns the full list of all nodes that depend on this node, and 139 // all nodes that depend on those nodes and so on down to the outputs. 140 DependencySet getIndirectDependents(); 141 142 // Does this access represent a read of memory (Load, ReduceOp, Call, etc). 143 bool isRead() const; 144 // Does this access represent a write of memory (Store, etc). 145 bool isWrite() const; 146 147 // Helpers for dumping accesses in various formats. 148 void print() const; 149 void dumpDOT(std::ostream& os) const; 150 const char* AccessTypeColour() const; 151 152 private: 153 size_t id_; 154 AccessType type_; 155 StmtPtr stmt_; 156 ExprPtr expr_; 157 VarPtr var_; 158 IndexBounds bounds_; 159 160 // Yes these should be sorted. 161 std::map<size_t, std::shared_ptr<AccessInfo>> dependencies_; 162 std::map<size_t, std::weak_ptr<AccessInfo>> dependents_; 163 }; 164 165 using VarBoundMap = std::unordered_map<VarPtr, Bound>; 166 167 /* MemDependencyChecker analyses a IR fragment and builds a dependency graph of 168 * accesses contained within. 169 * 170 * It's possible to retrieve the entire graph in node-object form, or can be 171 * used as an oracle for answering dependency questions. e.g: 172 * 173 * analyzer.hasIndirectDependency(BufA, BufB); or, 174 * analyzer.hasDirectDependency(LoadA, StoreB); 175 */ 176 class TORCH_API MemDependencyChecker : public IRVisitor { 177 struct Scope; 178 179 public: 180 MemDependencyChecker(); 181 MemDependencyChecker( 182 const std::unordered_set<BufPtr>& inputs, 183 const std::unordered_set<BufPtr>& outputs); 184 MemDependencyChecker( 185 const std::vector<BufHandle>& inputs, 186 const std::vector<BufHandle>& outputs); 187 188 ~MemDependencyChecker() override = default; 189 190 // Whether or not to allow loop execution order to influence dependency 191 // calculation. If the loop may later be parallelized you don't want this. 192 bool allowLoopExecutionOrderAnalysis(bool allow = true); 193 194 // Dependency Checking API. 195 // The goal is to have enough overloads here so you don't really have to think 196 // about it. 197 198 // Returns true if any read in A has a direct dependence on a write in B. 199 bool dependsDirectly(const StmtPtr& A, const StmtPtr& B); 200 bool dependsDirectly(const ExprPtr& A, const StmtPtr& B); 201 202 // Returns true of the output depends directly on a write contained in B. 203 bool dependsDirectly(const BufPtr& output, const StmtPtr& B); 204 205 // Returns true if a read in A depends directly on the provided input. 206 bool dependsDirectly(const StmtPtr& A, const BufPtr& input); 207 bool dependsDirectly(const ExprPtr& A, const BufPtr& input); 208 209 // Outputs/inputs cannot depend directly. 210 211 // Returns true if the access A has B as an immediate dependency. 212 bool dependsDirectly( 213 const std::shared_ptr<AccessInfo>& A, 214 const std::shared_ptr<AccessInfo>& B); 215 216 // Returns true if any read in A has an ancestor write contained in B. 217 bool dependsIndirectly(const StmtPtr& A, const StmtPtr& B); 218 bool dependsIndirectly(const ExprPtr& A, const StmtPtr& B); 219 220 // Returns true of the output depends indirectly on a write contained in B. 221 bool dependsIndirectly(const BufPtr& output, const StmtPtr& B); 222 223 // Returns true if a read in A depends indirectly on the provided input. 224 bool dependsIndirectly(const StmtPtr& A, const BufPtr& input); 225 bool dependsIndirectly(const ExprPtr& A, const BufPtr& input); 226 227 // returns true if the output uses any load of the input. 228 bool dependsIndirectly(const BufPtr& output, const BufPtr& input); 229 230 // Returns true if the access A has a dependency chain to access B. 231 bool dependsIndirectly( 232 const std::shared_ptr<AccessInfo>& A, 233 const std::shared_ptr<AccessInfo>& B); 234 235 // Returns the AccessInfo 236 std::shared_ptr<AccessInfo> accessFor(const StmtPtr& A) const; 237 std::shared_ptr<AccessInfo> accessFor(const ExprPtr& A) const; 238 239 // Returns all AccessInfos. 240 std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin( 241 const StmtPtr& A) const; 242 // TODO: this will return only the AccessInfo for A. It's included for 243 // completeness but be aware it wont return accesses used in the computation 244 // of A. 245 std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin( 246 const ExprPtr& A) const; 247 248 // Accesses relating to input and output buffers. 249 std::shared_ptr<AccessInfo> input(const BufPtr& B) const; 250 std::shared_ptr<AccessInfo> output(const BufPtr& B) const; 251 252 // Returns the full history of reads and writes. 253 const std::vector<std::shared_ptr<AccessInfo>>& getHistory() const; 254 255 // Dumps the dependency graph in DOT format. 256 void dumpDAG(const std::string& filename) const; 257 258 private: 259 // Node visitors. 260 void visit(const StorePtr& v) override; 261 void visit(const LoadPtr& v) override; 262 void visit(const ForPtr& v) override; 263 void visit(const CondPtr& v) override; 264 void visit(const IfThenElsePtr& v) override; 265 void visit(const CompareSelectPtr& v) override; 266 void visit(const BlockPtr& v) override; 267 void visit(const LetPtr& v) override; 268 void visit(const AtomicAddPtr& v) override; 269 void visit(const AllocatePtr& v) override; 270 void visit(const FreePtr& v) override; 271 272 using BoundRelationship = std::pair<IndexBounds, std::shared_ptr<AccessInfo>>; 273 274 // An internal struct holding the accesses found within a scope Block. 275 struct Scope { ScopeScope276 Scope(BlockPtr b, std::shared_ptr<Scope> p) 277 : block(std::move(b)), parent(std::move(p)) {} 278 279 BlockPtr block; 280 std::shared_ptr<Scope> parent; 281 282 std::unordered_map<VarPtr, Bound> shadowedVarBounds; 283 std::unordered_set<VarPtr> localVars; 284 285 std::vector<std::shared_ptr<AccessInfo>> accesses_; 286 287 std::unordered_map<VarPtr, std::list<BoundRelationship>> openWrites_; 288 }; 289 std::shared_ptr<Scope> currentScope_; 290 291 bool allowExecutionOrderAnalysis_{false}; 292 293 std::unordered_multimap<StmtPtr, std::shared_ptr<AccessInfo>> stmtToAccess_; 294 std::unordered_multimap<ExprPtr, std::shared_ptr<AccessInfo>> exprToAccess_; 295 std::unordered_map<StmtPtr, std::vector<std::shared_ptr<AccessInfo>>> 296 scopeToAccesses_; 297 298 VarBoundMap knownVarBounds_; 299 300 // Finds all accesses that are reads within the scope of v. 301 template <typename StmtOrExprPtr> getAllReadsWithin(const StmtOrExprPtr & v)302 DependencySet getAllReadsWithin(const StmtOrExprPtr& v) { 303 DependencySet reads; 304 auto insertAllReads = [&](const auto& nodes) { 305 for (const auto& l : nodes) { 306 auto bound = exprToAccess_.equal_range(l); 307 for (auto it = bound.first; it != bound.second; ++it) { 308 if (it->second->isRead()) { 309 reads.insert(it->second); 310 } 311 } 312 } 313 }; 314 315 // Look for and insert accesses belonging to all nodes that act like 316 // reads. 317 insertAllReads(NodeFinder<Load>::find(v)); 318 insertAllReads(NodeFinder<ReduceOp>::find(v)); 319 320 return reads; 321 } 322 323 // Finds all accesses that are writes within the scope of v. 324 // Writes cannot occur in Exprs, so this is a little simpler. getAllWritesWithin(const StmtPtr & v)325 DependencySet getAllWritesWithin(const StmtPtr& v) { 326 DependencySet writes; 327 328 // writes just Store currently. 329 auto stores = NodeFinder<Store>::find(v); 330 for (const auto& s : stores) { 331 auto bound = stmtToAccess_.equal_range(s); 332 for (auto it = bound.first; it != bound.second; ++it) { 333 if (it->second->isWrite()) { 334 writes.insert(it->second); 335 } 336 } 337 } 338 return writes; 339 } 340 341 // Templated helpers to work on either Exprs or Stmts. 342 template <typename StmtOrExprPtr> dependsDirectlyHelper(const StmtOrExprPtr & A,const StmtPtr & B)343 bool dependsDirectlyHelper(const StmtOrExprPtr& A, const StmtPtr& B) { 344 auto aReads = getAllReadsWithin(A); 345 auto bWrites = getAllWritesWithin(B); 346 347 for (auto& read : aReads) { 348 for (auto& depPair : read->dependencies()) { 349 if (bWrites.count(depPair.second) != 0) { 350 return true; 351 } 352 } 353 } 354 355 return false; 356 } 357 358 template <typename StmtOrExprPtr> dependsIndirectlyHelper(StmtOrExprPtr A,const StmtPtr & B)359 bool dependsIndirectlyHelper(StmtOrExprPtr A, const StmtPtr& B) { 360 auto aReads = getAllReadsWithin(A); 361 auto bWrites = getAllWritesWithin(B); 362 363 auto aDeps = getAllWriteDependencies(aReads); 364 365 for (auto& dependency : aDeps) { 366 if (bWrites.count(dependency) != 0) { 367 return true; 368 } 369 } 370 371 return false; 372 } 373 374 DependencySet getAllWriteDependencies(const DependencySet& products); 375 376 // Maps for inputs and outputs, since they aren't present directly in the IR. 377 std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>> inputs_; 378 std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>> outputs_; 379 std::unordered_map<VarPtr, std::shared_ptr<AccessInfo>> intermediates_; 380 381 // Inserts accesses for Buf's: specifically for inputs and outputs. 382 void insertBuffers( 383 std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>>& bufs, 384 AccessType type); 385 386 // Update the write history with a new write, adding dependencies and closing 387 // any overlapped writes (if possible). 388 void updateWriteHistory( 389 std::list<BoundRelationship>& writeHistory, 390 const std::shared_ptr<AccessInfo>& info, 391 size_t latestAccessToClose, 392 bool closeOverlapped = true, 393 bool insert = true); 394 395 // Merge a child scope into a parent scope, adding dependencies for open 396 // writes in the parent to accesses in the child. 397 void mergeScope( 398 const std::shared_ptr<Scope>& child, 399 const std::shared_ptr<Scope>& parent, 400 bool closeOverlapped = true); 401 402 // Binds symbolic vars in indices with the low and high bound for those vars. 403 std::vector<Bound> getIndicesBounds(const std::vector<ExprPtr>& indices); 404 405 size_t nextAccess_{0}; 406 StmtPtr lastStmt_{nullptr}; 407 }; 408 409 } // namespace torch::jit::tensorexpr::analysis 410