xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/mem_dependency_checker.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/core/ScalarType.h>
3 #include <torch/csrc/Export.h>
4 #include <utility>
5 #include <vector>
6 
7 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
8 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
9 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
10 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
11 #include <torch/csrc/jit/tensorexpr/stmt.h>
12 
13 namespace torch::jit::tensorexpr::analysis {
14 
15 enum class AccessType {
16   Input,
17   Output,
18   Load,
19   Store,
20   Call,
21   AtomicAdd,
22   Alloc,
23   Free
24 };
25 const char* AccessToString(AccessType a);
26 
27 class AccessInfo;
28 using DependencySet = std::unordered_set<std::shared_ptr<AccessInfo>>;
29 
30 /* AccessInfo
31  *
32  * Represents a single bounded memory access to a buffer, for instance a Load or
33  * a Store. Holds information relating to the specific access and links to
34  * connected accesses in the dependency graph.
35  */
36 class TORCH_API AccessInfo {
37  public:
AccessInfo(size_t id,AccessType type,StmtPtr stmt,VarPtr var,IndexBounds bounds)38   AccessInfo(
39       size_t id,
40       AccessType type,
41       StmtPtr stmt,
42       VarPtr var,
43       IndexBounds bounds)
44       : id_(id),
45         type_(type),
46         stmt_(std::move(stmt)),
47         expr_(nullptr),
48         var_(std::move(var)),
49         bounds_(std::move(bounds)) {}
50 
AccessInfo(size_t id,AccessType type,ExprPtr expr,StmtPtr stmt,VarPtr var,IndexBounds bounds)51   AccessInfo(
52       size_t id,
53       AccessType type,
54       ExprPtr expr,
55       StmtPtr stmt,
56       VarPtr var,
57       IndexBounds bounds)
58       : id_(id),
59         type_(type),
60         stmt_(std::move(stmt)),
61         expr_(std::move(expr)),
62         var_(std::move(var)),
63         bounds_(std::move(bounds)) {}
64 
65   // Id is a unique int representing the order this access occurred in the
66   // graph.
id()67   size_t id() const {
68     return id_;
69   }
70 
71   // The type of the access (Load, Store, etc).
type()72   AccessType type() const {
73     return type_;
74   }
75 
76   // The enclosing Stmt this access represents. E.g. if this is a Store then
77   // Stmt is the Store itself, while if the access is caused by an Expr, this is
78   // the most immediate parent Stmt.
stmt()79   StmtPtr stmt() const {
80     return stmt_;
81   }
82 
83   // If the access is represented by an Expr (such as Load or Call) then this is
84   // it, otherwise it's nullptr.
expr()85   ExprPtr expr() const {
86     return expr_;
87   }
88 
89   // The Var representing the underlying Buffer.
var()90   VarPtr var() const {
91     return var_;
92   }
93 
94   // A vector of Bounds representing the start and end expression for each
95   // dimension.
bounds()96   IndexBounds& bounds() {
97     return bounds_;
98   }
99 
100   // Each access that this depends upon,
101   // eg. if this is a Load, then it contains every Store that immediately
102   // contributes to a load of the bounds.
103   // or: if this is a Store, it contains all reads on the RHS of the Store.
dependencies()104   const std::map<size_t, std::shared_ptr<AccessInfo>>& dependencies() const {
105     return dependencies_;
106   }
107 
108   // Each access that depends on this one.
109   // ie. this access is present in the dependencies map of all accesses that are
110   // dependent.
dependents()111   std::map<size_t, std::shared_ptr<AccessInfo>> dependents() const {
112     std::map<size_t, std::shared_ptr<AccessInfo>> res;
113     for (const auto& kv : dependents_) {
114       res.emplace(kv.first, kv.second.lock());
115     }
116     return res;
117   }
118 
119   // Returns the symbolic expression of the indices of this access.
120   std::vector<ExprPtr> getIndices() const;
121 
122   // Establishes a dependency or dependent relationship with another access.
123   void addDependency(const std::shared_ptr<AccessInfo>& write);
124   void addDependent(const std::shared_ptr<AccessInfo>& read);
125 
126   // helper for checking dependencies.
127   bool hasDependency(const std::shared_ptr<AccessInfo>& info) const;
128 
129   // Returns the set of all nodes that are direct (immediate) dependencies of
130   // this access.
131   DependencySet getDirectDependencies();
132   // likewise, returns all nodes that directly depend on this one.
133   DependencySet getDirectDependents();
134 
135   // Returns the full list of all nodes in the graph that this access depends
136   // on, and all nodes they depend on, and so forth, back to the inputs.
137   DependencySet getIndirectDependencies();
138   // likewise, returns the full list of all nodes that depend on this node, and
139   // all nodes that depend on those nodes and so on down to the outputs.
140   DependencySet getIndirectDependents();
141 
142   // Does this access represent a read of memory (Load, ReduceOp, Call, etc).
143   bool isRead() const;
144   // Does this access represent a write of memory (Store, etc).
145   bool isWrite() const;
146 
147   // Helpers for dumping accesses in various formats.
148   void print() const;
149   void dumpDOT(std::ostream& os) const;
150   const char* AccessTypeColour() const;
151 
152  private:
153   size_t id_;
154   AccessType type_;
155   StmtPtr stmt_;
156   ExprPtr expr_;
157   VarPtr var_;
158   IndexBounds bounds_;
159 
160   // Yes these should be sorted.
161   std::map<size_t, std::shared_ptr<AccessInfo>> dependencies_;
162   std::map<size_t, std::weak_ptr<AccessInfo>> dependents_;
163 };
164 
165 using VarBoundMap = std::unordered_map<VarPtr, Bound>;
166 
167 /* MemDependencyChecker analyses a IR fragment and builds a dependency graph of
168  * accesses contained within.
169  *
170  * It's possible to retrieve the entire graph in node-object form, or can be
171  * used as an oracle for answering dependency questions. e.g:
172  *
173  *  analyzer.hasIndirectDependency(BufA, BufB); or,
174  *  analyzer.hasDirectDependency(LoadA, StoreB);
175  */
176 class TORCH_API MemDependencyChecker : public IRVisitor {
177   struct Scope;
178 
179  public:
180   MemDependencyChecker();
181   MemDependencyChecker(
182       const std::unordered_set<BufPtr>& inputs,
183       const std::unordered_set<BufPtr>& outputs);
184   MemDependencyChecker(
185       const std::vector<BufHandle>& inputs,
186       const std::vector<BufHandle>& outputs);
187 
188   ~MemDependencyChecker() override = default;
189 
190   // Whether or not to allow loop execution order to influence dependency
191   // calculation. If the loop may later be parallelized you don't want this.
192   bool allowLoopExecutionOrderAnalysis(bool allow = true);
193 
194   // Dependency Checking API.
195   // The goal is to have enough overloads here so you don't really have to think
196   // about it.
197 
198   // Returns true if any read in A has a direct dependence on a write in B.
199   bool dependsDirectly(const StmtPtr& A, const StmtPtr& B);
200   bool dependsDirectly(const ExprPtr& A, const StmtPtr& B);
201 
202   // Returns true of the output depends directly on a write contained in B.
203   bool dependsDirectly(const BufPtr& output, const StmtPtr& B);
204 
205   // Returns true if a read in A depends directly on the provided input.
206   bool dependsDirectly(const StmtPtr& A, const BufPtr& input);
207   bool dependsDirectly(const ExprPtr& A, const BufPtr& input);
208 
209   // Outputs/inputs cannot depend directly.
210 
211   // Returns true if the access A has B as an immediate dependency.
212   bool dependsDirectly(
213       const std::shared_ptr<AccessInfo>& A,
214       const std::shared_ptr<AccessInfo>& B);
215 
216   // Returns true if any read in A has an ancestor write contained in B.
217   bool dependsIndirectly(const StmtPtr& A, const StmtPtr& B);
218   bool dependsIndirectly(const ExprPtr& A, const StmtPtr& B);
219 
220   // Returns true of the output depends indirectly on a write contained in B.
221   bool dependsIndirectly(const BufPtr& output, const StmtPtr& B);
222 
223   // Returns true if a read in A depends indirectly on the provided input.
224   bool dependsIndirectly(const StmtPtr& A, const BufPtr& input);
225   bool dependsIndirectly(const ExprPtr& A, const BufPtr& input);
226 
227   // returns true if the output uses any load of the input.
228   bool dependsIndirectly(const BufPtr& output, const BufPtr& input);
229 
230   // Returns true if the access A has a dependency chain to access B.
231   bool dependsIndirectly(
232       const std::shared_ptr<AccessInfo>& A,
233       const std::shared_ptr<AccessInfo>& B);
234 
235   // Returns the AccessInfo
236   std::shared_ptr<AccessInfo> accessFor(const StmtPtr& A) const;
237   std::shared_ptr<AccessInfo> accessFor(const ExprPtr& A) const;
238 
239   // Returns all AccessInfos.
240   std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(
241       const StmtPtr& A) const;
242   // TODO: this will return only the AccessInfo for A. It's included for
243   // completeness but be aware it wont return accesses used in the computation
244   // of A.
245   std::unordered_set<std::shared_ptr<AccessInfo>> accessesWithin(
246       const ExprPtr& A) const;
247 
248   // Accesses relating to input and output buffers.
249   std::shared_ptr<AccessInfo> input(const BufPtr& B) const;
250   std::shared_ptr<AccessInfo> output(const BufPtr& B) const;
251 
252   // Returns the full history of reads and writes.
253   const std::vector<std::shared_ptr<AccessInfo>>& getHistory() const;
254 
255   // Dumps the dependency graph in DOT format.
256   void dumpDAG(const std::string& filename) const;
257 
258  private:
259   // Node visitors.
260   void visit(const StorePtr& v) override;
261   void visit(const LoadPtr& v) override;
262   void visit(const ForPtr& v) override;
263   void visit(const CondPtr& v) override;
264   void visit(const IfThenElsePtr& v) override;
265   void visit(const CompareSelectPtr& v) override;
266   void visit(const BlockPtr& v) override;
267   void visit(const LetPtr& v) override;
268   void visit(const AtomicAddPtr& v) override;
269   void visit(const AllocatePtr& v) override;
270   void visit(const FreePtr& v) override;
271 
272   using BoundRelationship = std::pair<IndexBounds, std::shared_ptr<AccessInfo>>;
273 
274   // An internal struct holding the accesses found within a scope Block.
275   struct Scope {
ScopeScope276     Scope(BlockPtr b, std::shared_ptr<Scope> p)
277         : block(std::move(b)), parent(std::move(p)) {}
278 
279     BlockPtr block;
280     std::shared_ptr<Scope> parent;
281 
282     std::unordered_map<VarPtr, Bound> shadowedVarBounds;
283     std::unordered_set<VarPtr> localVars;
284 
285     std::vector<std::shared_ptr<AccessInfo>> accesses_;
286 
287     std::unordered_map<VarPtr, std::list<BoundRelationship>> openWrites_;
288   };
289   std::shared_ptr<Scope> currentScope_;
290 
291   bool allowExecutionOrderAnalysis_{false};
292 
293   std::unordered_multimap<StmtPtr, std::shared_ptr<AccessInfo>> stmtToAccess_;
294   std::unordered_multimap<ExprPtr, std::shared_ptr<AccessInfo>> exprToAccess_;
295   std::unordered_map<StmtPtr, std::vector<std::shared_ptr<AccessInfo>>>
296       scopeToAccesses_;
297 
298   VarBoundMap knownVarBounds_;
299 
300   // Finds all accesses that are reads within the scope of v.
301   template <typename StmtOrExprPtr>
getAllReadsWithin(const StmtOrExprPtr & v)302   DependencySet getAllReadsWithin(const StmtOrExprPtr& v) {
303     DependencySet reads;
304     auto insertAllReads = [&](const auto& nodes) {
305       for (const auto& l : nodes) {
306         auto bound = exprToAccess_.equal_range(l);
307         for (auto it = bound.first; it != bound.second; ++it) {
308           if (it->second->isRead()) {
309             reads.insert(it->second);
310           }
311         }
312       }
313     };
314 
315     // Look for and insert accesses belonging to all nodes that act like
316     // reads.
317     insertAllReads(NodeFinder<Load>::find(v));
318     insertAllReads(NodeFinder<ReduceOp>::find(v));
319 
320     return reads;
321   }
322 
323   // Finds all accesses that are writes within the scope of v.
324   // Writes cannot occur in Exprs, so this is a little simpler.
getAllWritesWithin(const StmtPtr & v)325   DependencySet getAllWritesWithin(const StmtPtr& v) {
326     DependencySet writes;
327 
328     // writes just Store currently.
329     auto stores = NodeFinder<Store>::find(v);
330     for (const auto& s : stores) {
331       auto bound = stmtToAccess_.equal_range(s);
332       for (auto it = bound.first; it != bound.second; ++it) {
333         if (it->second->isWrite()) {
334           writes.insert(it->second);
335         }
336       }
337     }
338     return writes;
339   }
340 
341   // Templated helpers to work on either Exprs or Stmts.
342   template <typename StmtOrExprPtr>
dependsDirectlyHelper(const StmtOrExprPtr & A,const StmtPtr & B)343   bool dependsDirectlyHelper(const StmtOrExprPtr& A, const StmtPtr& B) {
344     auto aReads = getAllReadsWithin(A);
345     auto bWrites = getAllWritesWithin(B);
346 
347     for (auto& read : aReads) {
348       for (auto& depPair : read->dependencies()) {
349         if (bWrites.count(depPair.second) != 0) {
350           return true;
351         }
352       }
353     }
354 
355     return false;
356   }
357 
358   template <typename StmtOrExprPtr>
dependsIndirectlyHelper(StmtOrExprPtr A,const StmtPtr & B)359   bool dependsIndirectlyHelper(StmtOrExprPtr A, const StmtPtr& B) {
360     auto aReads = getAllReadsWithin(A);
361     auto bWrites = getAllWritesWithin(B);
362 
363     auto aDeps = getAllWriteDependencies(aReads);
364 
365     for (auto& dependency : aDeps) {
366       if (bWrites.count(dependency) != 0) {
367         return true;
368       }
369     }
370 
371     return false;
372   }
373 
374   DependencySet getAllWriteDependencies(const DependencySet& products);
375 
376   // Maps for inputs and outputs, since they aren't present directly in the IR.
377   std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>> inputs_;
378   std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>> outputs_;
379   std::unordered_map<VarPtr, std::shared_ptr<AccessInfo>> intermediates_;
380 
381   // Inserts accesses for Buf's: specifically for inputs and outputs.
382   void insertBuffers(
383       std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>>& bufs,
384       AccessType type);
385 
386   // Update the write history with a new write, adding dependencies and closing
387   // any overlapped writes (if possible).
388   void updateWriteHistory(
389       std::list<BoundRelationship>& writeHistory,
390       const std::shared_ptr<AccessInfo>& info,
391       size_t latestAccessToClose,
392       bool closeOverlapped = true,
393       bool insert = true);
394 
395   // Merge a child scope into a parent scope, adding dependencies for open
396   // writes in the parent to accesses in the child.
397   void mergeScope(
398       const std::shared_ptr<Scope>& child,
399       const std::shared_ptr<Scope>& parent,
400       bool closeOverlapped = true);
401 
402   // Binds symbolic vars in indices with the low and high bound for those vars.
403   std::vector<Bound> getIndicesBounds(const std::vector<ExprPtr>& indices);
404 
405   size_t nextAccess_{0};
406   StmtPtr lastStmt_{nullptr};
407 };
408 
409 } // namespace torch::jit::tensorexpr::analysis
410