xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/bounds_inference.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/bounds_inference.h>
2 
3 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
4 #include <torch/csrc/jit/tensorexpr/expr.h>
5 #include <torch/csrc/jit/tensorexpr/ir.h>
6 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
7 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
8 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
9 #include <torch/csrc/jit/tensorexpr/stmt.h>
10 
11 #include <c10/util/irange.h>
12 
13 #include <iostream>
14 #include <utility>
15 
16 namespace torch::jit::tensorexpr {
17 
18 using namespace analysis;
19 
20 template <typename Container>
mergeTensorAccesses(const Container & accesses,const std::unordered_map<VarPtr,BufPtr> & varToBuf,bool distinctAccessKinds)21 BoundsInfo mergeTensorAccesses(
22     const Container& accesses,
23     const std::unordered_map<VarPtr, BufPtr>& varToBuf,
24     bool distinctAccessKinds) {
25   BoundsInfo ret;
26   for (auto& access : accesses) {
27     if (access->type() == AccessType::Input ||
28         access->type() == AccessType::Output) {
29       continue;
30     }
31 
32     auto vtbIt = varToBuf.find(access->var());
33     TORCH_INTERNAL_ASSERT(vtbIt != varToBuf.end(), buildErrorMessage());
34     BufPtr buf = vtbIt->second;
35     std::vector<TensorAccessBoundsInfo>& infos = ret[buf];
36 
37     bool added = false;
38     // This loop should be small, max of 2 (kLoad, kStore).
39     for (auto& TABI : infos) {
40       TensorAccessKind kind = access->isWrite() ? kStore : kLoad;
41       if (!distinctAccessKinds || kind == TABI.kind) {
42         TORCH_INTERNAL_ASSERT(
43             TABI.start.size() == access->bounds().size(), buildErrorMessage());
44         TORCH_INTERNAL_ASSERT(
45             TABI.stop.size() == access->bounds().size(), buildErrorMessage());
46         for (size_t i = 0; i < TABI.start.size(); ++i) {
47           TABI.start[i] = IRSimplifier::simplify(
48               alloc<Min>(TABI.start[i], access->bounds()[i].start, true));
49           TABI.stop[i] = IRSimplifier::simplify(
50               alloc<Max>(TABI.stop[i], access->bounds()[i].end, true));
51           added = true;
52 
53           if (kind != TABI.kind) {
54             TABI.kind = kMutate;
55           }
56         }
57       }
58     }
59 
60     if (!added) {
61       TensorAccessBoundsInfo info;
62       info.kind = access->isWrite() ? kStore : kLoad;
63 
64       for (auto& b : access->bounds()) {
65         info.start.push_back(b.start);
66         info.stop.push_back(b.end);
67       }
68 
69       infos.push_back(info);
70     }
71   }
72 
73   return ret;
74 }
75 
getAllBufs(const StmtPtr & s)76 static std::unordered_map<VarPtr, BufPtr> getAllBufs(const StmtPtr& s) {
77   std::unordered_map<VarPtr, BufPtr> varToBuf;
78 
79   auto bufs = NodeFinder<Buf>::find(s);
80   for (const auto& b : bufs) {
81     varToBuf[b->base_handle()] = b;
82   }
83   return varToBuf;
84 }
85 
getAllBufs(const ExprPtr & e)86 static std::unordered_map<VarPtr, BufPtr> getAllBufs(const ExprPtr& e) {
87   std::unordered_map<VarPtr, BufPtr> varToBuf;
88 
89   auto bufs = NodeFinder<Buf>::find(e);
90   for (const auto& b : bufs) {
91     varToBuf[b->base_handle()] = b;
92   }
93   return varToBuf;
94 }
95 
inferBounds(const StmtPtr & s,bool distinctAccessKinds)96 BoundsInfo inferBounds(const StmtPtr& s, bool distinctAccessKinds) {
97   auto varToBuf = getAllBufs(s);
98 
99   MemDependencyChecker checker;
100   s->accept(&checker);
101 
102   return mergeTensorAccesses(
103       checker.getHistory(), varToBuf, distinctAccessKinds);
104 }
105 
getInferredBounds(MemDependencyChecker & analyzer,const StmtPtr & s,bool distinctAccessKinds)106 BoundsInfo getInferredBounds(
107     MemDependencyChecker& analyzer,
108     const StmtPtr& s,
109     bool distinctAccessKinds) {
110   return mergeTensorAccesses(
111       analyzer.accessesWithin(s), getAllBufs(s), distinctAccessKinds);
112 }
113 
getInferredBounds(MemDependencyChecker & analyzer,const ExprPtr & e,bool distinctAccessKinds)114 BoundsInfo getInferredBounds(
115     MemDependencyChecker& analyzer,
116     const ExprPtr& e,
117     bool distinctAccessKinds) {
118   return mergeTensorAccesses(
119       analyzer.accessesWithin(e), getAllBufs(e), distinctAccessKinds);
120 }
121 
printBoundsInfo(const BoundsInfo & v)122 void printBoundsInfo(const BoundsInfo& v) {
123   std::cerr << "Access vector {\n";
124   for (auto& pair : v) {
125     std::cerr << *pair.first << " in [";
126     bool first = true;
127     for (auto& b : pair.second) {
128       if (!first) {
129         std::cerr << ", ";
130       }
131       std::cerr << ((b.kind == kLoad) ? "LOAD" : "STORE") << "(";
132       int i = 0;
133       if (b.start.empty()) {
134         std::cerr << "0";
135       }
136       for (auto& s : b.start) {
137         if (i != 0) {
138           std::cerr << ", ";
139         }
140         std::cerr << *s;
141         i++;
142       }
143       std::cerr << "; ";
144       i = 0;
145       if (b.stop.empty()) {
146         std::cerr << "0";
147       }
148       for (auto& s : b.stop) {
149         if (i != 0) {
150           std::cerr << ", ";
151         }
152         std::cerr << *s;
153         i++;
154       }
155       std::cerr << ")";
156       first = false;
157     }
158     std::cerr << "]\n";
159   }
160   std::cerr << "}\n";
161 }
162 
getBoundExtents(const std::vector<TensorAccessBoundsInfo> & infos)163 std::vector<ExprPtr> getBoundExtents(
164     const std::vector<TensorAccessBoundsInfo>& infos) {
165   std::vector<ExprPtr> starts;
166   std::vector<ExprPtr> stops;
167 
168   // Find the safe size of the temporary buffer by determining the outer
169   // extents of a union of all bounds.
170   for (const TensorAccessBoundsInfo& p : infos) {
171     for (const auto i : c10::irange(p.start.size())) {
172       if (starts.size() <= i) {
173         starts.push_back(p.start[i]);
174       } else {
175         starts[i] =
176             IRSimplifier::simplify(alloc<Min>(starts[i], p.start[i], true));
177       }
178 
179       if (stops.size() <= i) {
180         stops.push_back(p.stop[i]);
181       } else {
182         stops[i] =
183             IRSimplifier::simplify(alloc<Max>(stops[i], p.stop[i], true));
184       }
185     }
186   }
187 
188   std::vector<ExprPtr> extents;
189   for (size_t i = 0; i < starts.size(); ++i) {
190     ExprPtr dim = IRSimplifier::simplify(
191         alloc<Add>(alloc<Sub>(stops[i], starts[i]), immLike(stops[i], 1)));
192 
193     extents.push_back(dim);
194   }
195 
196   return extents;
197 }
198 
199 using BoundSet = std::unordered_set<Bound, BoundHash>;
200 
convertBounds(const std::vector<TensorAccessBoundsInfo> & bounds,TensorAccessKind filter=kMutate)201 static BoundSet convertBounds(
202     const std::vector<TensorAccessBoundsInfo>& bounds,
203     TensorAccessKind filter = kMutate) {
204   BoundSet ret;
205   for (auto& TABI : bounds) {
206     if (filter == kMutate || TABI.kind == filter) {
207       for (size_t i = 0; i < TABI.start.size(); ++i) {
208         ret.insert(Bound(TABI.start[i], TABI.stop[i]));
209       }
210     }
211   }
212   return ret;
213 }
214 
convertBounds(BoundsInfo & bounds,const BufPtr & buf,TensorAccessKind filter=kMutate)215 static BoundSet convertBounds(
216     BoundsInfo& bounds,
217     const BufPtr& buf,
218     TensorAccessKind filter = kMutate) {
219   auto it = bounds.find(buf);
220   if (it == bounds.end()) {
221     return BoundSet();
222   }
223 
224   return convertBounds(it->second, filter);
225 }
226 
getPotentialHazards(MemDependencyChecker & analyzer,const StmtPtr & A,const StmtPtr & B)227 HazardKind getPotentialHazards(
228     MemDependencyChecker& analyzer,
229     const StmtPtr& A,
230     const StmtPtr& B) {
231   BoundsInfo aBounds = getInferredBounds(analyzer, A, true);
232   BoundsInfo bBounds = getInferredBounds(analyzer, B, true);
233 
234   for (auto& pair : bBounds) {
235     BufPtr buf = pair.first;
236     if (aBounds.find(buf) == aBounds.end()) {
237       continue;
238     }
239 
240     auto aWrites = convertBounds(aBounds, buf, kStore);
241     auto aReads = convertBounds(aBounds, buf, kLoad);
242 
243     auto bWrites = convertBounds(pair.second, kStore);
244     auto bReads = convertBounds(pair.second, kLoad);
245 
246     // First, RAW.
247     for (auto& bR : bReads) {
248       for (auto& aW : aWrites) {
249         if (boundOverlap(bR, aW) != OverlapKind::NoOverlap) {
250           return HazardKind::ReadAfterWrite;
251         }
252       }
253     }
254 
255     // Then WAR.
256     for (auto& bW : bWrites) {
257       for (auto& aR : aReads) {
258         if (boundOverlap(bW, aR) != OverlapKind::NoOverlap) {
259           return HazardKind::WriteAfterRead;
260         }
261       }
262     }
263 
264     // Then WAW.
265     for (auto& bW : bWrites) {
266       for (auto& aW : aWrites) {
267         if (boundOverlap(bW, aW) != OverlapKind::NoOverlap) {
268           return HazardKind::WriteAfterWrite;
269         }
270       }
271     }
272   }
273 
274   return HazardKind::NoDependency;
275 }
276 
getIndexBounds(const TensorAccessBoundsInfo & tabi)277 static IndexBounds getIndexBounds(const TensorAccessBoundsInfo& tabi) {
278   TORCH_INTERNAL_ASSERT(
279       tabi.start.size() == tabi.stop.size(), buildErrorMessage());
280   IndexBounds ret(tabi.start.size());
281   if (tabi.start.empty()) {
282     return ret;
283   }
284   for (size_t i = 0; i < tabi.start.size(); ++i) {
285     ret[i] = Bound(tabi.start[i], tabi.stop[i]);
286   }
287   return ret;
288 }
289 
getIndexBounds(const std::vector<TensorAccessBoundsInfo> & vTABI,TensorAccessKind filter=kMutate)290 static std::vector<IndexBounds> getIndexBounds(
291     const std::vector<TensorAccessBoundsInfo>& vTABI,
292     TensorAccessKind filter = kMutate) {
293   std::vector<IndexBounds> bounds;
294   for (auto& TABI : vTABI) {
295     if (filter == kMutate || TABI.kind == filter) {
296       bounds.push_back(getIndexBounds(TABI));
297     }
298   }
299   return bounds;
300 }
301 
hasConflictingOverlap(const BoundsInfo & aBounds,const BoundsInfo & bBounds,TensorAccessKind aFilter=kMutate,TensorAccessKind bFilter=kMutate)302 static bool hasConflictingOverlap(
303     const BoundsInfo& aBounds,
304     const BoundsInfo& bBounds,
305     TensorAccessKind aFilter = kMutate,
306     TensorAccessKind bFilter = kMutate) {
307   using IndexBoundsInfo = std::unordered_map<BufPtr, std::vector<IndexBounds>>;
308   IndexBoundsInfo aIndexBoundsInfo;
309   for (auto& aBound : aBounds) {
310     aIndexBoundsInfo[aBound.first] = getIndexBounds(aBound.second, aFilter);
311   }
312   IndexBoundsInfo bIndexBoundsInfo;
313   for (auto& bBound : bBounds) {
314     bIndexBoundsInfo[bBound.first] = getIndexBounds(bBound.second, bFilter);
315   }
316 
317   for (auto& aBound : aBounds) {
318     auto bIt = bBounds.find(aBound.first);
319     if (bIt == bBounds.end()) {
320       continue;
321     }
322     auto aIndexBounds = aIndexBoundsInfo[aBound.first];
323     auto bIndexBounds = bIndexBoundsInfo[bIt->first];
324     auto aTABIs = aBound.second;
325     auto bTABIs = bIt->second;
326     for (size_t i = 0; i < aTABIs.size(); ++i) {
327       for (size_t j = 0; j < bTABIs.size(); ++j) {
328         auto aTABI = aTABIs[i];
329         auto bTABI = bTABIs[j];
330         if (aTABI.kind == kLoad && bTABI.kind == kLoad) {
331           continue;
332         }
333         auto overlap = overlaps(aIndexBounds[i], bIndexBounds[j]);
334         if (overlap != OverlapKind::NoOverlap) {
335           return true;
336         }
337       }
338     }
339   }
340   return false;
341 }
342 
hasConflictingOverlap(analysis::MemDependencyChecker & analyzer,const StmtPtr & A,const StmtPtr & B)343 bool hasConflictingOverlap(
344     analysis::MemDependencyChecker& analyzer,
345     const StmtPtr& A,
346     const StmtPtr& B) {
347   BoundsInfo aBounds = getInferredBounds(analyzer, A, true);
348   BoundsInfo bBounds = getInferredBounds(analyzer, B, true);
349   return hasConflictingOverlap(aBounds, bBounds);
350 }
351 
isOverlapping(analysis::MemDependencyChecker & analyzer,const StorePtr & S1,const StorePtr & S2)352 bool isOverlapping(
353     analysis::MemDependencyChecker& analyzer,
354     const StorePtr& S1,
355     const StorePtr& S2) {
356   BoundsInfo s1Bounds = getInferredBounds(analyzer, S1, true);
357   BoundsInfo s2Bounds = getInferredBounds(analyzer, S2, true);
358   return hasConflictingOverlap(s1Bounds, s2Bounds, kStore, kStore);
359 }
360 
isOverlapping(analysis::MemDependencyChecker & analyzer,const StorePtr & S,const LoadPtr & L)361 bool isOverlapping(
362     analysis::MemDependencyChecker& analyzer,
363     const StorePtr& S,
364     const LoadPtr& L) {
365   BoundsInfo sBounds = getInferredBounds(analyzer, S, true);
366   BoundsInfo lBounds = getInferredBounds(analyzer, L, true);
367   return hasConflictingOverlap(sBounds, lBounds, kStore, kLoad);
368 }
369 
370 } // namespace torch::jit::tensorexpr
371