1 #include <torch/csrc/jit/tensorexpr/mem_dependency_checker.h>
2
3 #include <c10/util/irange.h>
4
5 #include <fstream>
6 #include <iostream>
7 #include <utility>
8
9 namespace torch::jit::tensorexpr::analysis {
10
AccessToString(AccessType a)11 const char* AccessToString(AccessType a) {
12 switch (a) {
13 case AccessType::Input:
14 return "Input";
15 case AccessType::Output:
16 return "Output";
17 case AccessType::Load:
18 return "Load";
19 case AccessType::Store:
20 return "Store";
21 case AccessType::Call:
22 return "Call";
23 case AccessType::AtomicAdd:
24 return "AtomicAdd";
25 case AccessType::Alloc:
26 return "Alloc";
27 case AccessType::Free:
28 return "Free";
29 default:
30 break;
31 }
32 return "Unknown";
33 }
34
getDependencyChain(const std::shared_ptr<AccessInfo> & info,DependencySet & dependencies)35 static void getDependencyChain(
36 const std::shared_ptr<AccessInfo>& info,
37 DependencySet& dependencies) {
38 if (!dependencies.insert(info).second) {
39 return;
40 }
41
42 for (auto& dep : info->dependencies()) {
43 getDependencyChain(dep.second, dependencies);
44 }
45 }
46
getDependentsChain(const std::shared_ptr<AccessInfo> & info,DependencySet & dependents)47 static void getDependentsChain(
48 const std::shared_ptr<AccessInfo>& info,
49 DependencySet& dependents) {
50 if (!dependents.insert(info).second) {
51 return;
52 }
53
54 for (auto& dep : info->dependents()) {
55 getDependencyChain(dep.second, dependents);
56 }
57 }
58
59 // AccessInfo
60
getIndices() const61 std::vector<ExprPtr> AccessInfo::getIndices() const {
62 std::vector<ExprPtr> indices;
63
64 if (expr_) {
65 if (auto load = to<Load>(expr_)) {
66 indices = load->indices();
67 }
68 } else {
69 if (auto store = to<Store>(stmt_)) {
70 indices = store->indices();
71 }
72 }
73 return indices;
74 }
75
addDependency(const std::shared_ptr<AccessInfo> & write)76 void AccessInfo::addDependency(const std::shared_ptr<AccessInfo>& write) {
77 auto res = dependencies_.emplace(write->id(), write);
78 TORCH_INTERNAL_ASSERT(
79 res.second,
80 buildErrorMessage("Duplicate entry in mem dep checker in the fuser."));
81 }
82
addDependent(const std::shared_ptr<AccessInfo> & read)83 void AccessInfo::addDependent(const std::shared_ptr<AccessInfo>& read) {
84 auto res = dependents_.emplace(read->id(), read);
85 TORCH_INTERNAL_ASSERT(
86 res.second,
87 buildErrorMessage("Duplicate entry in mem dep checker in the fuser."));
88 }
89
hasDependency(const std::shared_ptr<AccessInfo> & info) const90 bool AccessInfo::hasDependency(const std::shared_ptr<AccessInfo>& info) const {
91 return dependencies_.count(info->id()) != 0;
92 }
93
getDirectDependencies()94 DependencySet AccessInfo::getDirectDependencies() {
95 DependencySet res;
96 for (auto& depPair : dependencies_) {
97 res.insert(depPair.second);
98 }
99 return res;
100 }
101
getIndirectDependencies()102 DependencySet AccessInfo::getIndirectDependencies() {
103 DependencySet res;
104 for (auto& depPair : dependencies_) {
105 getDependencyChain(depPair.second, res);
106 }
107 return res;
108 }
109
getDirectDependents()110 DependencySet AccessInfo::getDirectDependents() {
111 DependencySet res;
112 for (auto& depPair : dependents_) {
113 res.insert(depPair.second.lock());
114 }
115 return res;
116 }
117
getIndirectDependents()118 DependencySet AccessInfo::getIndirectDependents() {
119 DependencySet res;
120 for (auto& depPair : dependencies_) {
121 getDependentsChain(depPair.second, res);
122 }
123 return res;
124 }
125
isRead() const126 bool AccessInfo::isRead() const {
127 switch (type_) {
128 case AccessType::Output:
129 case AccessType::Load:
130 case AccessType::Call:
131 case AccessType::AtomicAdd:
132 return true;
133 default:
134 break;
135 }
136 return false;
137 }
138
isWrite() const139 bool AccessInfo::isWrite() const {
140 switch (type_) {
141 case AccessType::Input:
142 case AccessType::Store:
143 case AccessType::AtomicAdd:
144 case AccessType::Alloc:
145 case AccessType::Free:
146 return true;
147 default:
148 break;
149 }
150 return false;
151 }
152
print() const153 void AccessInfo::print() const {
154 std::cout << id_ << ". " << AccessToString(type_) << ": " << *var_ << "[";
155 if (!bounds_.empty()) {
156 for (size_t i = 0; i < bounds_.size() - 1; ++i) {
157 bounds_[i].print();
158 std::cout << ", ";
159 }
160
161 size_t i = bounds_.size() - 1;
162 bounds_[i].print();
163 }
164 std::cout << "]";
165
166 if (!dependencies_.empty()) {
167 std::cout << " - depends on: ";
168 for (auto& pair : dependencies_) {
169 std::cout << pair.second->id() << " ";
170 }
171 }
172
173 if (!dependents_.empty()) {
174 std::cout << " - dependents: ";
175 for (auto& pair : dependents_) {
176 std::cout << pair.second.lock()->id() << " ";
177 }
178 }
179
180 std::cout << "\n";
181 }
182
dumpDOT(std::ostream & os) const183 void AccessInfo::dumpDOT(std::ostream& os) const {
184 if (type_ == AccessType::Input || type_ == AccessType::Output ||
185 type_ == AccessType::Alloc) {
186 os << "n" << id_ << " [\n";
187 os << "label = \"" << AccessToString(type_) << "\\n " << *var_ << "[";
188 if (!bounds_.empty()) {
189 for (size_t i = 0; i < bounds_.size() - 1; ++i) {
190 os << *IRSimplifier::simplify(
191 alloc<Add>(bounds_[i].end, immLike(bounds_[i].end, 1)))
192 << ", ";
193 }
194
195 size_t i = bounds_.size() - 1;
196 os << *IRSimplifier::simplify(
197 alloc<Add>(bounds_[i].end, immLike(bounds_[i].end, 1)));
198 os << "]\"\n ";
199 }
200 if (isWrite()) {
201 os << "\tshape = \"invhouse\"\n";
202 } else {
203 os << "\tshape = \"house\"\n";
204 }
205 } else {
206 os << "n" << id_ << " [\n";
207 os << "label = \"" << AccessToString(type_) << " (#" << id_ << ")\\n";
208 os << "buf : " << *var_ << "\\n";
209 os << "bounds : [";
210 if (!bounds_.empty()) {
211 for (size_t i = 0; i < bounds_.size() - 1; ++i) {
212 os << "(" << *bounds_[i].start << ", " << *bounds_[i].end << "), ";
213 }
214
215 size_t i = bounds_.size() - 1;
216 os << "(" << *bounds_[i].start << ", " << *bounds_[i].end << ")]";
217 }
218 os << "\"\n";
219 os << "\tshape = \"box\"\n";
220 }
221 os << "\tstyle=\"filled\"\n";
222 os << "\tcolor=\"" << AccessTypeColour() << "\"\n";
223 std::string edgeColour;
224 if (isWrite()) {
225 edgeColour = "cornflowerblue";
226 } else {
227 edgeColour = "goldenrod";
228 }
229 os << "]\n";
230 for (auto& pair : dependencies_) {
231 os << "n" << pair.second->id() << " -> "
232 << "n" << id_ << " [color=\"" << edgeColour << "\"]\n";
233 }
234 }
235
AccessTypeColour() const236 const char* AccessInfo::AccessTypeColour() const {
237 switch (type_) {
238 case AccessType::Input:
239 case AccessType::Output:
240 return "palegreen";
241 case AccessType::Load:
242 return "peachpuff";
243 case AccessType::Store:
244 return "dodgerblue";
245 case AccessType::Call:
246 return "violet";
247 case AccessType::Alloc:
248 case AccessType::Free:
249 return "sandybrown";
250 default:
251 break;
252 }
253 return "white";
254 }
255
256 // MemDependencyChecker
257 //
MemDependencyChecker()258 MemDependencyChecker::MemDependencyChecker() {
259 currentScope_ = std::make_shared<Scope>(nullptr, nullptr);
260 }
261
MemDependencyChecker(const std::unordered_set<BufPtr> & inputs,const std::unordered_set<BufPtr> & outputs)262 MemDependencyChecker::MemDependencyChecker(
263 const std::unordered_set<BufPtr>& inputs,
264 const std::unordered_set<BufPtr>& outputs) {
265 for (const auto& s : inputs) {
266 inputs_[s] = nullptr;
267 }
268 for (const auto& s : outputs) {
269 outputs_[s] = nullptr;
270 }
271
272 currentScope_ = std::make_shared<Scope>(nullptr, nullptr);
273 }
274
MemDependencyChecker(const std::vector<BufHandle> & inputs,const std::vector<BufHandle> & outputs)275 MemDependencyChecker::MemDependencyChecker(
276 const std::vector<BufHandle>& inputs,
277 const std::vector<BufHandle>& outputs) {
278 for (auto& s : inputs) {
279 inputs_[s.node()] = nullptr;
280 }
281 for (auto& s : outputs) {
282 outputs_[s.node()] = nullptr;
283 }
284
285 currentScope_ = std::make_shared<Scope>(nullptr, nullptr);
286 }
287
allowLoopExecutionOrderAnalysis(bool allow)288 bool MemDependencyChecker::allowLoopExecutionOrderAnalysis(bool allow) {
289 std::swap(allowExecutionOrderAnalysis_, allow);
290 return allow;
291 }
292
293 const std::vector<std::shared_ptr<AccessInfo>>& MemDependencyChecker::
getHistory() const294 getHistory() const {
295 return currentScope_->accesses_;
296 }
297
dumpDAG(const std::string & filename) const298 void MemDependencyChecker::dumpDAG(const std::string& filename) const {
299 std::ofstream dotfile(filename);
300
301 dotfile << "digraph {\n";
302 for (auto& wi : getHistory()) {
303 wi->dumpDOT(dotfile);
304 }
305 dotfile << "}\n";
306 dotfile.close();
307 }
308
309 // dependsDirectly, dependsIndirectly and friends:
310
getAllWriteDependencies(const DependencySet & products)311 DependencySet MemDependencyChecker::getAllWriteDependencies(
312 const DependencySet& products) {
313 DependencySet writes;
314
315 for (auto& info : products) {
316 DependencySet dependencies;
317 getDependencyChain(info, dependencies);
318 for (auto& other : dependencies) {
319 if (other->isWrite()) {
320 writes.insert(other);
321 }
322 }
323 }
324
325 return writes;
326 }
327
dependsDirectly(const ExprPtr & A,const StmtPtr & B)328 bool MemDependencyChecker::dependsDirectly(const ExprPtr& A, const StmtPtr& B) {
329 return dependsDirectlyHelper(A, B);
330 }
331
dependsDirectly(const StmtPtr & A,const StmtPtr & B)332 bool MemDependencyChecker::dependsDirectly(const StmtPtr& A, const StmtPtr& B) {
333 return dependsDirectlyHelper(A, B);
334 }
335
dependsDirectly(const BufPtr & O,const StmtPtr & B)336 bool MemDependencyChecker::dependsDirectly(const BufPtr& O, const StmtPtr& B) {
337 auto outputAccess = output(O);
338 auto bWrites = getAllWritesWithin(B);
339
340 for (auto& depPair : outputAccess->dependencies()) {
341 if (bWrites.count(depPair.second) != 0) {
342 return true;
343 }
344 }
345
346 return false;
347 }
348
dependsDirectly(const StmtPtr & A,const BufPtr & I)349 bool MemDependencyChecker::dependsDirectly(const StmtPtr& A, const BufPtr& I) {
350 auto aReads = getAllReadsWithin(A);
351 auto inputAccess = input(I);
352
353 for (auto& depPair : inputAccess->dependents()) {
354 if (aReads.count(depPair.second) != 0) {
355 return true;
356 }
357 }
358
359 return false;
360 }
361
dependsDirectly(const ExprPtr & A,const BufPtr & I)362 bool MemDependencyChecker::dependsDirectly(const ExprPtr& A, const BufPtr& I) {
363 auto aReads = getAllReadsWithin(A);
364 auto inputAccess = input(I);
365
366 for (auto& depPair : inputAccess->dependents()) {
367 if (aReads.count(depPair.second) != 0) {
368 return true;
369 }
370 }
371
372 return false;
373 }
374
dependsDirectly(const std::shared_ptr<AccessInfo> & A,const std::shared_ptr<AccessInfo> & B)375 bool MemDependencyChecker::dependsDirectly(
376 const std::shared_ptr<AccessInfo>& A,
377 const std::shared_ptr<AccessInfo>& B) {
378 return A->hasDependency(B) && B->isWrite();
379 }
380
dependsIndirectly(const ExprPtr & A,const StmtPtr & B)381 bool MemDependencyChecker::dependsIndirectly(
382 const ExprPtr& A,
383 const StmtPtr& B) {
384 return dependsIndirectlyHelper(A, B);
385 }
386
dependsIndirectly(const StmtPtr & A,const StmtPtr & B)387 bool MemDependencyChecker::dependsIndirectly(
388 const StmtPtr& A,
389 const StmtPtr& B) {
390 return dependsIndirectlyHelper(A, B);
391 }
392
dependsIndirectly(const BufPtr & O,const StmtPtr & B)393 bool MemDependencyChecker::dependsIndirectly(
394 const BufPtr& O,
395 const StmtPtr& B) {
396 auto outputAccess = output(O);
397
398 DependencySet dependencies;
399 getDependencyChain(outputAccess, dependencies);
400
401 auto bWrites = getAllWritesWithin(B);
402 for (auto& dep : dependencies) {
403 if (bWrites.count(dep) != 0) {
404 return true;
405 }
406 }
407
408 return false;
409 }
410
dependsIndirectly(const StmtPtr & A,const BufPtr & I)411 bool MemDependencyChecker::dependsIndirectly(
412 const StmtPtr& A,
413 const BufPtr& I) {
414 auto aReads = getAllReadsWithin(A);
415 auto inputAccess = input(I);
416
417 auto aDeps = getAllWriteDependencies(aReads);
418
419 return aDeps.count(inputAccess) != 0;
420 }
421
dependsIndirectly(const ExprPtr & A,const BufPtr & I)422 bool MemDependencyChecker::dependsIndirectly(
423 const ExprPtr& A,
424 const BufPtr& I) {
425 auto aReads = getAllReadsWithin(A);
426 auto inputAccess = input(I);
427
428 auto aDeps = getAllWriteDependencies(aReads);
429
430 return aDeps.count(inputAccess) != 0;
431 }
432
dependsIndirectly(const BufPtr & O,const BufPtr & I)433 bool MemDependencyChecker::dependsIndirectly(const BufPtr& O, const BufPtr& I) {
434 auto outputAccess = output(O);
435 auto inputAccess = input(I);
436
437 return dependsIndirectly(outputAccess, inputAccess);
438 }
439
dependsIndirectly(const std::shared_ptr<AccessInfo> & A,const std::shared_ptr<AccessInfo> & B)440 bool MemDependencyChecker::dependsIndirectly(
441 const std::shared_ptr<AccessInfo>& A,
442 const std::shared_ptr<AccessInfo>& B) {
443 if (!B->isWrite()) {
444 return false;
445 }
446
447 DependencySet dependencies;
448 getDependencyChain(A, dependencies);
449 if (dependencies.count(B) == 0) {
450 return false;
451 }
452
453 return true;
454 }
455
accessFor(const StmtPtr & A) const456 std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
457 const StmtPtr& A) const {
458 auto bound = stmtToAccess_.equal_range(A);
459 for (auto it = bound.first; it != bound.second; ++it) {
460 if (it->second->expr() == nullptr) {
461 return it->second;
462 }
463 }
464 return nullptr;
465 }
466
accessFor(const ExprPtr & A) const467 std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
468 const ExprPtr& A) const {
469 // TODO exprs can have multiple accesses... we're returning the first but that
470 // isn't great. Can't do much here.
471 auto bound = exprToAccess_.equal_range(A);
472 if (bound.first != exprToAccess_.end()) {
473 return bound.first->second;
474 }
475
476 return nullptr;
477 }
478
479 std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
accessesWithin(const StmtPtr & A) const480 accessesWithin(const StmtPtr& A) const {
481 auto it = scopeToAccesses_.find(A);
482 if (it != scopeToAccesses_.end()) {
483 return std::unordered_set<std::shared_ptr<AccessInfo>>(
484 it->second.begin(), it->second.end());
485 }
486
487 std::unordered_set<std::shared_ptr<AccessInfo>> ret;
488 auto bound = stmtToAccess_.equal_range(A);
489 for (auto it = bound.first; it != bound.second; ++it) {
490 ret.insert(it->second);
491 }
492 return ret;
493 }
494
495 std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
accessesWithin(const ExprPtr & A) const496 accessesWithin(const ExprPtr& A) const {
497 return {accessFor(A)};
498 }
499
input(const BufPtr & b) const500 std::shared_ptr<AccessInfo> MemDependencyChecker::input(const BufPtr& b) const {
501 auto it = inputs_.find(b);
502 if (it == inputs_.end()) {
503 return nullptr;
504 }
505 return it->second;
506 }
507
output(const BufPtr & b) const508 std::shared_ptr<AccessInfo> MemDependencyChecker::output(
509 const BufPtr& b) const {
510 auto it = outputs_.find(b);
511 if (it == outputs_.end()) {
512 return nullptr;
513 }
514 return it->second;
515 }
516
517 // Node visitors:
518
visit(const StorePtr & v)519 void MemDependencyChecker::visit(const StorePtr& v) {
520 StmtPtr last = lastStmt_;
521 lastStmt_ = v;
522 v->value()->accept(this);
523
524 for (const ExprPtr& ind : v->indices()) {
525 ind->accept(this);
526 }
527 lastStmt_ = last;
528
529 // Create a new AccessInfo for the store.
530 VarPtr var = v->buf()->base_handle();
531 auto info = std::make_shared<AccessInfo>(
532 nextAccess_++, AccessType::Store, v, var, getIndicesBounds(v->indices()));
533
534 // Add a dependency to any accesses that are within the scope of this store
535 // (ie. the RHS).
536 auto bound = stmtToAccess_.equal_range(v);
537 for (auto it = bound.first; it != bound.second; ++it) {
538 info->addDependency(it->second);
539 it->second->addDependent(info);
540 }
541
542 stmtToAccess_.emplace(v, info);
543
544 // This write is open, and will close any open writes that it totally
545 // overlaps.
546 auto& history = currentScope_->openWrites_[var];
547 updateWriteHistory(history, info, info->id());
548 currentScope_->accesses_.push_back(info);
549 }
550
visit(const LoadPtr & v)551 void MemDependencyChecker::visit(const LoadPtr& v) {
552 // Create a temporary scope to hold any loads that occur within the indices of
553 // this load.
554 auto indicesScope =
555 std::make_shared<Scope>(currentScope_->block, currentScope_);
556 currentScope_ = indicesScope;
557
558 for (const ExprPtr& ind : v->indices()) {
559 ind->accept(this);
560 }
561
562 // Create a new AccessInfo for the load.
563 VarPtr var = v->buf()->base_handle();
564 auto load = std::make_shared<AccessInfo>(
565 nextAccess_++,
566 AccessType::Load,
567 v,
568 lastStmt_,
569 var,
570 getIndicesBounds(v->indices()));
571
572 // If there were loads in the indices, this load depends on them, and merge
573 // them in.
574 if (!indicesScope->accesses_.empty()) {
575 for (auto& access : indicesScope->accesses_) {
576 load->addDependency(access);
577 access->addDependent(load);
578 }
579 mergeScope(indicesScope, indicesScope->parent, false);
580 }
581
582 currentScope_ = indicesScope->parent;
583
584 stmtToAccess_.emplace(lastStmt_, load);
585 exprToAccess_.emplace(v, load);
586
587 // This is a read, and does not close any accesses - but we need to establish
588 // dependencies on accesses in the same scope.
589 // Intentionally using operator[], we want it to be created if it does not
590 // exist.
591 auto& writeHistory = currentScope_->openWrites_[var];
592 updateWriteHistory(writeHistory, load, load->id());
593 currentScope_->accesses_.push_back(load);
594 }
595
596 // This check determines if two accesses within a loop are "safe" from loop-self
597 // dependence. This function does not consider overlap in bound range, but
598 // rather the stride of the bound relative to the loop variable. This is the
599 // section of the code which considers iteration order, if allowed.
executionSafetyCheck(const std::shared_ptr<AccessInfo> & info,const std::shared_ptr<AccessInfo> & other,const std::vector<ExprPtr> & aStrides,const std::vector<ExprPtr> & oStrides,bool parallelized)600 static bool executionSafetyCheck(
601 const std::shared_ptr<AccessInfo>& info,
602 const std::shared_ptr<AccessInfo>& other,
603 const std::vector<ExprPtr>& aStrides,
604 const std::vector<ExprPtr>& oStrides,
605 bool parallelized) {
606 if (aStrides.empty() || oStrides.empty()) {
607 return false;
608 }
609 TORCH_INTERNAL_ASSERT(
610 info->bounds().size() == other->bounds().size(),
611 buildErrorMessage(
612 "Dimension mismatch for two accesses in mem dep checker in the fuser."));
613 for (size_t b = 0; b < info->bounds().size(); ++b) {
614 ExprPtr aIndexStride = aStrides[b];
615 ExprPtr oIndexStride = oStrides[b];
616 // can't be safe on this index if we can't determine stride.
617 if (!aIndexStride->isConstant() || !oIndexStride->isConstant()) {
618 continue;
619 }
620
621 ExprPtr minStride =
622 IRSimplifier::simplify(alloc<Min>(aIndexStride, oIndexStride, true));
623 ExprPtr maxStride =
624 IRSimplifier::simplify(alloc<Max>(aIndexStride, oIndexStride, true));
625
626 // If the first access has no stride don't apply safety).
627 if (immediateEquals(minStride, 0)) {
628 continue;
629 }
630
631 ExprPtr modCheck = IRSimplifier::simplify(alloc<Mod>(maxStride, minStride));
632
633 // if the strides can't have easily inferable distinct offsets, they're not
634 // safe.
635 if (!immediateEquals(modCheck, 0)) {
636 continue;
637 }
638
639 // If the loop has a defined execution order (ie. sequential for) then
640 // the order of execution can provide safety from overlaps.
641 // Specifically if the difference in first access position for any
642 // axis is the same sign as the common stride, then they will not
643 // overlap.
644
645 ExprPtr startDiff = IRSimplifier::simplify(
646 alloc<Sub>(info->bounds()[b].start, other->bounds()[b].start));
647
648 bool diffNegative = immediateIsNegative(startDiff);
649 bool strideNegative = immediateIsNegative(minStride);
650
651 // Invert the startDiff so mod works.
652 if (diffNegative != strideNegative) {
653 startDiff =
654 IRSimplifier::simplify(alloc<Sub>(immLike(startDiff, 0), startDiff));
655 }
656
657 // If both accesses have the same stride, and the difference in start
658 // element is smaller than this stride then the entire range is distinct.
659 if (exprEquals(minStride, maxStride)) {
660 ExprPtr check1 = IRSimplifier::simplify(
661 alloc<CompareSelect>(startDiff, minStride, kLT));
662 if (check1->isConstant() && immediateEquals(check1, 1)) {
663 return true;
664 }
665 }
666
667 startDiff = IRSimplifier::simplify(alloc<Mod>(startDiff, minStride));
668
669 CompareSelectOperation op = strideNegative ? kLT : kGT;
670
671 ExprPtr check = IRSimplifier::simplify(
672 alloc<CompareSelect>(startDiff, immLike(startDiff, 0), op));
673
674 // If the start difference modulo the minimum stride is offset from that
675 // stride, then the ranges have distinct strides.
676 if (check->isConstant() && immediateEquals<int>(check, 1)) {
677 return true;
678 }
679
680 // If we can consider execution order and the difference in offset is
681 // opposite signed to the stride then the read occurs in the past and we can
682 // infer safety.
683 if (!parallelized && diffNegative == strideNegative &&
684 immediateEquals(startDiff, 0)) {
685 return true;
686 }
687 }
688
689 return false;
690 }
691
visit(const ForPtr & v)692 void MemDependencyChecker::visit(const ForPtr& v) {
693 VarPtr var = v->var();
694
695 StmtPtr last = lastStmt_;
696 lastStmt_ = v;
697
698 v->var()->accept(this);
699
700 // Loads inside the For's start and stop expression are special.
701 // They exist in the enclosing scope, but accesses within the loop body may
702 // depend on them via usage of the loop variable.
703 // The way we handle this is to create a new scope so we have an easily
704 // accessible list of the accesses within the extents.
705 auto extentsScope =
706 std::make_shared<Scope>(currentScope_->block, currentScope_);
707 currentScope_ = extentsScope;
708
709 v->start()->accept(this);
710 v->stop()->accept(this);
711
712 currentScope_ = currentScope_->parent;
713
714 auto newScope = std::make_shared<Scope>(v->body(), currentScope_);
715 currentScope_ = newScope;
716
717 v->body()->accept(this);
718
719 lastStmt_ = last;
720
721 // Ok now we need to determine whether accesses in the loop depend on
722 // other loop iterations.
723 //
724 // This is the real challenge here, it depends on both the fully expanded
725 // bounds and the symbolic bounds.
726
727 // The indices must change monotonically to avoid intersection. This is
728 // hard to determine, so here's our heuristic I hope it's conservative
729 // enough.
730
731 // the size of at least one dependent index must be >= the size of the
732 // loop.
733
734 // First step is to infer the stride relative to each dimension of each
735 // access, which we do via substituting the loop var with (var+1) into the
736 // indices expr.
737
738 std::vector<std::vector<ExprPtr>> loopStrides;
739 loopStrides.resize(currentScope_->accesses_.size());
740
741 for (size_t a = 0; a < currentScope_->accesses_.size(); ++a) {
742 auto& info = currentScope_->accesses_[a];
743
744 std::vector<ExprPtr> indices = info->getIndices();
745
746 std::vector<ExprPtr>& loopIndicesStride = loopStrides[a];
747 loopIndicesStride.resize(indices.size());
748
749 // index expr must depend on the loop var in some way to have a stride.
750 for (const auto i : c10::irange(indices.size())) {
751 VarFinder vf;
752 if (vf.find(indices[i]).count(var) == 0) {
753 loopIndicesStride[i] = immLike(indices[i], 0);
754 } else {
755 // If we've previously swapped the start and end of this bound, we
756 // should apply the substitution to the reverse of the bounds.
757 if (info->bounds()[i].swapped) {
758 info->bounds()[i].end = IRSimplifier::simplify(
759 SubstituteInClone(info->bounds()[i].end, {{var, v->start()}}));
760 info->bounds()[i].start = IRSimplifier::simplify(SubstituteInClone(
761 info->bounds()[i].start,
762 {{var, alloc<Sub>(v->stop(), immLike(v->stop(), 1))}}));
763
764 } else {
765 info->bounds()[i].start = IRSimplifier::simplify(
766 SubstituteInClone(info->bounds()[i].start, {{var, v->start()}}));
767 info->bounds()[i].end = IRSimplifier::simplify(SubstituteInClone(
768 info->bounds()[i].end,
769 {{var, alloc<Sub>(v->stop(), immLike(v->stop(), 1))}}));
770 }
771
772 ExprPtr zeroStep = indices[i];
773 ExprPtr oneStep = SubstituteInClone(
774 indices[i], {{var, alloc<Add>(var, immLike(var, 1))}});
775 loopIndicesStride[i] =
776 IRSimplifier::simplify(alloc<Sub>(oneStep, zeroStep));
777
778 // If the start < end then swap the order of the bound.
779 ExprPtr diff = IRSimplifier::simplify(
780 alloc<Sub>(info->bounds()[i].end, info->bounds()[i].start));
781 if (diff->isConstant() && immediateIsNegative(diff)) {
782 info->bounds()[i].swap();
783 }
784
785 // If this access uses the loop var, it depends on loads used to compute
786 // the loop var.
787 for (auto& extentLoad : extentsScope->accesses_) {
788 info->addDependency(extentLoad);
789 extentLoad->addDependent(info);
790 }
791 }
792 }
793 }
794
795 // Now we need to update the bounds in openWrites since that is what we use to
796 // merge.
797 for (auto& openWritePair : currentScope_->openWrites_) {
798 for (auto& pair : openWritePair.second) {
799 IndexBounds& bounds = pair.first;
800
801 // The bounds may not contain the loop var, but in that case Substitute
802 // does nothing.
803 for (auto& bound : bounds) {
804 bound.start = IRSimplifier::simplify(
805 SubstituteInClone(bound.start, {{var, v->start()}}));
806 bound.end = IRSimplifier::simplify(SubstituteInClone(
807 bound.end, {{var, alloc<Sub>(v->stop(), immLike(v->stop(), 1))}}));
808
809 // If the start < end then swap the order of the bound.
810 ExprPtr diff =
811 IRSimplifier::simplify(alloc<Sub>(bound.end, bound.start));
812 if (diff->isConstant() && immediateIsNegative(diff)) {
813 bound.swap();
814 }
815 }
816 }
817 }
818
819 // TODO this isn't a scalable way to determine parallelism.
820 bool parallelized = v->loop_options().is_gpu_block_index() ||
821 v->loop_options().is_gpu_thread_index();
822
823 // Store buffers allocated at this scope.
824 std::unordered_set<VarPtr> local_intermediates;
825
826 // Scanning from the top of the loop, we look for accesses which may depend
827 // on a previous or parallel loop iteration.
828 for (size_t a = 0; a < currentScope_->accesses_.size(); ++a) {
829 auto& info = currentScope_->accesses_[a];
830 if (info->type() == AccessType::Alloc) {
831 local_intermediates.insert(info->var());
832 continue;
833 }
834
835 if (!info->isRead()) {
836 continue;
837 }
838
839 // Vars that don't carry outside this scope can't have loop self dependence.
840 if (local_intermediates.count(info->var())) {
841 continue;
842 }
843
844 // Copy the bounds so we can keep track of open bounds internally without
845 // affecting the merge into the enclosing scope. The open portion of the
846 // bounds may be cut into multiple independent slices.
847 std::vector<IndexBounds> openBounds({info->bounds()});
848
849 // Scan from the bottom of the loop.
850 for (size_t j = currentScope_->accesses_.size() - 1; j > a; --j) {
851 std::shared_ptr<AccessInfo> other = currentScope_->accesses_[j];
852 if (!other->isWrite()) {
853 continue;
854 }
855
856 if (info->var() != other->var()) {
857 continue;
858 }
859
860 if (info->hasDependency(other)) {
861 continue;
862 }
863
864 // Whether or not the accesses within the loop are dependent on other
865 // iterations depends whether the loop could be parallelized, the
866 // difference in their strides and their start offset.
867 bool iterationsDistinct = executionSafetyCheck(
868 info,
869 other,
870 loopStrides[a],
871 loopStrides[j],
872 !allowExecutionOrderAnalysis_ || parallelized);
873
874 if (iterationsDistinct) {
875 continue;
876 }
877
878 std::vector<IndexBounds> newBoundSlices;
879 for (auto& b : openBounds) {
880 OverlapKind overlap = overlaps(b, other->bounds());
881 if (overlap == OverlapKind::NoOverlap) {
882 newBoundSlices.push_back(b);
883 continue;
884 }
885
886 // It's dependent, link it to other.
887 info->addDependency(other);
888 other->addDependent(info);
889
890 if (overlap == OverlapKind::Contains) {
891 continue;
892 }
893
894 // Otherwise update openBounds.
895 auto slices = subtractIndicesBounds(b, other->bounds(), overlap);
896 std::move(
897 slices.begin(), slices.end(), std::back_inserter(newBoundSlices));
898 }
899
900 if (newBoundSlices.empty()) {
901 break;
902 }
903 openBounds.swap(newBoundSlices);
904 }
905 }
906
907 std::vector<std::shared_ptr<AccessInfo>> mergedAccesses;
908 mergedAccesses.reserve(
909 extentsScope->accesses_.size() + currentScope_->accesses_.size());
910 std::copy(
911 extentsScope->accesses_.begin(),
912 extentsScope->accesses_.end(),
913 std::back_inserter(mergedAccesses));
914 std::copy(
915 currentScope_->accesses_.begin(),
916 currentScope_->accesses_.end(),
917 std::back_inserter(mergedAccesses));
918 scopeToAccesses_.emplace(v, mergedAccesses);
919
920 // it's a little faster to merge without closing, and since no writes can
921 // occur within the start and stop exprs we'll do that.
922 mergeScope(extentsScope, extentsScope->parent, false);
923 mergeScope(currentScope_, currentScope_->parent, true);
924 currentScope_ = currentScope_->parent;
925 }
926
visit(const CondPtr & v)927 void MemDependencyChecker::visit(const CondPtr& v) {
928 StmtPtr last = lastStmt_;
929 lastStmt_ = v;
930
931 auto enclosingScope =
932 std::make_shared<Scope>(currentScope_->block, currentScope_);
933
934 // condition is in enclosing scope.
935 v->condition()->accept(this);
936
937 BlockPtr true_stmt = v->true_stmt();
938 BlockPtr false_stmt = v->false_stmt();
939
940 // Create scopes so the Block visitor doesn't create and merge a new scope.
941 auto trueScope = std::make_shared<Scope>(true_stmt, enclosingScope);
942 auto falseScope = std::make_shared<Scope>(false_stmt, enclosingScope);
943
944 if (true_stmt) {
945 currentScope_ = trueScope;
946 true_stmt->accept(this);
947 }
948
949 if (false_stmt) {
950 currentScope_ = falseScope;
951 false_stmt->accept(this);
952 }
953
954 // TODO(nickg): this logic isn't quite correct, if a write's Bound range is
955 // present in both the true and false branches then we can close overlapping
956 // accesses in the enclosing scope. Without that analysis future accesses
957 // may be dependent on a write of a common range in all three of the
958 // enclosing, true and false scope. This is a false positive so not too bad
959 // in the short term, I think.
960
961 // Merge both true and false branches into the parent, but don't close any
962 // accesses.
963 mergeScope(trueScope, enclosingScope, false);
964 mergeScope(falseScope, enclosingScope, false);
965
966 // Merge the enclosing scope into it's parent.
967 mergeScope(enclosingScope, enclosingScope->parent, false);
968
969 currentScope_ = enclosingScope;
970 scopeToAccesses_.emplace(v, enclosingScope->accesses_);
971
972 currentScope_ = enclosingScope->parent;
973 lastStmt_ = last;
974 }
975
visit(const IfThenElsePtr & v)976 void MemDependencyChecker::visit(const IfThenElsePtr& v) {
977 // condition is in enclosing scope.
978 v->condition()->accept(this);
979
980 ExprPtr true_value = v->true_value();
981 ExprPtr false_value = v->false_value();
982
983 auto enclosingScope = currentScope_;
984
985 // Create scopes to hold downstream Loads. It's safe to put nullptr for the
986 // Scope's Block as it is only used by Stmts, not Exprs.
987 auto trueScope = std::make_shared<Scope>(nullptr, enclosingScope);
988 auto falseScope = std::make_shared<Scope>(nullptr, enclosingScope);
989
990 if (true_value) {
991 currentScope_ = trueScope;
992 true_value->accept(this);
993 }
994
995 if (false_value) {
996 currentScope_ = falseScope;
997 false_value->accept(this);
998 }
999
1000 // This doesn't have the same issue as Cond where there could be false
1001 // positives from the enclosing scope since there are no Exprs which are
1002 // writes.
1003
1004 // Merge both true and false branches into the parent, but don't close any
1005 // accesses.
1006 mergeScope(trueScope, enclosingScope, false);
1007 mergeScope(falseScope, enclosingScope, false);
1008
1009 currentScope_ = enclosingScope;
1010 }
1011
visit(const CompareSelectPtr & v)1012 void MemDependencyChecker::visit(const CompareSelectPtr& v) {
1013 // condition is in enclosing scope.
1014 v->lhs()->accept(this);
1015 v->rhs()->accept(this);
1016
1017 ExprPtr true_value = v->ret_val1();
1018 ExprPtr false_value = v->ret_val2();
1019
1020 auto enclosingScope = currentScope_;
1021
1022 // Create scopes to hold downstream Loads. It's safe to put nullptr for the
1023 // Scope's Block as it is only used by Stmts, not Exprs.
1024 auto trueScope = std::make_shared<Scope>(nullptr, enclosingScope);
1025 auto falseScope = std::make_shared<Scope>(nullptr, enclosingScope);
1026
1027 if (true_value) {
1028 currentScope_ = trueScope;
1029 true_value->accept(this);
1030 }
1031
1032 if (false_value) {
1033 currentScope_ = falseScope;
1034 false_value->accept(this);
1035 }
1036
1037 // This doesn't have the same issue as Cond where there could be false
1038 // positives from the enclosing scope since there are no Exprs which are
1039 // writes.
1040
1041 // Merge both true and false branches into the parent, but don't close any
1042 // accesses.
1043 mergeScope(trueScope, enclosingScope, false);
1044 mergeScope(falseScope, enclosingScope, false);
1045
1046 currentScope_ = enclosingScope;
1047 }
1048
1049 // Inserts accesses for a map of buffers (ie. for inputs and outputs).
insertBuffers(std::unordered_map<BufPtr,std::shared_ptr<AccessInfo>> & bufs,AccessType type)1050 void MemDependencyChecker::insertBuffers(
1051 std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>>& bufs,
1052 AccessType type) {
1053 for (auto& pair : bufs) {
1054 const BufPtr& b = pair.first;
1055 VarPtr var = b->base_handle();
1056 IndexBounds bounds;
1057 for (const auto& d : b->dims()) {
1058 bounds.emplace_back(
1059 immLike(d, 0), IRSimplifier::simplify(alloc<Sub>(d, immLike(d, 1))));
1060 }
1061 auto info =
1062 std::make_shared<AccessInfo>(nextAccess_++, type, nullptr, var, bounds);
1063
1064 bufs[b] = info;
1065
1066 auto& history = currentScope_->openWrites_[var];
1067 updateWriteHistory(history, info, info->id());
1068 currentScope_->accesses_.push_back(info);
1069 }
1070 }
1071
visit(const BlockPtr & v)1072 void MemDependencyChecker::visit(const BlockPtr& v) {
1073 auto prev_scope = currentScope_;
1074
1075 // handle kernel inputs.
1076 if (prev_scope->block == nullptr) {
1077 insertBuffers(inputs_, AccessType::Input);
1078 }
1079
1080 if (currentScope_->block != v) {
1081 currentScope_ = std::make_shared<Scope>((BlockPtr)v, prev_scope);
1082 }
1083
1084 for (const auto& s : *v) {
1085 s->accept(this);
1086 }
1087
1088 for (const auto& v : currentScope_->localVars) {
1089 knownVarBounds_.erase(v);
1090 }
1091 for (auto& pair : currentScope_->shadowedVarBounds) {
1092 knownVarBounds_[pair.first] = pair.second;
1093 }
1094
1095 scopeToAccesses_.emplace(v, currentScope_->accesses_);
1096
1097 if (currentScope_ != prev_scope) {
1098 mergeScope(currentScope_, prev_scope, true);
1099 currentScope_ = prev_scope;
1100 }
1101
1102 // handle kernel outputs.
1103 if (prev_scope->block == nullptr) {
1104 insertBuffers(outputs_, AccessType::Output);
1105 }
1106 }
1107
visit(const LetPtr & v)1108 void MemDependencyChecker::visit(const LetPtr& v) {
1109 StmtPtr last = lastStmt_;
1110 lastStmt_ = v;
1111
1112 IRVisitor::visit(v);
1113
1114 lastStmt_ = last;
1115
1116 VarPtr var = v->var();
1117 if (knownVarBounds_.count(var) != 0) {
1118 currentScope_->shadowedVarBounds[var] = knownVarBounds_[var];
1119 }
1120
1121 currentScope_->localVars.insert(var);
1122 knownVarBounds_[var] = {v->value(), v->value()};
1123 }
1124
1125 // Don't support AtomicAdd yet, it's a bit more complex since it's both a read
1126 // and a write. It's only inserted during Cuda codegen so this should be okay.
visit(const AtomicAddPtr & v)1127 void MemDependencyChecker::visit(const AtomicAddPtr& v) {
1128 throw std::runtime_error("MemDependencyChecker AtomicAdd unimplemented");
1129 }
1130
visit(const AllocatePtr & v)1131 void MemDependencyChecker::visit(const AllocatePtr& v) {
1132 StmtPtr last = lastStmt_;
1133 lastStmt_ = v;
1134
1135 IRVisitor::visit(v);
1136
1137 VarPtr var = v->buffer_var();
1138 IndexBounds bounds;
1139 // TODO: remove the "buf_flat_size" process below and extend the buf bound
1140 // check to support N-d indices access and 1-d index access.
1141 // "Allocate" stmt is based on "Buf" which supports N-d indices access and 1-d
1142 // index access. Currently the write bound check in memory analysis cannot
1143 // identify 1-d index access for N-d bufs. Thus we flatten N-d bufs here to
1144 // avoid failing the bound check. But this is not the correct approach and
1145 // should be fixed.
1146 ExprPtr flat_size = buf_flat_size(v->buf());
1147 flat_size =
1148 IRSimplifier::simplify(alloc<Sub>(flat_size, immLike(flat_size, 1)));
1149 bounds.emplace_back(immLike(flat_size, 0), flat_size);
1150
1151 auto info = std::make_shared<AccessInfo>(
1152 nextAccess_++, AccessType::Alloc, nullptr, var, bounds);
1153
1154 intermediates_[var] = info;
1155
1156 auto& history = currentScope_->openWrites_[var];
1157 history.emplace_back(std::make_pair(info->bounds(), info));
1158 currentScope_->accesses_.push_back(info);
1159
1160 lastStmt_ = last;
1161 }
1162
visit(const FreePtr & v)1163 void MemDependencyChecker::visit(const FreePtr& v) {
1164 StmtPtr last = lastStmt_;
1165 lastStmt_ = v;
1166
1167 IRVisitor::visit(v);
1168
1169 VarPtr var = v->buffer_var();
1170 auto it = intermediates_.find(var);
1171 TORCH_INTERNAL_ASSERT(
1172 it != intermediates_.end(),
1173 buildErrorMessage(
1174 "Expected to find '" + var->name_hint() +
1175 "' in intermediate vars in mem dep checker in the fuser."));
1176
1177 IndexBounds bounds = it->second->bounds();
1178 auto info = std::make_shared<AccessInfo>(
1179 nextAccess_++, AccessType::Free, nullptr, var, bounds);
1180
1181 auto& history = currentScope_->openWrites_[var];
1182 updateWriteHistory(history, info, info->id());
1183 currentScope_->accesses_.push_back(info);
1184
1185 lastStmt_ = last;
1186 }
1187
updateWriteHistory(std::list<BoundRelationship> & writeHistory,const std::shared_ptr<AccessInfo> & info,size_t latestAccessToClose,bool closeOverlapped,bool insert)1188 void MemDependencyChecker::updateWriteHistory(
1189 std::list<BoundRelationship>& writeHistory,
1190 const std::shared_ptr<AccessInfo>& info,
1191 size_t latestAccessToClose,
1192 bool closeOverlapped,
1193 bool insert) {
1194 bool isWrite = info->isWrite();
1195
1196 for (auto it = writeHistory.begin(); it != writeHistory.end();) {
1197 auto& indexBounds = it->first;
1198 std::shared_ptr<AccessInfo> other = it->second;
1199 if (info->hasDependency(other)) {
1200 ++it;
1201 continue;
1202 }
1203
1204 OverlapKind overlap = overlaps(indexBounds, info->bounds());
1205
1206 if (overlap == OverlapKind::NoOverlap) {
1207 ++it;
1208 continue;
1209 }
1210
1211 // Only writes can close open accesses.
1212 if (!isWrite) {
1213 info->addDependency(other);
1214 other->addDependent(info);
1215 ++it;
1216 continue;
1217 }
1218
1219 // If we're not closing accesses we can stop here.
1220 if (!closeOverlapped || other->id() > latestAccessToClose) {
1221 ++it;
1222 continue;
1223 }
1224
1225 if (overlap == OverlapKind::ContainedOrEqual) {
1226 // Total overlap is easy - the new access totally replaces the old.
1227 it = writeHistory.erase(it);
1228 } else {
1229 // The new write partially overlaps a previous write. We want to keep
1230 // both, but only track the uncovered part of the earlier write.
1231
1232 // Determine the slices of the earlier bound not covered by info.
1233 auto newBounds =
1234 subtractIndicesBounds(indexBounds, info->bounds(), overlap);
1235
1236 // Erase the old slice.
1237 it = writeHistory.erase(it);
1238
1239 // Add all new slices.
1240 for (auto& b : newBounds) {
1241 writeHistory.insert(it, std::make_pair(b, other));
1242 }
1243 // No need to increment the iterator since it has been updated after
1244 // `erase` above.
1245 }
1246 }
1247
1248 if (insert && isWrite) {
1249 writeHistory.emplace_back(info->bounds(), info);
1250 }
1251 }
1252
mergeScope(const std::shared_ptr<Scope> & child,const std::shared_ptr<Scope> & parent,bool closeOverlapped)1253 void MemDependencyChecker::mergeScope(
1254 const std::shared_ptr<Scope>& child,
1255 const std::shared_ptr<Scope>& parent,
1256 bool closeOverlapped) {
1257 if (child->accesses_.empty()) {
1258 return;
1259 }
1260
1261 // Update dependencies, but don't add new open writes yet.
1262 for (auto& info : child->accesses_) {
1263 // Intentionally using operator[], we want it to be created if it does not
1264 // exist.
1265 auto& writeHistory = parent->openWrites_[info->var()];
1266
1267 size_t latestAccessToClose = child->accesses_.front()->id();
1268 updateWriteHistory(
1269 writeHistory, info, latestAccessToClose, closeOverlapped, false);
1270 }
1271
1272 // Copy open writes up.
1273 for (auto& pair : child->openWrites_) {
1274 VarPtr var = pair.first;
1275
1276 // Intentionally using operator[], we want it to be created if it does not
1277 // exist.
1278 auto& writeHistory = parent->openWrites_[var];
1279
1280 for (auto& rel : pair.second) {
1281 writeHistory.push_back(rel);
1282 }
1283 }
1284
1285 // the parent scope is responsible for holding all accesses now.
1286 parent->accesses_.insert(
1287 parent->accesses_.end(),
1288 std::make_move_iterator(child->accesses_.begin()),
1289 std::make_move_iterator(child->accesses_.end()));
1290 }
1291
1292 // A visitor which applies known Bounds to symbolic expressions.
1293 class VarBoundBinder : public IRVisitor {
1294 public:
VarBoundBinder(const VarBoundMap & vars)1295 VarBoundBinder(const VarBoundMap& vars) : vars_(vars) {}
1296
getBounds(const ExprPtr & e)1297 Bound getBounds(const ExprPtr& e) {
1298 min_ = e;
1299 max_ = e;
1300 e->accept(this);
1301 min_ = IRSimplifier::simplify(min_);
1302 max_ = IRSimplifier::simplify(max_);
1303 return {min_, max_};
1304 }
1305
1306 private:
visit(const VarPtr & v)1307 void visit(const VarPtr& v) override {
1308 auto it = vars_.find(v);
1309 if (it == vars_.end()) {
1310 return;
1311 }
1312
1313 min_ = SubstituteInClone(min_, {{v, it->second.start}});
1314 max_ = SubstituteInClone(max_, {{v, it->second.end}});
1315 }
1316
1317 ExprPtr min_{nullptr};
1318 ExprPtr max_{nullptr};
1319 const VarBoundMap& vars_;
1320 };
1321
getIndicesBounds(const std::vector<ExprPtr> & indices)1322 std::vector<Bound> MemDependencyChecker::getIndicesBounds(
1323 const std::vector<ExprPtr>& indices) {
1324 std::vector<Bound> bounds;
1325 bounds.reserve(indices.size());
1326 VarBoundBinder binder(knownVarBounds_);
1327 for (const auto& s : indices) {
1328 bounds.push_back(binder.getBounds(s));
1329 }
1330 return bounds;
1331 }
1332
1333 } // namespace torch::jit::tensorexpr::analysis
1334