1 #include <torch/csrc/jit/tensorexpr/bounds_overlap.h>
2 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
3 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
4 #include <torch/csrc/jit/tensorexpr/stmt.h>
5
6 #include <iostream>
7
8 namespace torch::jit::tensorexpr::analysis {
9
10 // Returns true if the given expression is guaranteed to be positive.
mustBePositive(const ExprPtr & e)11 static bool mustBePositive(const ExprPtr& e) {
12 if (e->isConstant()) {
13 int e_val = immediateAs<int>(e);
14 return e_val > 0;
15 }
16 return false;
17 }
18
19 // Returns true if the given expression is guaranteed to be negative.
mustBeNegative(const ExprPtr & e)20 static bool mustBeNegative(const ExprPtr& e) {
21 if (e->isConstant()) {
22 int e_val = immediateAs<int>(e);
23 return e_val < 0;
24 }
25 return false;
26 }
27
28 // Returns true if the given expression is guaranteed to be zero.
mustBeZero(const ExprPtr & e)29 static bool mustBeZero(const ExprPtr& e) {
30 if (e->isConstant()) {
31 int e_val = immediateAs<int>(e);
32 return e_val == 0;
33 }
34 return false;
35 }
36
print() const37 void Bound::print() const {
38 std::cout << "(" << *start << ", " << *end << ")";
39 }
40
equals(const Bound & other) const41 bool Bound::equals(const Bound& other) const {
42 return exprEquals(start, other.start) && exprEquals(end, other.end);
43 }
44
operator ==(const Bound & other) const45 bool Bound::operator==(const Bound& other) const {
46 if (equals(other)) {
47 auto ret_expr = IRSimplifier::simplify(alloc<Sub>(start, end));
48 return mustBeZero(ret_expr);
49 }
50
51 return false;
52 }
53
operator !=(const Bound & other) const54 bool Bound::operator!=(const Bound& other) const {
55 return (*this < other) || (*this > other);
56 }
57
operator >=(const Bound & other) const58 bool Bound::operator>=(const Bound& other) const {
59 if (*this == other) {
60 return true;
61 }
62 auto ret_expr = IRSimplifier::simplify(alloc<Sub>(start, other.end));
63 return mustBePositive(ret_expr) || mustBeZero(ret_expr);
64 }
65
operator >(const Bound & other) const66 bool Bound::operator>(const Bound& other) const {
67 auto ret_expr = IRSimplifier::simplify(alloc<Sub>(start, other.end));
68 return mustBePositive(ret_expr);
69 }
70
operator <=(const Bound & other) const71 bool Bound::operator<=(const Bound& other) const {
72 if (*this == other) {
73 return true;
74 }
75 auto ret_expr = IRSimplifier::simplify(alloc<Sub>(end, other.start));
76 return mustBeNegative(ret_expr) || mustBeZero(ret_expr);
77 }
78
operator <(const Bound & other) const79 bool Bound::operator<(const Bound& other) const {
80 auto ret_expr = IRSimplifier::simplify(alloc<Sub>(end, other.start));
81 return mustBeNegative(ret_expr);
82 }
83
boundOverlap(const Bound & a,const Bound & b)84 OverlapKind boundOverlap(const Bound& a, const Bound& b) {
85 // If they're equal they're equal.
86 bool startEqual = exprEquals(a.start, b.start);
87 bool endEqual = exprEquals(a.end, b.end);
88 if (startEqual && endEqual) {
89 return OverlapKind::ContainedOrEqual;
90 }
91
92 // We have to figure out if the bounds fall under the following 2 cases:
93 // 1. a is before b
94 // a.start ... a.end ... b.start ... b.end
95 // 2. b is before a
96 // b.start ... b.end ... a.start ... a.end
97 //
98 // So, we compute "a.start - b.end" and "b.start - a.end". If even one of
99 // those is positive, then it is guaranteed that the bounds do not overlap.
100 //
101 // If the diff is a constant, then we can directly check if the constant is
102 // positive. If the diff is not a constant, then it will be made of
103 // variables that correspond to the bounds of buffers involved. These buffer
104 // bounds can never be negative. So, we check if the given expression is
105 // guaranteed to be positive under the assumption that the variables involved
106 // are never negative.
107
108 ExprPtr lowDiff = IRSimplifier::simplify(alloc<Sub>(a.start, b.end));
109 ExprPtr highDiff = IRSimplifier::simplify(alloc<Sub>(b.start, a.end));
110
111 if (mustBePositive(lowDiff)) {
112 return OverlapKind::NoOverlap;
113 }
114 if (mustBePositive(highDiff)) {
115 return OverlapKind::NoOverlap;
116 }
117
118 ExprPtr diff_start = IRSimplifier::simplify(alloc<Sub>(b.start, a.start));
119 ExprPtr diff_end = IRSimplifier::simplify(alloc<Sub>(b.end, a.end));
120
121 // If one side fully encloses the other, they're adjacent.
122 if (diff_start->isConstant() && diff_end->isConstant()) {
123 int start = immediateAs<int>(diff_start);
124 int end = immediateAs<int>(diff_end);
125 // If diff_start and diff_end have different signs they are enclosing.
126 if (start <= 0 && end >= 0) {
127 return OverlapKind::ContainedOrEqual;
128 }
129
130 if (start >= 0 && end <= 0) {
131 return OverlapKind::Contains;
132 }
133 }
134
135 // We can't be sure there's no overlap so the conservative answer is
136 // partial.
137 return OverlapKind::PartialOverlap;
138 }
139
compareBound(const Bound & a,const Bound & b,const CompareSelectOperation & cmp_op)140 CmpEvalResult TORCH_API compareBound(
141 const Bound& a,
142 const Bound& b,
143 const CompareSelectOperation& cmp_op) {
144 switch (cmp_op) {
145 case CompareSelectOperation::kGT:
146 return (a > b)
147 ? CmpEvalResult::True
148 : (a <= b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
149 case CompareSelectOperation::kGE:
150 return (a >= b)
151 ? CmpEvalResult::True
152 : (a < b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
153 case CompareSelectOperation::kLT:
154 return (a < b)
155 ? CmpEvalResult::True
156 : (a >= b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
157 case CompareSelectOperation::kLE:
158 return (a <= b)
159 ? CmpEvalResult::True
160 : (a > b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
161 case CompareSelectOperation::kNE:
162 return (a != b)
163 ? CmpEvalResult::True
164 : (a == b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
165 default:
166 TORCH_INTERNAL_ASSERT(cmp_op == CompareSelectOperation::kEQ)
167 return (a == b)
168 ? CmpEvalResult::True
169 : (a != b ? CmpEvalResult::False : CmpEvalResult::NotDetermined);
170 }
171 }
172
indexBoundsEquals(const IndexBounds & A,const IndexBounds & B)173 bool indexBoundsEquals(const IndexBounds& A, const IndexBounds& B) {
174 if (A.size() != B.size()) {
175 return false;
176 }
177
178 for (size_t i = 0; i != A.size(); ++i) {
179 if (!A[i].equals(B[i])) {
180 return false;
181 }
182 }
183 return true;
184 }
185
flattenBounds(const IndexBounds & a)186 Bound flattenBounds(const IndexBounds& a) {
187 if (a.empty()) {
188 return Bound();
189 }
190 Bound ret = a[0];
191
192 for (size_t i = 1; i < a.size(); ++i) {
193 ret.start = alloc<Mul>(ret.start, a[i].start);
194 ret.end = alloc<Mul>(ret.end, a[i].end);
195 }
196
197 ret.start = IRSimplifier::simplify(ret.start);
198 ret.end = IRSimplifier::simplify(ret.end);
199 return ret;
200 }
201
overlaps(const IndexBounds & a,const IndexBounds & b)202 OverlapKind overlaps(const IndexBounds& a, const IndexBounds& b) {
203 if (a.empty() && b.empty()) {
204 return OverlapKind::ContainedOrEqual;
205 }
206
207 // All accesses to a buf must have the same dimensionality.
208
209 if (a.size() != b.size()) {
210 return boundOverlap(flattenBounds(a), flattenBounds(b));
211 }
212 TORCH_INTERNAL_ASSERT(a.size() == b.size());
213
214 OverlapKind overlap = boundOverlap(a[0], b[0]);
215 for (size_t i = 1; i < a.size(); ++i) {
216 OverlapKind bOverlap = boundOverlap(a[i], b[i]);
217 if (bOverlap == OverlapKind::NoOverlap) {
218 return OverlapKind::NoOverlap;
219 }
220
221 if (overlap == OverlapKind::ContainedOrEqual &&
222 bOverlap == OverlapKind::Contains) {
223 overlap = OverlapKind::Contains;
224 }
225
226 if (overlap == OverlapKind::Contains &&
227 bOverlap == OverlapKind::ContainedOrEqual) {
228 continue;
229 }
230
231 if (bOverlap != overlap) {
232 overlap = OverlapKind::PartialOverlap;
233 break;
234 }
235 }
236
237 return overlap;
238 }
239
subtractBound(const Bound & a,const Bound & b)240 std::vector<Bound> subtractBound(const Bound& a, const Bound& b) {
241 OverlapKind overlap = boundOverlap(a, b);
242 if (overlap == OverlapKind::NoOverlap) {
243 return {a};
244 }
245 if (overlap == OverlapKind::ContainedOrEqual) {
246 return {};
247 }
248
249 // The bounds must overlap.
250 std::vector<Bound> res;
251
252 if (a.start->isConstant() != b.start->isConstant() ||
253 a.end->isConstant() != b.end->isConstant()) {
254 return {a};
255 }
256
257 ExprPtr lowDiff = IRSimplifier::simplify(alloc<Sub>(b.start, a.start));
258 ExprPtr highDiff = IRSimplifier::simplify(alloc<Sub>(b.end, a.end));
259
260 // If the diff has only a single var, we can try to guess sign.
261 if (!lowDiff->isConstant()) {
262 auto vars = VarFinder::find(lowDiff);
263 if (vars.size() == 1) {
264 lowDiff = IRSimplifier::simplify(alloc<Sub>(
265 SubstituteInClone(b.start, {{*vars.begin(), immLike(b.start, 1)}}),
266 SubstituteInClone(a.start, {{*vars.begin(), immLike(a.start, 1)}})));
267 }
268 }
269
270 if (!highDiff->isConstant()) {
271 auto vars = VarFinder::find(highDiff);
272 if (vars.size() == 1) {
273 highDiff = IRSimplifier::simplify(alloc<Sub>(
274 SubstituteInClone(b.end, {{*vars.begin(), immLike(b.end, 1)}}),
275 SubstituteInClone(a.end, {{*vars.begin(), immLike(a.end, 1)}})));
276 }
277 }
278
279 bool hasHead = lowDiff->isConstant() && immediateAs<int>(lowDiff) > 0;
280 bool hasTail = highDiff->isConstant() && immediateAs<int>(highDiff) < 0;
281
282 bool constantExtents = lowDiff->isConstant() && highDiff->isConstant();
283
284 if (!constantExtents) {
285 // If we can't infer the bound lengths, there's no way to create a safe
286 // subset. Just bail out.
287 return {a};
288 }
289
290 if (hasHead) {
291 res.emplace_back(
292 a.start,
293 IRSimplifier::simplify(alloc<Sub>(b.start, immLike(b.start, 1))));
294 }
295
296 if (hasTail) {
297 ExprPtr tailStart =
298 IRSimplifier::simplify(alloc<Add>(b.end, immLike(b.end, 1)));
299 res.emplace_back(tailStart, a.end);
300 }
301
302 return res;
303 }
304
subtractIndicesBounds(const IndexBounds & A,const IndexBounds & B,OverlapKind overlap)305 std::vector<IndexBounds> subtractIndicesBounds(
306 const IndexBounds& A,
307 const IndexBounds& B,
308 OverlapKind overlap) {
309 if (overlap == OverlapKind::NoOverlap) {
310 return {A};
311 }
312
313 if (overlap == OverlapKind::ContainedOrEqual) {
314 return {};
315 }
316 // All accesses to a buf must have the same dimensionality.
317 TORCH_INTERNAL_ASSERT(A.size() == B.size(), buildErrorMessage());
318
319 // Each dimension can be sliced into multiple bound segments.
320 std::vector<IndexBounds> boundSlices;
321 std::vector<Bound> remainingOuterBounds;
322
323 for (size_t i = 0; i < A.size(); ++i) {
324 auto slices = subtractBound(A[i], B[i]);
325
326 Bound remaining = A[i];
327
328 for (const auto& slice : slices) {
329 IndexBounds newRegion;
330 newRegion.reserve(A.size());
331 TORCH_INTERNAL_ASSERT(
332 remainingOuterBounds.size() == i, buildErrorMessage());
333
334 for (size_t j = 0; j < i; ++j) {
335 newRegion.push_back(remainingOuterBounds[j]);
336 }
337 newRegion.push_back(slice);
338 for (size_t j = i + 1; j < A.size(); ++j) {
339 newRegion.push_back(A[j]);
340 }
341
342 boundSlices.push_back(newRegion);
343
344 if (slice.equals(A[i])) {
345 remaining = A[i];
346 } else {
347 auto remainingSlices = subtractBound(remaining, slice);
348 // In some cases, we might end up with empty remainingSlices due to the
349 // optimization done in subtraction while handling diff expressions
350 // that have a single variable in `subtractBound()`.
351 if (!remainingSlices.empty()) {
352 TORCH_INTERNAL_ASSERT(
353 remainingSlices.size() == 1, buildErrorMessage());
354 remaining = remainingSlices[0];
355 }
356 }
357 }
358
359 remainingOuterBounds.push_back(remaining);
360 }
361
362 return boundSlices;
363 }
364
365 std::vector<IndexBounds> TORCH_API
subtractIndicesBounds(const IndexBounds & A,const IndexBounds & B)366 subtractIndicesBounds(const IndexBounds& A, const IndexBounds& B) {
367 return subtractIndicesBounds(A, B, overlaps(A, B));
368 }
369
370 } // namespace torch::jit::tensorexpr::analysis
371