1 #include <torch/csrc/jit/tensorexpr/bounds_inference.h>
2
3 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
4 #include <torch/csrc/jit/tensorexpr/expr.h>
5 #include <torch/csrc/jit/tensorexpr/ir.h>
6 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
7 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
8 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
9 #include <torch/csrc/jit/tensorexpr/stmt.h>
10
11 #include <c10/util/irange.h>
12
13 #include <iostream>
14 #include <utility>
15
16 namespace torch::jit::tensorexpr {
17
18 using namespace analysis;
19
20 template <typename Container>
mergeTensorAccesses(const Container & accesses,const std::unordered_map<VarPtr,BufPtr> & varToBuf,bool distinctAccessKinds)21 BoundsInfo mergeTensorAccesses(
22 const Container& accesses,
23 const std::unordered_map<VarPtr, BufPtr>& varToBuf,
24 bool distinctAccessKinds) {
25 BoundsInfo ret;
26 for (auto& access : accesses) {
27 if (access->type() == AccessType::Input ||
28 access->type() == AccessType::Output) {
29 continue;
30 }
31
32 auto vtbIt = varToBuf.find(access->var());
33 TORCH_INTERNAL_ASSERT(vtbIt != varToBuf.end(), buildErrorMessage());
34 BufPtr buf = vtbIt->second;
35 std::vector<TensorAccessBoundsInfo>& infos = ret[buf];
36
37 bool added = false;
38 // This loop should be small, max of 2 (kLoad, kStore).
39 for (auto& TABI : infos) {
40 TensorAccessKind kind = access->isWrite() ? kStore : kLoad;
41 if (!distinctAccessKinds || kind == TABI.kind) {
42 TORCH_INTERNAL_ASSERT(
43 TABI.start.size() == access->bounds().size(), buildErrorMessage());
44 TORCH_INTERNAL_ASSERT(
45 TABI.stop.size() == access->bounds().size(), buildErrorMessage());
46 for (size_t i = 0; i < TABI.start.size(); ++i) {
47 TABI.start[i] = IRSimplifier::simplify(
48 alloc<Min>(TABI.start[i], access->bounds()[i].start, true));
49 TABI.stop[i] = IRSimplifier::simplify(
50 alloc<Max>(TABI.stop[i], access->bounds()[i].end, true));
51 added = true;
52
53 if (kind != TABI.kind) {
54 TABI.kind = kMutate;
55 }
56 }
57 }
58 }
59
60 if (!added) {
61 TensorAccessBoundsInfo info;
62 info.kind = access->isWrite() ? kStore : kLoad;
63
64 for (auto& b : access->bounds()) {
65 info.start.push_back(b.start);
66 info.stop.push_back(b.end);
67 }
68
69 infos.push_back(info);
70 }
71 }
72
73 return ret;
74 }
75
getAllBufs(const StmtPtr & s)76 static std::unordered_map<VarPtr, BufPtr> getAllBufs(const StmtPtr& s) {
77 std::unordered_map<VarPtr, BufPtr> varToBuf;
78
79 auto bufs = NodeFinder<Buf>::find(s);
80 for (const auto& b : bufs) {
81 varToBuf[b->base_handle()] = b;
82 }
83 return varToBuf;
84 }
85
getAllBufs(const ExprPtr & e)86 static std::unordered_map<VarPtr, BufPtr> getAllBufs(const ExprPtr& e) {
87 std::unordered_map<VarPtr, BufPtr> varToBuf;
88
89 auto bufs = NodeFinder<Buf>::find(e);
90 for (const auto& b : bufs) {
91 varToBuf[b->base_handle()] = b;
92 }
93 return varToBuf;
94 }
95
inferBounds(const StmtPtr & s,bool distinctAccessKinds)96 BoundsInfo inferBounds(const StmtPtr& s, bool distinctAccessKinds) {
97 auto varToBuf = getAllBufs(s);
98
99 MemDependencyChecker checker;
100 s->accept(&checker);
101
102 return mergeTensorAccesses(
103 checker.getHistory(), varToBuf, distinctAccessKinds);
104 }
105
getInferredBounds(MemDependencyChecker & analyzer,const StmtPtr & s,bool distinctAccessKinds)106 BoundsInfo getInferredBounds(
107 MemDependencyChecker& analyzer,
108 const StmtPtr& s,
109 bool distinctAccessKinds) {
110 return mergeTensorAccesses(
111 analyzer.accessesWithin(s), getAllBufs(s), distinctAccessKinds);
112 }
113
getInferredBounds(MemDependencyChecker & analyzer,const ExprPtr & e,bool distinctAccessKinds)114 BoundsInfo getInferredBounds(
115 MemDependencyChecker& analyzer,
116 const ExprPtr& e,
117 bool distinctAccessKinds) {
118 return mergeTensorAccesses(
119 analyzer.accessesWithin(e), getAllBufs(e), distinctAccessKinds);
120 }
121
printBoundsInfo(const BoundsInfo & v)122 void printBoundsInfo(const BoundsInfo& v) {
123 std::cerr << "Access vector {\n";
124 for (auto& pair : v) {
125 std::cerr << *pair.first << " in [";
126 bool first = true;
127 for (auto& b : pair.second) {
128 if (!first) {
129 std::cerr << ", ";
130 }
131 std::cerr << ((b.kind == kLoad) ? "LOAD" : "STORE") << "(";
132 int i = 0;
133 if (b.start.empty()) {
134 std::cerr << "0";
135 }
136 for (auto& s : b.start) {
137 if (i != 0) {
138 std::cerr << ", ";
139 }
140 std::cerr << *s;
141 i++;
142 }
143 std::cerr << "; ";
144 i = 0;
145 if (b.stop.empty()) {
146 std::cerr << "0";
147 }
148 for (auto& s : b.stop) {
149 if (i != 0) {
150 std::cerr << ", ";
151 }
152 std::cerr << *s;
153 i++;
154 }
155 std::cerr << ")";
156 first = false;
157 }
158 std::cerr << "]\n";
159 }
160 std::cerr << "}\n";
161 }
162
getBoundExtents(const std::vector<TensorAccessBoundsInfo> & infos)163 std::vector<ExprPtr> getBoundExtents(
164 const std::vector<TensorAccessBoundsInfo>& infos) {
165 std::vector<ExprPtr> starts;
166 std::vector<ExprPtr> stops;
167
168 // Find the safe size of the temporary buffer by determining the outer
169 // extents of a union of all bounds.
170 for (const TensorAccessBoundsInfo& p : infos) {
171 for (const auto i : c10::irange(p.start.size())) {
172 if (starts.size() <= i) {
173 starts.push_back(p.start[i]);
174 } else {
175 starts[i] =
176 IRSimplifier::simplify(alloc<Min>(starts[i], p.start[i], true));
177 }
178
179 if (stops.size() <= i) {
180 stops.push_back(p.stop[i]);
181 } else {
182 stops[i] =
183 IRSimplifier::simplify(alloc<Max>(stops[i], p.stop[i], true));
184 }
185 }
186 }
187
188 std::vector<ExprPtr> extents;
189 for (size_t i = 0; i < starts.size(); ++i) {
190 ExprPtr dim = IRSimplifier::simplify(
191 alloc<Add>(alloc<Sub>(stops[i], starts[i]), immLike(stops[i], 1)));
192
193 extents.push_back(dim);
194 }
195
196 return extents;
197 }
198
199 using BoundSet = std::unordered_set<Bound, BoundHash>;
200
convertBounds(const std::vector<TensorAccessBoundsInfo> & bounds,TensorAccessKind filter=kMutate)201 static BoundSet convertBounds(
202 const std::vector<TensorAccessBoundsInfo>& bounds,
203 TensorAccessKind filter = kMutate) {
204 BoundSet ret;
205 for (auto& TABI : bounds) {
206 if (filter == kMutate || TABI.kind == filter) {
207 for (size_t i = 0; i < TABI.start.size(); ++i) {
208 ret.insert(Bound(TABI.start[i], TABI.stop[i]));
209 }
210 }
211 }
212 return ret;
213 }
214
convertBounds(BoundsInfo & bounds,const BufPtr & buf,TensorAccessKind filter=kMutate)215 static BoundSet convertBounds(
216 BoundsInfo& bounds,
217 const BufPtr& buf,
218 TensorAccessKind filter = kMutate) {
219 auto it = bounds.find(buf);
220 if (it == bounds.end()) {
221 return BoundSet();
222 }
223
224 return convertBounds(it->second, filter);
225 }
226
getPotentialHazards(MemDependencyChecker & analyzer,const StmtPtr & A,const StmtPtr & B)227 HazardKind getPotentialHazards(
228 MemDependencyChecker& analyzer,
229 const StmtPtr& A,
230 const StmtPtr& B) {
231 BoundsInfo aBounds = getInferredBounds(analyzer, A, true);
232 BoundsInfo bBounds = getInferredBounds(analyzer, B, true);
233
234 for (auto& pair : bBounds) {
235 BufPtr buf = pair.first;
236 if (aBounds.find(buf) == aBounds.end()) {
237 continue;
238 }
239
240 auto aWrites = convertBounds(aBounds, buf, kStore);
241 auto aReads = convertBounds(aBounds, buf, kLoad);
242
243 auto bWrites = convertBounds(pair.second, kStore);
244 auto bReads = convertBounds(pair.second, kLoad);
245
246 // First, RAW.
247 for (auto& bR : bReads) {
248 for (auto& aW : aWrites) {
249 if (boundOverlap(bR, aW) != OverlapKind::NoOverlap) {
250 return HazardKind::ReadAfterWrite;
251 }
252 }
253 }
254
255 // Then WAR.
256 for (auto& bW : bWrites) {
257 for (auto& aR : aReads) {
258 if (boundOverlap(bW, aR) != OverlapKind::NoOverlap) {
259 return HazardKind::WriteAfterRead;
260 }
261 }
262 }
263
264 // Then WAW.
265 for (auto& bW : bWrites) {
266 for (auto& aW : aWrites) {
267 if (boundOverlap(bW, aW) != OverlapKind::NoOverlap) {
268 return HazardKind::WriteAfterWrite;
269 }
270 }
271 }
272 }
273
274 return HazardKind::NoDependency;
275 }
276
getIndexBounds(const TensorAccessBoundsInfo & tabi)277 static IndexBounds getIndexBounds(const TensorAccessBoundsInfo& tabi) {
278 TORCH_INTERNAL_ASSERT(
279 tabi.start.size() == tabi.stop.size(), buildErrorMessage());
280 IndexBounds ret(tabi.start.size());
281 if (tabi.start.empty()) {
282 return ret;
283 }
284 for (size_t i = 0; i < tabi.start.size(); ++i) {
285 ret[i] = Bound(tabi.start[i], tabi.stop[i]);
286 }
287 return ret;
288 }
289
getIndexBounds(const std::vector<TensorAccessBoundsInfo> & vTABI,TensorAccessKind filter=kMutate)290 static std::vector<IndexBounds> getIndexBounds(
291 const std::vector<TensorAccessBoundsInfo>& vTABI,
292 TensorAccessKind filter = kMutate) {
293 std::vector<IndexBounds> bounds;
294 for (auto& TABI : vTABI) {
295 if (filter == kMutate || TABI.kind == filter) {
296 bounds.push_back(getIndexBounds(TABI));
297 }
298 }
299 return bounds;
300 }
301
hasConflictingOverlap(const BoundsInfo & aBounds,const BoundsInfo & bBounds,TensorAccessKind aFilter=kMutate,TensorAccessKind bFilter=kMutate)302 static bool hasConflictingOverlap(
303 const BoundsInfo& aBounds,
304 const BoundsInfo& bBounds,
305 TensorAccessKind aFilter = kMutate,
306 TensorAccessKind bFilter = kMutate) {
307 using IndexBoundsInfo = std::unordered_map<BufPtr, std::vector<IndexBounds>>;
308 IndexBoundsInfo aIndexBoundsInfo;
309 for (auto& aBound : aBounds) {
310 aIndexBoundsInfo[aBound.first] = getIndexBounds(aBound.second, aFilter);
311 }
312 IndexBoundsInfo bIndexBoundsInfo;
313 for (auto& bBound : bBounds) {
314 bIndexBoundsInfo[bBound.first] = getIndexBounds(bBound.second, bFilter);
315 }
316
317 for (auto& aBound : aBounds) {
318 auto bIt = bBounds.find(aBound.first);
319 if (bIt == bBounds.end()) {
320 continue;
321 }
322 auto aIndexBounds = aIndexBoundsInfo[aBound.first];
323 auto bIndexBounds = bIndexBoundsInfo[bIt->first];
324 auto aTABIs = aBound.second;
325 auto bTABIs = bIt->second;
326 for (size_t i = 0; i < aTABIs.size(); ++i) {
327 for (size_t j = 0; j < bTABIs.size(); ++j) {
328 auto aTABI = aTABIs[i];
329 auto bTABI = bTABIs[j];
330 if (aTABI.kind == kLoad && bTABI.kind == kLoad) {
331 continue;
332 }
333 auto overlap = overlaps(aIndexBounds[i], bIndexBounds[j]);
334 if (overlap != OverlapKind::NoOverlap) {
335 return true;
336 }
337 }
338 }
339 }
340 return false;
341 }
342
hasConflictingOverlap(analysis::MemDependencyChecker & analyzer,const StmtPtr & A,const StmtPtr & B)343 bool hasConflictingOverlap(
344 analysis::MemDependencyChecker& analyzer,
345 const StmtPtr& A,
346 const StmtPtr& B) {
347 BoundsInfo aBounds = getInferredBounds(analyzer, A, true);
348 BoundsInfo bBounds = getInferredBounds(analyzer, B, true);
349 return hasConflictingOverlap(aBounds, bBounds);
350 }
351
isOverlapping(analysis::MemDependencyChecker & analyzer,const StorePtr & S1,const StorePtr & S2)352 bool isOverlapping(
353 analysis::MemDependencyChecker& analyzer,
354 const StorePtr& S1,
355 const StorePtr& S2) {
356 BoundsInfo s1Bounds = getInferredBounds(analyzer, S1, true);
357 BoundsInfo s2Bounds = getInferredBounds(analyzer, S2, true);
358 return hasConflictingOverlap(s1Bounds, s2Bounds, kStore, kStore);
359 }
360
isOverlapping(analysis::MemDependencyChecker & analyzer,const StorePtr & S,const LoadPtr & L)361 bool isOverlapping(
362 analysis::MemDependencyChecker& analyzer,
363 const StorePtr& S,
364 const LoadPtr& L) {
365 BoundsInfo sBounds = getInferredBounds(analyzer, S, true);
366 BoundsInfo lBounds = getInferredBounds(analyzer, L, true);
367 return hasConflictingOverlap(sBounds, lBounds, kStore, kLoad);
368 }
369
370 } // namespace torch::jit::tensorexpr
371