xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/mem_dependency_checker.h>
2 
3 #include <c10/util/irange.h>
4 
5 #include <fstream>
6 #include <iostream>
7 #include <utility>
8 
9 namespace torch::jit::tensorexpr::analysis {
10 
AccessToString(AccessType a)11 const char* AccessToString(AccessType a) {
12   switch (a) {
13     case AccessType::Input:
14       return "Input";
15     case AccessType::Output:
16       return "Output";
17     case AccessType::Load:
18       return "Load";
19     case AccessType::Store:
20       return "Store";
21     case AccessType::Call:
22       return "Call";
23     case AccessType::AtomicAdd:
24       return "AtomicAdd";
25     case AccessType::Alloc:
26       return "Alloc";
27     case AccessType::Free:
28       return "Free";
29     default:
30       break;
31   }
32   return "Unknown";
33 }
34 
getDependencyChain(const std::shared_ptr<AccessInfo> & info,DependencySet & dependencies)35 static void getDependencyChain(
36     const std::shared_ptr<AccessInfo>& info,
37     DependencySet& dependencies) {
38   if (!dependencies.insert(info).second) {
39     return;
40   }
41 
42   for (auto& dep : info->dependencies()) {
43     getDependencyChain(dep.second, dependencies);
44   }
45 }
46 
getDependentsChain(const std::shared_ptr<AccessInfo> & info,DependencySet & dependents)47 static void getDependentsChain(
48     const std::shared_ptr<AccessInfo>& info,
49     DependencySet& dependents) {
50   if (!dependents.insert(info).second) {
51     return;
52   }
53 
54   for (auto& dep : info->dependents()) {
55     getDependencyChain(dep.second, dependents);
56   }
57 }
58 
59 // AccessInfo
60 
getIndices() const61 std::vector<ExprPtr> AccessInfo::getIndices() const {
62   std::vector<ExprPtr> indices;
63 
64   if (expr_) {
65     if (auto load = to<Load>(expr_)) {
66       indices = load->indices();
67     }
68   } else {
69     if (auto store = to<Store>(stmt_)) {
70       indices = store->indices();
71     }
72   }
73   return indices;
74 }
75 
addDependency(const std::shared_ptr<AccessInfo> & write)76 void AccessInfo::addDependency(const std::shared_ptr<AccessInfo>& write) {
77   auto res = dependencies_.emplace(write->id(), write);
78   TORCH_INTERNAL_ASSERT(
79       res.second,
80       buildErrorMessage("Duplicate entry in mem dep checker in the fuser."));
81 }
82 
addDependent(const std::shared_ptr<AccessInfo> & read)83 void AccessInfo::addDependent(const std::shared_ptr<AccessInfo>& read) {
84   auto res = dependents_.emplace(read->id(), read);
85   TORCH_INTERNAL_ASSERT(
86       res.second,
87       buildErrorMessage("Duplicate entry in mem dep checker in the fuser."));
88 }
89 
hasDependency(const std::shared_ptr<AccessInfo> & info) const90 bool AccessInfo::hasDependency(const std::shared_ptr<AccessInfo>& info) const {
91   return dependencies_.count(info->id()) != 0;
92 }
93 
getDirectDependencies()94 DependencySet AccessInfo::getDirectDependencies() {
95   DependencySet res;
96   for (auto& depPair : dependencies_) {
97     res.insert(depPair.second);
98   }
99   return res;
100 }
101 
getIndirectDependencies()102 DependencySet AccessInfo::getIndirectDependencies() {
103   DependencySet res;
104   for (auto& depPair : dependencies_) {
105     getDependencyChain(depPair.second, res);
106   }
107   return res;
108 }
109 
getDirectDependents()110 DependencySet AccessInfo::getDirectDependents() {
111   DependencySet res;
112   for (auto& depPair : dependents_) {
113     res.insert(depPair.second.lock());
114   }
115   return res;
116 }
117 
getIndirectDependents()118 DependencySet AccessInfo::getIndirectDependents() {
119   DependencySet res;
120   for (auto& depPair : dependencies_) {
121     getDependentsChain(depPair.second, res);
122   }
123   return res;
124 }
125 
isRead() const126 bool AccessInfo::isRead() const {
127   switch (type_) {
128     case AccessType::Output:
129     case AccessType::Load:
130     case AccessType::Call:
131     case AccessType::AtomicAdd:
132       return true;
133     default:
134       break;
135   }
136   return false;
137 }
138 
isWrite() const139 bool AccessInfo::isWrite() const {
140   switch (type_) {
141     case AccessType::Input:
142     case AccessType::Store:
143     case AccessType::AtomicAdd:
144     case AccessType::Alloc:
145     case AccessType::Free:
146       return true;
147     default:
148       break;
149   }
150   return false;
151 }
152 
print() const153 void AccessInfo::print() const {
154   std::cout << id_ << ". " << AccessToString(type_) << ": " << *var_ << "[";
155   if (!bounds_.empty()) {
156     for (size_t i = 0; i < bounds_.size() - 1; ++i) {
157       bounds_[i].print();
158       std::cout << ", ";
159     }
160 
161     size_t i = bounds_.size() - 1;
162     bounds_[i].print();
163   }
164   std::cout << "]";
165 
166   if (!dependencies_.empty()) {
167     std::cout << " - depends on: ";
168     for (auto& pair : dependencies_) {
169       std::cout << pair.second->id() << " ";
170     }
171   }
172 
173   if (!dependents_.empty()) {
174     std::cout << " - dependents: ";
175     for (auto& pair : dependents_) {
176       std::cout << pair.second.lock()->id() << " ";
177     }
178   }
179 
180   std::cout << "\n";
181 }
182 
dumpDOT(std::ostream & os) const183 void AccessInfo::dumpDOT(std::ostream& os) const {
184   if (type_ == AccessType::Input || type_ == AccessType::Output ||
185       type_ == AccessType::Alloc) {
186     os << "n" << id_ << " [\n";
187     os << "label = \"" << AccessToString(type_) << "\\n " << *var_ << "[";
188     if (!bounds_.empty()) {
189       for (size_t i = 0; i < bounds_.size() - 1; ++i) {
190         os << *IRSimplifier::simplify(
191                   alloc<Add>(bounds_[i].end, immLike(bounds_[i].end, 1)))
192            << ", ";
193       }
194 
195       size_t i = bounds_.size() - 1;
196       os << *IRSimplifier::simplify(
197           alloc<Add>(bounds_[i].end, immLike(bounds_[i].end, 1)));
198       os << "]\"\n ";
199     }
200     if (isWrite()) {
201       os << "\tshape = \"invhouse\"\n";
202     } else {
203       os << "\tshape = \"house\"\n";
204     }
205   } else {
206     os << "n" << id_ << " [\n";
207     os << "label = \"" << AccessToString(type_) << " (#" << id_ << ")\\n";
208     os << "buf : " << *var_ << "\\n";
209     os << "bounds : [";
210     if (!bounds_.empty()) {
211       for (size_t i = 0; i < bounds_.size() - 1; ++i) {
212         os << "(" << *bounds_[i].start << ", " << *bounds_[i].end << "), ";
213       }
214 
215       size_t i = bounds_.size() - 1;
216       os << "(" << *bounds_[i].start << ", " << *bounds_[i].end << ")]";
217     }
218     os << "\"\n";
219     os << "\tshape = \"box\"\n";
220   }
221   os << "\tstyle=\"filled\"\n";
222   os << "\tcolor=\"" << AccessTypeColour() << "\"\n";
223   std::string edgeColour;
224   if (isWrite()) {
225     edgeColour = "cornflowerblue";
226   } else {
227     edgeColour = "goldenrod";
228   }
229   os << "]\n";
230   for (auto& pair : dependencies_) {
231     os << "n" << pair.second->id() << " -> "
232        << "n" << id_ << " [color=\"" << edgeColour << "\"]\n";
233   }
234 }
235 
AccessTypeColour() const236 const char* AccessInfo::AccessTypeColour() const {
237   switch (type_) {
238     case AccessType::Input:
239     case AccessType::Output:
240       return "palegreen";
241     case AccessType::Load:
242       return "peachpuff";
243     case AccessType::Store:
244       return "dodgerblue";
245     case AccessType::Call:
246       return "violet";
247     case AccessType::Alloc:
248     case AccessType::Free:
249       return "sandybrown";
250     default:
251       break;
252   }
253   return "white";
254 }
255 
256 // MemDependencyChecker
257 //
MemDependencyChecker()258 MemDependencyChecker::MemDependencyChecker() {
259   currentScope_ = std::make_shared<Scope>(nullptr, nullptr);
260 }
261 
MemDependencyChecker(const std::unordered_set<BufPtr> & inputs,const std::unordered_set<BufPtr> & outputs)262 MemDependencyChecker::MemDependencyChecker(
263     const std::unordered_set<BufPtr>& inputs,
264     const std::unordered_set<BufPtr>& outputs) {
265   for (const auto& s : inputs) {
266     inputs_[s] = nullptr;
267   }
268   for (const auto& s : outputs) {
269     outputs_[s] = nullptr;
270   }
271 
272   currentScope_ = std::make_shared<Scope>(nullptr, nullptr);
273 }
274 
MemDependencyChecker(const std::vector<BufHandle> & inputs,const std::vector<BufHandle> & outputs)275 MemDependencyChecker::MemDependencyChecker(
276     const std::vector<BufHandle>& inputs,
277     const std::vector<BufHandle>& outputs) {
278   for (auto& s : inputs) {
279     inputs_[s.node()] = nullptr;
280   }
281   for (auto& s : outputs) {
282     outputs_[s.node()] = nullptr;
283   }
284 
285   currentScope_ = std::make_shared<Scope>(nullptr, nullptr);
286 }
287 
allowLoopExecutionOrderAnalysis(bool allow)288 bool MemDependencyChecker::allowLoopExecutionOrderAnalysis(bool allow) {
289   std::swap(allowExecutionOrderAnalysis_, allow);
290   return allow;
291 }
292 
293 const std::vector<std::shared_ptr<AccessInfo>>& MemDependencyChecker::
getHistory() const294     getHistory() const {
295   return currentScope_->accesses_;
296 }
297 
dumpDAG(const std::string & filename) const298 void MemDependencyChecker::dumpDAG(const std::string& filename) const {
299   std::ofstream dotfile(filename);
300 
301   dotfile << "digraph {\n";
302   for (auto& wi : getHistory()) {
303     wi->dumpDOT(dotfile);
304   }
305   dotfile << "}\n";
306   dotfile.close();
307 }
308 
309 // dependsDirectly, dependsIndirectly and friends:
310 
getAllWriteDependencies(const DependencySet & products)311 DependencySet MemDependencyChecker::getAllWriteDependencies(
312     const DependencySet& products) {
313   DependencySet writes;
314 
315   for (auto& info : products) {
316     DependencySet dependencies;
317     getDependencyChain(info, dependencies);
318     for (auto& other : dependencies) {
319       if (other->isWrite()) {
320         writes.insert(other);
321       }
322     }
323   }
324 
325   return writes;
326 }
327 
dependsDirectly(const ExprPtr & A,const StmtPtr & B)328 bool MemDependencyChecker::dependsDirectly(const ExprPtr& A, const StmtPtr& B) {
329   return dependsDirectlyHelper(A, B);
330 }
331 
dependsDirectly(const StmtPtr & A,const StmtPtr & B)332 bool MemDependencyChecker::dependsDirectly(const StmtPtr& A, const StmtPtr& B) {
333   return dependsDirectlyHelper(A, B);
334 }
335 
dependsDirectly(const BufPtr & O,const StmtPtr & B)336 bool MemDependencyChecker::dependsDirectly(const BufPtr& O, const StmtPtr& B) {
337   auto outputAccess = output(O);
338   auto bWrites = getAllWritesWithin(B);
339 
340   for (auto& depPair : outputAccess->dependencies()) {
341     if (bWrites.count(depPair.second) != 0) {
342       return true;
343     }
344   }
345 
346   return false;
347 }
348 
dependsDirectly(const StmtPtr & A,const BufPtr & I)349 bool MemDependencyChecker::dependsDirectly(const StmtPtr& A, const BufPtr& I) {
350   auto aReads = getAllReadsWithin(A);
351   auto inputAccess = input(I);
352 
353   for (auto& depPair : inputAccess->dependents()) {
354     if (aReads.count(depPair.second) != 0) {
355       return true;
356     }
357   }
358 
359   return false;
360 }
361 
dependsDirectly(const ExprPtr & A,const BufPtr & I)362 bool MemDependencyChecker::dependsDirectly(const ExprPtr& A, const BufPtr& I) {
363   auto aReads = getAllReadsWithin(A);
364   auto inputAccess = input(I);
365 
366   for (auto& depPair : inputAccess->dependents()) {
367     if (aReads.count(depPair.second) != 0) {
368       return true;
369     }
370   }
371 
372   return false;
373 }
374 
dependsDirectly(const std::shared_ptr<AccessInfo> & A,const std::shared_ptr<AccessInfo> & B)375 bool MemDependencyChecker::dependsDirectly(
376     const std::shared_ptr<AccessInfo>& A,
377     const std::shared_ptr<AccessInfo>& B) {
378   return A->hasDependency(B) && B->isWrite();
379 }
380 
dependsIndirectly(const ExprPtr & A,const StmtPtr & B)381 bool MemDependencyChecker::dependsIndirectly(
382     const ExprPtr& A,
383     const StmtPtr& B) {
384   return dependsIndirectlyHelper(A, B);
385 }
386 
dependsIndirectly(const StmtPtr & A,const StmtPtr & B)387 bool MemDependencyChecker::dependsIndirectly(
388     const StmtPtr& A,
389     const StmtPtr& B) {
390   return dependsIndirectlyHelper(A, B);
391 }
392 
dependsIndirectly(const BufPtr & O,const StmtPtr & B)393 bool MemDependencyChecker::dependsIndirectly(
394     const BufPtr& O,
395     const StmtPtr& B) {
396   auto outputAccess = output(O);
397 
398   DependencySet dependencies;
399   getDependencyChain(outputAccess, dependencies);
400 
401   auto bWrites = getAllWritesWithin(B);
402   for (auto& dep : dependencies) {
403     if (bWrites.count(dep) != 0) {
404       return true;
405     }
406   }
407 
408   return false;
409 }
410 
dependsIndirectly(const StmtPtr & A,const BufPtr & I)411 bool MemDependencyChecker::dependsIndirectly(
412     const StmtPtr& A,
413     const BufPtr& I) {
414   auto aReads = getAllReadsWithin(A);
415   auto inputAccess = input(I);
416 
417   auto aDeps = getAllWriteDependencies(aReads);
418 
419   return aDeps.count(inputAccess) != 0;
420 }
421 
dependsIndirectly(const ExprPtr & A,const BufPtr & I)422 bool MemDependencyChecker::dependsIndirectly(
423     const ExprPtr& A,
424     const BufPtr& I) {
425   auto aReads = getAllReadsWithin(A);
426   auto inputAccess = input(I);
427 
428   auto aDeps = getAllWriteDependencies(aReads);
429 
430   return aDeps.count(inputAccess) != 0;
431 }
432 
dependsIndirectly(const BufPtr & O,const BufPtr & I)433 bool MemDependencyChecker::dependsIndirectly(const BufPtr& O, const BufPtr& I) {
434   auto outputAccess = output(O);
435   auto inputAccess = input(I);
436 
437   return dependsIndirectly(outputAccess, inputAccess);
438 }
439 
dependsIndirectly(const std::shared_ptr<AccessInfo> & A,const std::shared_ptr<AccessInfo> & B)440 bool MemDependencyChecker::dependsIndirectly(
441     const std::shared_ptr<AccessInfo>& A,
442     const std::shared_ptr<AccessInfo>& B) {
443   if (!B->isWrite()) {
444     return false;
445   }
446 
447   DependencySet dependencies;
448   getDependencyChain(A, dependencies);
449   if (dependencies.count(B) == 0) {
450     return false;
451   }
452 
453   return true;
454 }
455 
accessFor(const StmtPtr & A) const456 std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
457     const StmtPtr& A) const {
458   auto bound = stmtToAccess_.equal_range(A);
459   for (auto it = bound.first; it != bound.second; ++it) {
460     if (it->second->expr() == nullptr) {
461       return it->second;
462     }
463   }
464   return nullptr;
465 }
466 
accessFor(const ExprPtr & A) const467 std::shared_ptr<AccessInfo> MemDependencyChecker::accessFor(
468     const ExprPtr& A) const {
469   // TODO exprs can have multiple accesses... we're returning the first but that
470   // isn't great. Can't do much here.
471   auto bound = exprToAccess_.equal_range(A);
472   if (bound.first != exprToAccess_.end()) {
473     return bound.first->second;
474   }
475 
476   return nullptr;
477 }
478 
479 std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
accessesWithin(const StmtPtr & A) const480     accessesWithin(const StmtPtr& A) const {
481   auto it = scopeToAccesses_.find(A);
482   if (it != scopeToAccesses_.end()) {
483     return std::unordered_set<std::shared_ptr<AccessInfo>>(
484         it->second.begin(), it->second.end());
485   }
486 
487   std::unordered_set<std::shared_ptr<AccessInfo>> ret;
488   auto bound = stmtToAccess_.equal_range(A);
489   for (auto it = bound.first; it != bound.second; ++it) {
490     ret.insert(it->second);
491   }
492   return ret;
493 }
494 
495 std::unordered_set<std::shared_ptr<AccessInfo>> MemDependencyChecker::
accessesWithin(const ExprPtr & A) const496     accessesWithin(const ExprPtr& A) const {
497   return {accessFor(A)};
498 }
499 
input(const BufPtr & b) const500 std::shared_ptr<AccessInfo> MemDependencyChecker::input(const BufPtr& b) const {
501   auto it = inputs_.find(b);
502   if (it == inputs_.end()) {
503     return nullptr;
504   }
505   return it->second;
506 }
507 
output(const BufPtr & b) const508 std::shared_ptr<AccessInfo> MemDependencyChecker::output(
509     const BufPtr& b) const {
510   auto it = outputs_.find(b);
511   if (it == outputs_.end()) {
512     return nullptr;
513   }
514   return it->second;
515 }
516 
517 // Node visitors:
518 
visit(const StorePtr & v)519 void MemDependencyChecker::visit(const StorePtr& v) {
520   StmtPtr last = lastStmt_;
521   lastStmt_ = v;
522   v->value()->accept(this);
523 
524   for (const ExprPtr& ind : v->indices()) {
525     ind->accept(this);
526   }
527   lastStmt_ = last;
528 
529   // Create a new AccessInfo for the store.
530   VarPtr var = v->buf()->base_handle();
531   auto info = std::make_shared<AccessInfo>(
532       nextAccess_++, AccessType::Store, v, var, getIndicesBounds(v->indices()));
533 
534   // Add a dependency to any accesses that are within the scope of this store
535   // (ie. the RHS).
536   auto bound = stmtToAccess_.equal_range(v);
537   for (auto it = bound.first; it != bound.second; ++it) {
538     info->addDependency(it->second);
539     it->second->addDependent(info);
540   }
541 
542   stmtToAccess_.emplace(v, info);
543 
544   // This write is open, and will close any open writes that it totally
545   // overlaps.
546   auto& history = currentScope_->openWrites_[var];
547   updateWriteHistory(history, info, info->id());
548   currentScope_->accesses_.push_back(info);
549 }
550 
visit(const LoadPtr & v)551 void MemDependencyChecker::visit(const LoadPtr& v) {
552   // Create a temporary scope to hold any loads that occur within the indices of
553   // this load.
554   auto indicesScope =
555       std::make_shared<Scope>(currentScope_->block, currentScope_);
556   currentScope_ = indicesScope;
557 
558   for (const ExprPtr& ind : v->indices()) {
559     ind->accept(this);
560   }
561 
562   // Create a new AccessInfo for the load.
563   VarPtr var = v->buf()->base_handle();
564   auto load = std::make_shared<AccessInfo>(
565       nextAccess_++,
566       AccessType::Load,
567       v,
568       lastStmt_,
569       var,
570       getIndicesBounds(v->indices()));
571 
572   // If there were loads in the indices, this load depends on them, and merge
573   // them in.
574   if (!indicesScope->accesses_.empty()) {
575     for (auto& access : indicesScope->accesses_) {
576       load->addDependency(access);
577       access->addDependent(load);
578     }
579     mergeScope(indicesScope, indicesScope->parent, false);
580   }
581 
582   currentScope_ = indicesScope->parent;
583 
584   stmtToAccess_.emplace(lastStmt_, load);
585   exprToAccess_.emplace(v, load);
586 
587   // This is a read, and does not close any accesses - but we need to establish
588   // dependencies on accesses in the same scope.
589   // Intentionally using operator[], we want it to be created if it does not
590   // exist.
591   auto& writeHistory = currentScope_->openWrites_[var];
592   updateWriteHistory(writeHistory, load, load->id());
593   currentScope_->accesses_.push_back(load);
594 }
595 
596 // This check determines if two accesses within a loop are "safe" from loop-self
597 // dependence. This function does not consider overlap in bound range, but
598 // rather the stride of the bound relative to the loop variable. This is the
599 // section of the code which considers iteration order, if allowed.
executionSafetyCheck(const std::shared_ptr<AccessInfo> & info,const std::shared_ptr<AccessInfo> & other,const std::vector<ExprPtr> & aStrides,const std::vector<ExprPtr> & oStrides,bool parallelized)600 static bool executionSafetyCheck(
601     const std::shared_ptr<AccessInfo>& info,
602     const std::shared_ptr<AccessInfo>& other,
603     const std::vector<ExprPtr>& aStrides,
604     const std::vector<ExprPtr>& oStrides,
605     bool parallelized) {
606   if (aStrides.empty() || oStrides.empty()) {
607     return false;
608   }
609   TORCH_INTERNAL_ASSERT(
610       info->bounds().size() == other->bounds().size(),
611       buildErrorMessage(
612           "Dimension mismatch for two accesses in mem dep checker in the fuser."));
613   for (size_t b = 0; b < info->bounds().size(); ++b) {
614     ExprPtr aIndexStride = aStrides[b];
615     ExprPtr oIndexStride = oStrides[b];
616     // can't be safe on this index if we can't determine stride.
617     if (!aIndexStride->isConstant() || !oIndexStride->isConstant()) {
618       continue;
619     }
620 
621     ExprPtr minStride =
622         IRSimplifier::simplify(alloc<Min>(aIndexStride, oIndexStride, true));
623     ExprPtr maxStride =
624         IRSimplifier::simplify(alloc<Max>(aIndexStride, oIndexStride, true));
625 
626     // If the first access has no stride don't apply safety).
627     if (immediateEquals(minStride, 0)) {
628       continue;
629     }
630 
631     ExprPtr modCheck = IRSimplifier::simplify(alloc<Mod>(maxStride, minStride));
632 
633     // if the strides can't have easily inferable distinct offsets, they're not
634     // safe.
635     if (!immediateEquals(modCheck, 0)) {
636       continue;
637     }
638 
639     // If the loop has a defined execution order (ie. sequential for) then
640     // the order of execution can provide safety from overlaps.
641     // Specifically if the difference in first access position for any
642     // axis is the same sign as the common stride, then they will not
643     // overlap.
644 
645     ExprPtr startDiff = IRSimplifier::simplify(
646         alloc<Sub>(info->bounds()[b].start, other->bounds()[b].start));
647 
648     bool diffNegative = immediateIsNegative(startDiff);
649     bool strideNegative = immediateIsNegative(minStride);
650 
651     // Invert the startDiff so mod works.
652     if (diffNegative != strideNegative) {
653       startDiff =
654           IRSimplifier::simplify(alloc<Sub>(immLike(startDiff, 0), startDiff));
655     }
656 
657     // If both accesses have the same stride, and the difference in start
658     // element is smaller than this stride then the entire range is distinct.
659     if (exprEquals(minStride, maxStride)) {
660       ExprPtr check1 = IRSimplifier::simplify(
661           alloc<CompareSelect>(startDiff, minStride, kLT));
662       if (check1->isConstant() && immediateEquals(check1, 1)) {
663         return true;
664       }
665     }
666 
667     startDiff = IRSimplifier::simplify(alloc<Mod>(startDiff, minStride));
668 
669     CompareSelectOperation op = strideNegative ? kLT : kGT;
670 
671     ExprPtr check = IRSimplifier::simplify(
672         alloc<CompareSelect>(startDiff, immLike(startDiff, 0), op));
673 
674     // If the start difference modulo the minimum stride is offset from that
675     // stride, then the ranges have distinct strides.
676     if (check->isConstant() && immediateEquals<int>(check, 1)) {
677       return true;
678     }
679 
680     // If we can consider execution order and the difference in offset is
681     // opposite signed to the stride then the read occurs in the past and we can
682     // infer safety.
683     if (!parallelized && diffNegative == strideNegative &&
684         immediateEquals(startDiff, 0)) {
685       return true;
686     }
687   }
688 
689   return false;
690 }
691 
visit(const ForPtr & v)692 void MemDependencyChecker::visit(const ForPtr& v) {
693   VarPtr var = v->var();
694 
695   StmtPtr last = lastStmt_;
696   lastStmt_ = v;
697 
698   v->var()->accept(this);
699 
700   // Loads inside the For's start and stop expression are special.
701   // They exist in the enclosing scope, but accesses within the loop body may
702   // depend on them via usage of the loop variable.
703   // The way we handle this is to create a new scope so we have an easily
704   // accessible list of the accesses within the extents.
705   auto extentsScope =
706       std::make_shared<Scope>(currentScope_->block, currentScope_);
707   currentScope_ = extentsScope;
708 
709   v->start()->accept(this);
710   v->stop()->accept(this);
711 
712   currentScope_ = currentScope_->parent;
713 
714   auto newScope = std::make_shared<Scope>(v->body(), currentScope_);
715   currentScope_ = newScope;
716 
717   v->body()->accept(this);
718 
719   lastStmt_ = last;
720 
721   // Ok now we need to determine whether accesses in the loop depend on
722   // other loop iterations.
723   //
724   // This is the real challenge here, it depends on both the fully expanded
725   // bounds and the symbolic bounds.
726 
727   // The indices must change monotonically to avoid intersection. This is
728   // hard to determine, so here's our heuristic I hope it's conservative
729   // enough.
730 
731   // the size of at least one dependent index must be >= the size of the
732   // loop.
733 
734   // First step is to infer the stride relative to each dimension of each
735   // access, which we do via substituting the loop var with (var+1) into the
736   // indices expr.
737 
738   std::vector<std::vector<ExprPtr>> loopStrides;
739   loopStrides.resize(currentScope_->accesses_.size());
740 
741   for (size_t a = 0; a < currentScope_->accesses_.size(); ++a) {
742     auto& info = currentScope_->accesses_[a];
743 
744     std::vector<ExprPtr> indices = info->getIndices();
745 
746     std::vector<ExprPtr>& loopIndicesStride = loopStrides[a];
747     loopIndicesStride.resize(indices.size());
748 
749     // index expr must depend on the loop var in some way to have a stride.
750     for (const auto i : c10::irange(indices.size())) {
751       VarFinder vf;
752       if (vf.find(indices[i]).count(var) == 0) {
753         loopIndicesStride[i] = immLike(indices[i], 0);
754       } else {
755         // If we've previously swapped the start and end of this bound, we
756         // should apply the substitution to the reverse of the bounds.
757         if (info->bounds()[i].swapped) {
758           info->bounds()[i].end = IRSimplifier::simplify(
759               SubstituteInClone(info->bounds()[i].end, {{var, v->start()}}));
760           info->bounds()[i].start = IRSimplifier::simplify(SubstituteInClone(
761               info->bounds()[i].start,
762               {{var, alloc<Sub>(v->stop(), immLike(v->stop(), 1))}}));
763 
764         } else {
765           info->bounds()[i].start = IRSimplifier::simplify(
766               SubstituteInClone(info->bounds()[i].start, {{var, v->start()}}));
767           info->bounds()[i].end = IRSimplifier::simplify(SubstituteInClone(
768               info->bounds()[i].end,
769               {{var, alloc<Sub>(v->stop(), immLike(v->stop(), 1))}}));
770         }
771 
772         ExprPtr zeroStep = indices[i];
773         ExprPtr oneStep = SubstituteInClone(
774             indices[i], {{var, alloc<Add>(var, immLike(var, 1))}});
775         loopIndicesStride[i] =
776             IRSimplifier::simplify(alloc<Sub>(oneStep, zeroStep));
777 
778         // If the start < end then swap the order of the bound.
779         ExprPtr diff = IRSimplifier::simplify(
780             alloc<Sub>(info->bounds()[i].end, info->bounds()[i].start));
781         if (diff->isConstant() && immediateIsNegative(diff)) {
782           info->bounds()[i].swap();
783         }
784 
785         // If this access uses the loop var, it depends on loads used to compute
786         // the loop var.
787         for (auto& extentLoad : extentsScope->accesses_) {
788           info->addDependency(extentLoad);
789           extentLoad->addDependent(info);
790         }
791       }
792     }
793   }
794 
795   // Now we need to update the bounds in openWrites since that is what we use to
796   // merge.
797   for (auto& openWritePair : currentScope_->openWrites_) {
798     for (auto& pair : openWritePair.second) {
799       IndexBounds& bounds = pair.first;
800 
801       // The bounds may not contain the loop var, but in that case Substitute
802       // does nothing.
803       for (auto& bound : bounds) {
804         bound.start = IRSimplifier::simplify(
805             SubstituteInClone(bound.start, {{var, v->start()}}));
806         bound.end = IRSimplifier::simplify(SubstituteInClone(
807             bound.end, {{var, alloc<Sub>(v->stop(), immLike(v->stop(), 1))}}));
808 
809         // If the start < end then swap the order of the bound.
810         ExprPtr diff =
811             IRSimplifier::simplify(alloc<Sub>(bound.end, bound.start));
812         if (diff->isConstant() && immediateIsNegative(diff)) {
813           bound.swap();
814         }
815       }
816     }
817   }
818 
819   // TODO this isn't a scalable way to determine parallelism.
820   bool parallelized = v->loop_options().is_gpu_block_index() ||
821       v->loop_options().is_gpu_thread_index();
822 
823   // Store buffers allocated at this scope.
824   std::unordered_set<VarPtr> local_intermediates;
825 
826   // Scanning from the top of the loop, we look for accesses which may depend
827   // on a previous or parallel loop iteration.
828   for (size_t a = 0; a < currentScope_->accesses_.size(); ++a) {
829     auto& info = currentScope_->accesses_[a];
830     if (info->type() == AccessType::Alloc) {
831       local_intermediates.insert(info->var());
832       continue;
833     }
834 
835     if (!info->isRead()) {
836       continue;
837     }
838 
839     // Vars that don't carry outside this scope can't have loop self dependence.
840     if (local_intermediates.count(info->var())) {
841       continue;
842     }
843 
844     // Copy the bounds so we can keep track of open bounds internally without
845     // affecting the merge into the enclosing scope. The open portion of the
846     // bounds may be cut into multiple independent slices.
847     std::vector<IndexBounds> openBounds({info->bounds()});
848 
849     // Scan from the bottom of the loop.
850     for (size_t j = currentScope_->accesses_.size() - 1; j > a; --j) {
851       std::shared_ptr<AccessInfo> other = currentScope_->accesses_[j];
852       if (!other->isWrite()) {
853         continue;
854       }
855 
856       if (info->var() != other->var()) {
857         continue;
858       }
859 
860       if (info->hasDependency(other)) {
861         continue;
862       }
863 
864       // Whether or not the accesses within the loop are dependent on other
865       // iterations depends whether the loop could be parallelized, the
866       // difference in their strides and their start offset.
867       bool iterationsDistinct = executionSafetyCheck(
868           info,
869           other,
870           loopStrides[a],
871           loopStrides[j],
872           !allowExecutionOrderAnalysis_ || parallelized);
873 
874       if (iterationsDistinct) {
875         continue;
876       }
877 
878       std::vector<IndexBounds> newBoundSlices;
879       for (auto& b : openBounds) {
880         OverlapKind overlap = overlaps(b, other->bounds());
881         if (overlap == OverlapKind::NoOverlap) {
882           newBoundSlices.push_back(b);
883           continue;
884         }
885 
886         // It's dependent, link it to other.
887         info->addDependency(other);
888         other->addDependent(info);
889 
890         if (overlap == OverlapKind::Contains) {
891           continue;
892         }
893 
894         // Otherwise update openBounds.
895         auto slices = subtractIndicesBounds(b, other->bounds(), overlap);
896         std::move(
897             slices.begin(), slices.end(), std::back_inserter(newBoundSlices));
898       }
899 
900       if (newBoundSlices.empty()) {
901         break;
902       }
903       openBounds.swap(newBoundSlices);
904     }
905   }
906 
907   std::vector<std::shared_ptr<AccessInfo>> mergedAccesses;
908   mergedAccesses.reserve(
909       extentsScope->accesses_.size() + currentScope_->accesses_.size());
910   std::copy(
911       extentsScope->accesses_.begin(),
912       extentsScope->accesses_.end(),
913       std::back_inserter(mergedAccesses));
914   std::copy(
915       currentScope_->accesses_.begin(),
916       currentScope_->accesses_.end(),
917       std::back_inserter(mergedAccesses));
918   scopeToAccesses_.emplace(v, mergedAccesses);
919 
920   // it's a little faster to merge without closing, and since no writes can
921   // occur within the start and stop exprs we'll do that.
922   mergeScope(extentsScope, extentsScope->parent, false);
923   mergeScope(currentScope_, currentScope_->parent, true);
924   currentScope_ = currentScope_->parent;
925 }
926 
visit(const CondPtr & v)927 void MemDependencyChecker::visit(const CondPtr& v) {
928   StmtPtr last = lastStmt_;
929   lastStmt_ = v;
930 
931   auto enclosingScope =
932       std::make_shared<Scope>(currentScope_->block, currentScope_);
933 
934   // condition is in enclosing scope.
935   v->condition()->accept(this);
936 
937   BlockPtr true_stmt = v->true_stmt();
938   BlockPtr false_stmt = v->false_stmt();
939 
940   // Create scopes so the Block visitor doesn't create and merge a new scope.
941   auto trueScope = std::make_shared<Scope>(true_stmt, enclosingScope);
942   auto falseScope = std::make_shared<Scope>(false_stmt, enclosingScope);
943 
944   if (true_stmt) {
945     currentScope_ = trueScope;
946     true_stmt->accept(this);
947   }
948 
949   if (false_stmt) {
950     currentScope_ = falseScope;
951     false_stmt->accept(this);
952   }
953 
954   // TODO(nickg): this logic isn't quite correct, if a write's Bound range is
955   // present in both the true and false branches then we can close overlapping
956   // accesses in the enclosing scope. Without that analysis future accesses
957   // may be dependent on a write of a common range in all three of the
958   // enclosing, true and false scope. This is a false positive so not too bad
959   // in the short term, I think.
960 
961   // Merge both true and false branches into the parent, but don't close any
962   // accesses.
963   mergeScope(trueScope, enclosingScope, false);
964   mergeScope(falseScope, enclosingScope, false);
965 
966   // Merge the enclosing scope into it's parent.
967   mergeScope(enclosingScope, enclosingScope->parent, false);
968 
969   currentScope_ = enclosingScope;
970   scopeToAccesses_.emplace(v, enclosingScope->accesses_);
971 
972   currentScope_ = enclosingScope->parent;
973   lastStmt_ = last;
974 }
975 
visit(const IfThenElsePtr & v)976 void MemDependencyChecker::visit(const IfThenElsePtr& v) {
977   // condition is in enclosing scope.
978   v->condition()->accept(this);
979 
980   ExprPtr true_value = v->true_value();
981   ExprPtr false_value = v->false_value();
982 
983   auto enclosingScope = currentScope_;
984 
985   // Create scopes to hold downstream Loads. It's safe to put nullptr for the
986   // Scope's Block as it is only used by Stmts, not Exprs.
987   auto trueScope = std::make_shared<Scope>(nullptr, enclosingScope);
988   auto falseScope = std::make_shared<Scope>(nullptr, enclosingScope);
989 
990   if (true_value) {
991     currentScope_ = trueScope;
992     true_value->accept(this);
993   }
994 
995   if (false_value) {
996     currentScope_ = falseScope;
997     false_value->accept(this);
998   }
999 
1000   // This doesn't have the same issue as Cond where there could be false
1001   // positives from the enclosing scope since there are no Exprs which are
1002   // writes.
1003 
1004   // Merge both true and false branches into the parent, but don't close any
1005   // accesses.
1006   mergeScope(trueScope, enclosingScope, false);
1007   mergeScope(falseScope, enclosingScope, false);
1008 
1009   currentScope_ = enclosingScope;
1010 }
1011 
visit(const CompareSelectPtr & v)1012 void MemDependencyChecker::visit(const CompareSelectPtr& v) {
1013   // condition is in enclosing scope.
1014   v->lhs()->accept(this);
1015   v->rhs()->accept(this);
1016 
1017   ExprPtr true_value = v->ret_val1();
1018   ExprPtr false_value = v->ret_val2();
1019 
1020   auto enclosingScope = currentScope_;
1021 
1022   // Create scopes to hold downstream Loads. It's safe to put nullptr for the
1023   // Scope's Block as it is only used by Stmts, not Exprs.
1024   auto trueScope = std::make_shared<Scope>(nullptr, enclosingScope);
1025   auto falseScope = std::make_shared<Scope>(nullptr, enclosingScope);
1026 
1027   if (true_value) {
1028     currentScope_ = trueScope;
1029     true_value->accept(this);
1030   }
1031 
1032   if (false_value) {
1033     currentScope_ = falseScope;
1034     false_value->accept(this);
1035   }
1036 
1037   // This doesn't have the same issue as Cond where there could be false
1038   // positives from the enclosing scope since there are no Exprs which are
1039   // writes.
1040 
1041   // Merge both true and false branches into the parent, but don't close any
1042   // accesses.
1043   mergeScope(trueScope, enclosingScope, false);
1044   mergeScope(falseScope, enclosingScope, false);
1045 
1046   currentScope_ = enclosingScope;
1047 }
1048 
1049 // Inserts accesses for a map of buffers (ie. for inputs and outputs).
insertBuffers(std::unordered_map<BufPtr,std::shared_ptr<AccessInfo>> & bufs,AccessType type)1050 void MemDependencyChecker::insertBuffers(
1051     std::unordered_map<BufPtr, std::shared_ptr<AccessInfo>>& bufs,
1052     AccessType type) {
1053   for (auto& pair : bufs) {
1054     const BufPtr& b = pair.first;
1055     VarPtr var = b->base_handle();
1056     IndexBounds bounds;
1057     for (const auto& d : b->dims()) {
1058       bounds.emplace_back(
1059           immLike(d, 0), IRSimplifier::simplify(alloc<Sub>(d, immLike(d, 1))));
1060     }
1061     auto info =
1062         std::make_shared<AccessInfo>(nextAccess_++, type, nullptr, var, bounds);
1063 
1064     bufs[b] = info;
1065 
1066     auto& history = currentScope_->openWrites_[var];
1067     updateWriteHistory(history, info, info->id());
1068     currentScope_->accesses_.push_back(info);
1069   }
1070 }
1071 
visit(const BlockPtr & v)1072 void MemDependencyChecker::visit(const BlockPtr& v) {
1073   auto prev_scope = currentScope_;
1074 
1075   // handle kernel inputs.
1076   if (prev_scope->block == nullptr) {
1077     insertBuffers(inputs_, AccessType::Input);
1078   }
1079 
1080   if (currentScope_->block != v) {
1081     currentScope_ = std::make_shared<Scope>((BlockPtr)v, prev_scope);
1082   }
1083 
1084   for (const auto& s : *v) {
1085     s->accept(this);
1086   }
1087 
1088   for (const auto& v : currentScope_->localVars) {
1089     knownVarBounds_.erase(v);
1090   }
1091   for (auto& pair : currentScope_->shadowedVarBounds) {
1092     knownVarBounds_[pair.first] = pair.second;
1093   }
1094 
1095   scopeToAccesses_.emplace(v, currentScope_->accesses_);
1096 
1097   if (currentScope_ != prev_scope) {
1098     mergeScope(currentScope_, prev_scope, true);
1099     currentScope_ = prev_scope;
1100   }
1101 
1102   // handle kernel outputs.
1103   if (prev_scope->block == nullptr) {
1104     insertBuffers(outputs_, AccessType::Output);
1105   }
1106 }
1107 
visit(const LetPtr & v)1108 void MemDependencyChecker::visit(const LetPtr& v) {
1109   StmtPtr last = lastStmt_;
1110   lastStmt_ = v;
1111 
1112   IRVisitor::visit(v);
1113 
1114   lastStmt_ = last;
1115 
1116   VarPtr var = v->var();
1117   if (knownVarBounds_.count(var) != 0) {
1118     currentScope_->shadowedVarBounds[var] = knownVarBounds_[var];
1119   }
1120 
1121   currentScope_->localVars.insert(var);
1122   knownVarBounds_[var] = {v->value(), v->value()};
1123 }
1124 
1125 // Don't support AtomicAdd yet, it's a bit more complex since it's both a read
1126 // and a write. It's only inserted during Cuda codegen so this should be okay.
visit(const AtomicAddPtr & v)1127 void MemDependencyChecker::visit(const AtomicAddPtr& v) {
1128   throw std::runtime_error("MemDependencyChecker AtomicAdd unimplemented");
1129 }
1130 
visit(const AllocatePtr & v)1131 void MemDependencyChecker::visit(const AllocatePtr& v) {
1132   StmtPtr last = lastStmt_;
1133   lastStmt_ = v;
1134 
1135   IRVisitor::visit(v);
1136 
1137   VarPtr var = v->buffer_var();
1138   IndexBounds bounds;
1139   // TODO: remove the "buf_flat_size" process below and extend the buf bound
1140   // check to support N-d indices access and 1-d index access.
1141   // "Allocate" stmt is based on "Buf" which supports N-d indices access and 1-d
1142   // index access. Currently the write bound check in memory analysis cannot
1143   // identify 1-d index access for N-d bufs. Thus we flatten N-d bufs here to
1144   // avoid failing the bound check. But this is not the correct approach and
1145   // should be fixed.
1146   ExprPtr flat_size = buf_flat_size(v->buf());
1147   flat_size =
1148       IRSimplifier::simplify(alloc<Sub>(flat_size, immLike(flat_size, 1)));
1149   bounds.emplace_back(immLike(flat_size, 0), flat_size);
1150 
1151   auto info = std::make_shared<AccessInfo>(
1152       nextAccess_++, AccessType::Alloc, nullptr, var, bounds);
1153 
1154   intermediates_[var] = info;
1155 
1156   auto& history = currentScope_->openWrites_[var];
1157   history.emplace_back(std::make_pair(info->bounds(), info));
1158   currentScope_->accesses_.push_back(info);
1159 
1160   lastStmt_ = last;
1161 }
1162 
visit(const FreePtr & v)1163 void MemDependencyChecker::visit(const FreePtr& v) {
1164   StmtPtr last = lastStmt_;
1165   lastStmt_ = v;
1166 
1167   IRVisitor::visit(v);
1168 
1169   VarPtr var = v->buffer_var();
1170   auto it = intermediates_.find(var);
1171   TORCH_INTERNAL_ASSERT(
1172       it != intermediates_.end(),
1173       buildErrorMessage(
1174           "Expected to find '" + var->name_hint() +
1175           "' in intermediate vars in mem dep checker in the fuser."));
1176 
1177   IndexBounds bounds = it->second->bounds();
1178   auto info = std::make_shared<AccessInfo>(
1179       nextAccess_++, AccessType::Free, nullptr, var, bounds);
1180 
1181   auto& history = currentScope_->openWrites_[var];
1182   updateWriteHistory(history, info, info->id());
1183   currentScope_->accesses_.push_back(info);
1184 
1185   lastStmt_ = last;
1186 }
1187 
updateWriteHistory(std::list<BoundRelationship> & writeHistory,const std::shared_ptr<AccessInfo> & info,size_t latestAccessToClose,bool closeOverlapped,bool insert)1188 void MemDependencyChecker::updateWriteHistory(
1189     std::list<BoundRelationship>& writeHistory,
1190     const std::shared_ptr<AccessInfo>& info,
1191     size_t latestAccessToClose,
1192     bool closeOverlapped,
1193     bool insert) {
1194   bool isWrite = info->isWrite();
1195 
1196   for (auto it = writeHistory.begin(); it != writeHistory.end();) {
1197     auto& indexBounds = it->first;
1198     std::shared_ptr<AccessInfo> other = it->second;
1199     if (info->hasDependency(other)) {
1200       ++it;
1201       continue;
1202     }
1203 
1204     OverlapKind overlap = overlaps(indexBounds, info->bounds());
1205 
1206     if (overlap == OverlapKind::NoOverlap) {
1207       ++it;
1208       continue;
1209     }
1210 
1211     // Only writes can close open accesses.
1212     if (!isWrite) {
1213       info->addDependency(other);
1214       other->addDependent(info);
1215       ++it;
1216       continue;
1217     }
1218 
1219     // If we're not closing accesses we can stop here.
1220     if (!closeOverlapped || other->id() > latestAccessToClose) {
1221       ++it;
1222       continue;
1223     }
1224 
1225     if (overlap == OverlapKind::ContainedOrEqual) {
1226       // Total overlap is easy - the new access totally replaces the old.
1227       it = writeHistory.erase(it);
1228     } else {
1229       // The new write partially overlaps a previous write. We want to keep
1230       // both, but only track the uncovered part of the earlier write.
1231 
1232       // Determine the slices of the earlier bound not covered by info.
1233       auto newBounds =
1234           subtractIndicesBounds(indexBounds, info->bounds(), overlap);
1235 
1236       // Erase the old slice.
1237       it = writeHistory.erase(it);
1238 
1239       // Add all new slices.
1240       for (auto& b : newBounds) {
1241         writeHistory.insert(it, std::make_pair(b, other));
1242       }
1243       // No need to increment the iterator since it has been updated after
1244       // `erase` above.
1245     }
1246   }
1247 
1248   if (insert && isWrite) {
1249     writeHistory.emplace_back(info->bounds(), info);
1250   }
1251 }
1252 
mergeScope(const std::shared_ptr<Scope> & child,const std::shared_ptr<Scope> & parent,bool closeOverlapped)1253 void MemDependencyChecker::mergeScope(
1254     const std::shared_ptr<Scope>& child,
1255     const std::shared_ptr<Scope>& parent,
1256     bool closeOverlapped) {
1257   if (child->accesses_.empty()) {
1258     return;
1259   }
1260 
1261   // Update dependencies, but don't add new open writes yet.
1262   for (auto& info : child->accesses_) {
1263     // Intentionally using operator[], we want it to be created if it does not
1264     // exist.
1265     auto& writeHistory = parent->openWrites_[info->var()];
1266 
1267     size_t latestAccessToClose = child->accesses_.front()->id();
1268     updateWriteHistory(
1269         writeHistory, info, latestAccessToClose, closeOverlapped, false);
1270   }
1271 
1272   // Copy open writes up.
1273   for (auto& pair : child->openWrites_) {
1274     VarPtr var = pair.first;
1275 
1276     // Intentionally using operator[], we want it to be created if it does not
1277     // exist.
1278     auto& writeHistory = parent->openWrites_[var];
1279 
1280     for (auto& rel : pair.second) {
1281       writeHistory.push_back(rel);
1282     }
1283   }
1284 
1285   // the parent scope is responsible for holding all accesses now.
1286   parent->accesses_.insert(
1287       parent->accesses_.end(),
1288       std::make_move_iterator(child->accesses_.begin()),
1289       std::make_move_iterator(child->accesses_.end()));
1290 }
1291 
1292 // A visitor which applies known Bounds to symbolic expressions.
1293 class VarBoundBinder : public IRVisitor {
1294  public:
VarBoundBinder(const VarBoundMap & vars)1295   VarBoundBinder(const VarBoundMap& vars) : vars_(vars) {}
1296 
getBounds(const ExprPtr & e)1297   Bound getBounds(const ExprPtr& e) {
1298     min_ = e;
1299     max_ = e;
1300     e->accept(this);
1301     min_ = IRSimplifier::simplify(min_);
1302     max_ = IRSimplifier::simplify(max_);
1303     return {min_, max_};
1304   }
1305 
1306  private:
visit(const VarPtr & v)1307   void visit(const VarPtr& v) override {
1308     auto it = vars_.find(v);
1309     if (it == vars_.end()) {
1310       return;
1311     }
1312 
1313     min_ = SubstituteInClone(min_, {{v, it->second.start}});
1314     max_ = SubstituteInClone(max_, {{v, it->second.end}});
1315   }
1316 
1317   ExprPtr min_{nullptr};
1318   ExprPtr max_{nullptr};
1319   const VarBoundMap& vars_;
1320 };
1321 
getIndicesBounds(const std::vector<ExprPtr> & indices)1322 std::vector<Bound> MemDependencyChecker::getIndicesBounds(
1323     const std::vector<ExprPtr>& indices) {
1324   std::vector<Bound> bounds;
1325   bounds.reserve(indices.size());
1326   VarBoundBinder binder(knownVarBounds_);
1327   for (const auto& s : indices) {
1328     bounds.push_back(binder.getBounds(s));
1329   }
1330   return bounds;
1331 }
1332 
1333 } // namespace torch::jit::tensorexpr::analysis
1334