xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/bounds_overlap.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
2 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
3 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
4 #include <torch/csrc/jit/tensorexpr/stmt.h>
5 
6 #include <iostream>
7 
8 namespace torch::jit::tensorexpr::analysis {
9 
10 // Returns true if the given expression is guaranteed to be positive.
mustBePositive(const ExprPtr & e)11 static bool mustBePositive(const ExprPtr& e) {
12   if (e->isConstant()) {
13     int e_val = immediateAs<int>(e);
14     return e_val > 0;
15   }
16   return false;
17 }
18 
19 // Returns true if the given expression is guaranteed to be negative.
mustBeNegative(const ExprPtr & e)20 static bool mustBeNegative(const ExprPtr& e) {
21   if (e->isConstant()) {
22     int e_val = immediateAs<int>(e);
23     return e_val < 0;
24   }
25   return false;
26 }
27 
28 // Returns true if the given expression is guaranteed to be zero.
mustBeZero(const ExprPtr & e)29 static bool mustBeZero(const ExprPtr& e) {
30   if (e->isConstant()) {
31     int e_val = immediateAs<int>(e);
32     return e_val == 0;
33   }
34   return false;
35 }
36 
print() const37 void Bound::print() const {
38   std::cout << "(" << *start << ", " << *end << ")";
39 }
40 
equals(const Bound & other) const41 bool Bound::equals(const Bound& other) const {
42   return exprEquals(start, other.start) && exprEquals(end, other.end);
43 }
44 
operator ==(const Bound & other) const45 bool Bound::operator==(const Bound& other) const {
46   if (equals(other)) {
47     auto ret_expr = IRSimplifier::simplify(alloc<Sub>(start, end));
48     return mustBeZero(ret_expr);
49   }
50 
51   return false;
52 }
53 
operator !=(const Bound & other) const54 bool Bound::operator!=(const Bound& other) const {
55   return (*this < other) || (*this > other);
56 }
57 
operator >=(const Bound & other) const58 bool Bound::operator>=(const Bound& other) const {
59   if (*this == other) {
60     return true;
61   }
62   auto ret_expr = IRSimplifier::simplify(alloc<Sub>(start, other.end));
63   return mustBePositive(ret_expr) || mustBeZero(ret_expr);
64 }
65 
operator >(const Bound & other) const66 bool Bound::operator>(const Bound& other) const {
67   auto ret_expr = IRSimplifier::simplify(alloc<Sub>(start, other.end));
68   return mustBePositive(ret_expr);
69 }
70 
operator <=(const Bound & other) const71 bool Bound::operator<=(const Bound& other) const {
72   if (*this == other) {
73     return true;
74   }
75   auto ret_expr = IRSimplifier::simplify(alloc<Sub>(end, other.start));
76   return mustBeNegative(ret_expr) || mustBeZero(ret_expr);
77 }
78 
operator <(const Bound & other) const79 bool Bound::operator<(const Bound& other) const {
80   auto ret_expr = IRSimplifier::simplify(alloc<Sub>(end, other.start));
81   return mustBeNegative(ret_expr);
82 }
83 
boundOverlap(const Bound & a,const Bound & b)84 OverlapKind boundOverlap(const Bound& a, const Bound& b) {
85   // If they're equal they're equal.
86   bool startEqual = exprEquals(a.start, b.start);
87   bool endEqual = exprEquals(a.end, b.end);
88   if (startEqual && endEqual) {
89     return OverlapKind::ContainedOrEqual;
90   }
91 
92   // We have to figure out if the bounds fall under the following 2 cases:
93   // 1. a is before b
94   //      a.start ... a.end ... b.start ... b.end
95   // 2. b is before a
96   //      b.start ... b.end ... a.start ... a.end
97   //
98   // So, we compute "a.start - b.end" and "b.start - a.end". If even one of
99   // those is positive, then it is guaranteed that the bounds do not overlap.
100   //
101   // If the diff is a constant, then we can directly check if the constant is
102   // positive. If the diff is not a constant, then it will be made of
103   // variables that correspond to the bounds of buffers involved. These buffer
104   // bounds can never be negative. So, we check if the given expression is
105   // guaranteed to be positive under the assumption that the variables involved
106   // are never negative.
107 
108   ExprPtr lowDiff = IRSimplifier::simplify(alloc<Sub>(a.start, b.end));
109   ExprPtr highDiff = IRSimplifier::simplify(alloc<Sub>(b.start, a.end));
110 
111   if (mustBePositive(lowDiff)) {
112     return OverlapKind::NoOverlap;
113   }
114   if (mustBePositive(highDiff)) {
115     return OverlapKind::NoOverlap;
116   }
117 
118   ExprPtr diff_start = IRSimplifier::simplify(alloc<Sub>(b.start, a.start));
119   ExprPtr diff_end = IRSimplifier::simplify(alloc<Sub>(b.end, a.end));
120 
121   // If one side fully encloses the other, they're adjacent.
122   if (diff_start->isConstant() && diff_end->isConstant()) {
123     int start = immediateAs<int>(diff_start);
124     int end = immediateAs<int>(diff_end);
125     // If diff_start and diff_end have different signs they are enclosing.
126     if (start <= 0 && end >= 0) {
127       return OverlapKind::ContainedOrEqual;
128     }
129 
130     if (start >= 0 && end <= 0) {
131       return OverlapKind::Contains;
132     }
133   }
134 
135   // We can't be sure there's no overlap so the conservative answer is
136   // partial.
137   return OverlapKind::PartialOverlap;
138 }
139 
compareBound(const Bound & a,const Bound & b,const CompareSelectOperation & cmp_op)140 CmpEvalResult TORCH_API compareBound(
141     const Bound& a,
142     const Bound& b,
143     const CompareSelectOperation& cmp_op) {
144   switch (cmp_op) {
145     case CompareSelectOperation::kGT:
146       return (a > b)
147           ? CmpEvalResult::True
148           : (a <= b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
149     case CompareSelectOperation::kGE:
150       return (a >= b)
151           ? CmpEvalResult::True
152           : (a < b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
153     case CompareSelectOperation::kLT:
154       return (a < b)
155           ? CmpEvalResult::True
156           : (a >= b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
157     case CompareSelectOperation::kLE:
158       return (a <= b)
159           ? CmpEvalResult::True
160           : (a > b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
161     case CompareSelectOperation::kNE:
162       return (a != b)
163           ? CmpEvalResult::True
164           : (a == b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
165     default:
166       TORCH_INTERNAL_ASSERT(cmp_op == CompareSelectOperation::kEQ)
167       return (a == b)
168           ? CmpEvalResult::True
169           : (a != b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
170   }
171 }
172 
indexBoundsEquals(const IndexBounds & A,const IndexBounds & B)173 bool indexBoundsEquals(const IndexBounds& A, const IndexBounds& B) {
174   if (A.size() != B.size()) {
175     return false;
176   }
177 
178   for (size_t i = 0; i != A.size(); ++i) {
179     if (!A[i].equals(B[i])) {
180       return false;
181     }
182   }
183   return true;
184 }
185 
flattenBounds(const IndexBounds & a)186 Bound flattenBounds(const IndexBounds& a) {
187   if (a.empty()) {
188     return Bound();
189   }
190   Bound ret = a[0];
191 
192   for (size_t i = 1; i < a.size(); ++i) {
193     ret.start = alloc<Mul>(ret.start, a[i].start);
194     ret.end = alloc<Mul>(ret.end, a[i].end);
195   }
196 
197   ret.start = IRSimplifier::simplify(ret.start);
198   ret.end = IRSimplifier::simplify(ret.end);
199   return ret;
200 }
201 
overlaps(const IndexBounds & a,const IndexBounds & b)202 OverlapKind overlaps(const IndexBounds& a, const IndexBounds& b) {
203   if (a.empty() && b.empty()) {
204     return OverlapKind::ContainedOrEqual;
205   }
206 
207   // All accesses to a buf must have the same dimensionality.
208 
209   if (a.size() != b.size()) {
210     return boundOverlap(flattenBounds(a), flattenBounds(b));
211   }
212   TORCH_INTERNAL_ASSERT(a.size() == b.size());
213 
214   OverlapKind overlap = boundOverlap(a[0], b[0]);
215   for (size_t i = 1; i < a.size(); ++i) {
216     OverlapKind bOverlap = boundOverlap(a[i], b[i]);
217     if (bOverlap == OverlapKind::NoOverlap) {
218       return OverlapKind::NoOverlap;
219     }
220 
221     if (overlap == OverlapKind::ContainedOrEqual &&
222         bOverlap == OverlapKind::Contains) {
223       overlap = OverlapKind::Contains;
224     }
225 
226     if (overlap == OverlapKind::Contains &&
227         bOverlap == OverlapKind::ContainedOrEqual) {
228       continue;
229     }
230 
231     if (bOverlap != overlap) {
232       overlap = OverlapKind::PartialOverlap;
233       break;
234     }
235   }
236 
237   return overlap;
238 }
239 
subtractBound(const Bound & a,const Bound & b)240 std::vector<Bound> subtractBound(const Bound& a, const Bound& b) {
241   OverlapKind overlap = boundOverlap(a, b);
242   if (overlap == OverlapKind::NoOverlap) {
243     return {a};
244   }
245   if (overlap == OverlapKind::ContainedOrEqual) {
246     return {};
247   }
248 
249   // The bounds must overlap.
250   std::vector<Bound> res;
251 
252   if (a.start->isConstant() != b.start->isConstant() ||
253       a.end->isConstant() != b.end->isConstant()) {
254     return {a};
255   }
256 
257   ExprPtr lowDiff = IRSimplifier::simplify(alloc<Sub>(b.start, a.start));
258   ExprPtr highDiff = IRSimplifier::simplify(alloc<Sub>(b.end, a.end));
259 
260   // If the diff has only a single var, we can try to guess sign.
261   if (!lowDiff->isConstant()) {
262     auto vars = VarFinder::find(lowDiff);
263     if (vars.size() == 1) {
264       lowDiff = IRSimplifier::simplify(alloc<Sub>(
265           SubstituteInClone(b.start, {{*vars.begin(), immLike(b.start, 1)}}),
266           SubstituteInClone(a.start, {{*vars.begin(), immLike(a.start, 1)}})));
267     }
268   }
269 
270   if (!highDiff->isConstant()) {
271     auto vars = VarFinder::find(highDiff);
272     if (vars.size() == 1) {
273       highDiff = IRSimplifier::simplify(alloc<Sub>(
274           SubstituteInClone(b.end, {{*vars.begin(), immLike(b.end, 1)}}),
275           SubstituteInClone(a.end, {{*vars.begin(), immLike(a.end, 1)}})));
276     }
277   }
278 
279   bool hasHead = lowDiff->isConstant() && immediateAs<int>(lowDiff) > 0;
280   bool hasTail = highDiff->isConstant() && immediateAs<int>(highDiff) < 0;
281 
282   bool constantExtents = lowDiff->isConstant() && highDiff->isConstant();
283 
284   if (!constantExtents) {
285     // If we can't infer the bound lengths, there's no way to create a safe
286     // subset. Just bail out.
287     return {a};
288   }
289 
290   if (hasHead) {
291     res.emplace_back(
292         a.start,
293         IRSimplifier::simplify(alloc<Sub>(b.start, immLike(b.start, 1))));
294   }
295 
296   if (hasTail) {
297     ExprPtr tailStart =
298         IRSimplifier::simplify(alloc<Add>(b.end, immLike(b.end, 1)));
299     res.emplace_back(tailStart, a.end);
300   }
301 
302   return res;
303 }
304 
subtractIndicesBounds(const IndexBounds & A,const IndexBounds & B,OverlapKind overlap)305 std::vector<IndexBounds> subtractIndicesBounds(
306     const IndexBounds& A,
307     const IndexBounds& B,
308     OverlapKind overlap) {
309   if (overlap == OverlapKind::NoOverlap) {
310     return {A};
311   }
312 
313   if (overlap == OverlapKind::ContainedOrEqual) {
314     return {};
315   }
316   // All accesses to a buf must have the same dimensionality.
317   TORCH_INTERNAL_ASSERT(A.size() == B.size(), buildErrorMessage());
318 
319   // Each dimension can be sliced into multiple bound segments.
320   std::vector<IndexBounds> boundSlices;
321   std::vector<Bound> remainingOuterBounds;
322 
323   for (size_t i = 0; i < A.size(); ++i) {
324     auto slices = subtractBound(A[i], B[i]);
325 
326     Bound remaining = A[i];
327 
328     for (const auto& slice : slices) {
329       IndexBounds newRegion;
330       newRegion.reserve(A.size());
331       TORCH_INTERNAL_ASSERT(
332           remainingOuterBounds.size() == i, buildErrorMessage());
333 
334       for (size_t j = 0; j < i; ++j) {
335         newRegion.push_back(remainingOuterBounds[j]);
336       }
337       newRegion.push_back(slice);
338       for (size_t j = i + 1; j < A.size(); ++j) {
339         newRegion.push_back(A[j]);
340       }
341 
342       boundSlices.push_back(newRegion);
343 
344       if (slice.equals(A[i])) {
345         remaining = A[i];
346       } else {
347         auto remainingSlices = subtractBound(remaining, slice);
348         // In some cases, we might end up with empty remainingSlices due to the
349         // optimization done in subtraction while handling diff expressions
350         // that have a single variable in `subtractBound()`.
351         if (!remainingSlices.empty()) {
352           TORCH_INTERNAL_ASSERT(
353               remainingSlices.size() == 1, buildErrorMessage());
354           remaining = remainingSlices[0];
355         }
356       }
357     }
358 
359     remainingOuterBounds.push_back(remaining);
360   }
361 
362   return boundSlices;
363 }
364 
365 std::vector<IndexBounds> TORCH_API
subtractIndicesBounds(const IndexBounds & A,const IndexBounds & B)366 subtractIndicesBounds(const IndexBounds& A, const IndexBounds& B) {
367   return subtractIndicesBounds(A, B, overlaps(A, B));
368 }
369 
370 } // namespace torch::jit::tensorexpr::analysis
371