xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/bounds_overlap.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/expr.h>
4 #include <torch/csrc/jit/tensorexpr/ir.h>
5 
6 #include <utility>
7 #include <vector>
8 
9 namespace torch {
10 namespace jit {
11 namespace tensorexpr {
12 namespace analysis {
13 
14 // A simple class containing the start and end of a range in a single dimension.
15 struct TORCH_API Bound {
16   ExprPtr start{nullptr};
17   ExprPtr end{nullptr};
18 
19   // This stores whether or not the start and end of this Bound have previously
20   // been swapped. This occurs when the bound is in a loop with a negative
21   // stride.
22   bool swapped{false};
23 
24   Bound() = default;
BoundBound25   Bound(ExprPtr s, ExprPtr e) : start(std::move(s)), end(std::move(e)) {}
26 
27   void print() const;
28   bool equals(const Bound& other) const;
29 
30   // The comparison operators are conservative. If the compare operator returns
31   // true, it means that all the elements satisfy the logical expression. But
32   // the false does not mean the opposite comparison is satisfied. It could be
33   // but not always.
34   bool operator==(const Bound& other) const;
35   bool operator!=(const Bound& other) const;
36   bool operator<(const Bound& other) const;
37   bool operator<=(const Bound& other) const;
38   bool operator>(const Bound& other) const;
39   bool operator>=(const Bound& other) const;
40 
swapBound41   void swap() {
42     std::swap(start, end);
43     swapped = !swapped;
44   }
45 };
46 
47 struct BoundHash {
operatorBoundHash48   size_t operator()(const Bound& b) const {
49     return std::hash<ExprPtr>()(b.start) ^ std::hash<ExprPtr>()(b.end);
50   }
51 };
52 
53 // The type of overlap found. Each condition is true only if none of the
54 // previous conditions hold.
55 //     ContainedOrEqual: All elements in the Bound A are in the Bound B (this
56 //                       includes the case where the bounds are equal).
57 //     Contains: All elements in the Bound B are in the Bound B.
58 //     PartialOverlap: Any elements in the Bound B are in the Bound A.
59 //     NoOverlap: No elements in the Bound A are in the bound B.
60 enum class OverlapKind {
61   ContainedOrEqual,
62   Contains,
63   PartialOverlap,
64   NoOverlap
65 };
66 
67 // The Bound comparison result.
68 //     True: Every Bound element always satisfies the given comparison operator
69 //     False: Every Bound element always does NOT satisfy the given comparison
70 //     operator
71 //     NotDetermined: Some elements satisfy the given comparison operator and
72 //     some elements not
73 enum class CmpEvalResult { True, False, NotDetermined };
74 
75 // Returns the kind of overlap between Bound A and Bound A in a single
76 // dimension.
77 OverlapKind TORCH_API boundOverlap(const Bound& A, const Bound& B);
78 
79 // The comparison is conservative and the compare result is deterministic.
80 // It means that every element of the Bound to be compared needs to satisfy
81 // the given comparison operator.
82 CmpEvalResult TORCH_API compareBound(
83     const Bound& a,
84     const Bound& b,
85     const CompareSelectOperation& cmp_op);
86 
87 // A multi dimensional bound representing the bound of a set of indices.
88 using IndexBounds = std::vector<Bound>;
89 
90 // Returns true if two IndexBounds are equivalent.
91 bool TORCH_API indexBoundsEquals(const IndexBounds& A, const IndexBounds& B);
92 
93 // Flattens a multi dimensional bound to a single dimension. The IndexBounds "a"
94 // *must* encapsulate the entire range of the buffer.
95 Bound TORCH_API flattenBounds(const IndexBounds& a);
96 
97 // Determines the kind of overlap in X dimensions.
98 OverlapKind TORCH_API overlaps(const IndexBounds& a, const IndexBounds& b);
99 
100 // Returns the Bound slices created by subtracing bound B from bound A.
101 // Multiple Bounds can be returned in the case where B slices A into two
102 // distinct regions with no overlap.
103 //
104 // For example:
105 //    subtractBound((0, 10), (2, 4)) => [(0, 1), (5, 10)]
106 //       bound A: (0, 10)
107 //       bound B: (2, 4)
108 //       If we remove slice (2, 4) from the slice (0, 10), we will be left
109 //       with 2 slices, one at the start (0, 1), and one at the end (5, 10).
110 //       So, the result of this subtraction is [(0, 1), (5, 10)].
111 //
112 // Note: this doesn't use IndexBounds because the Bounds returned do not
113 // represent multiple different dimensions.
114 std::vector<Bound> TORCH_API subtractBound(const Bound& a, const Bound& b);
115 
116 // Returns the bound slices created by subtracting the IndexBounds B from A.
117 std::vector<IndexBounds> TORCH_API subtractIndicesBounds(
118     const IndexBounds& A,
119     const IndexBounds& B,
120     OverlapKind overlap);
121 std::vector<IndexBounds> TORCH_API
122 subtractIndicesBounds(const IndexBounds& A, const IndexBounds& B);
123 
124 } // namespace analysis
125 } // namespace tensorexpr
126 } // namespace jit
127 } // namespace torch
128