1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/expr.h> 4 #include <torch/csrc/jit/tensorexpr/ir.h> 5 6 #include <utility> 7 #include <vector> 8 9 namespace torch { 10 namespace jit { 11 namespace tensorexpr { 12 namespace analysis { 13 14 // A simple class containing the start and end of a range in a single dimension. 15 struct TORCH_API Bound { 16 ExprPtr start{nullptr}; 17 ExprPtr end{nullptr}; 18 19 // This stores whether or not the start and end of this Bound have previously 20 // been swapped. This occurs when the bound is in a loop with a negative 21 // stride. 22 bool swapped{false}; 23 24 Bound() = default; BoundBound25 Bound(ExprPtr s, ExprPtr e) : start(std::move(s)), end(std::move(e)) {} 26 27 void print() const; 28 bool equals(const Bound& other) const; 29 30 // The comparison operators are conservative. If the compare operator returns 31 // true, it means that all the elements satisfy the logical expression. But 32 // the false does not mean the opposite comparison is satisfied. It could be 33 // but not always. 34 bool operator==(const Bound& other) const; 35 bool operator!=(const Bound& other) const; 36 bool operator<(const Bound& other) const; 37 bool operator<=(const Bound& other) const; 38 bool operator>(const Bound& other) const; 39 bool operator>=(const Bound& other) const; 40 swapBound41 void swap() { 42 std::swap(start, end); 43 swapped = !swapped; 44 } 45 }; 46 47 struct BoundHash { operatorBoundHash48 size_t operator()(const Bound& b) const { 49 return std::hash<ExprPtr>()(b.start) ^ std::hash<ExprPtr>()(b.end); 50 } 51 }; 52 53 // The type of overlap found. Each condition is true only if none of the 54 // previous conditions hold. 55 // ContainedOrEqual: All elements in the Bound A are in the Bound B (this 56 // includes the case where the bounds are equal). 57 // Contains: All elements in the Bound B are in the Bound B. 58 // PartialOverlap: Any elements in the Bound B are in the Bound A. 59 // NoOverlap: No elements in the Bound A are in the bound B. 60 enum class OverlapKind { 61 ContainedOrEqual, 62 Contains, 63 PartialOverlap, 64 NoOverlap 65 }; 66 67 // The Bound comparison result. 68 // True: Every Bound element always satisfies the given comparison operator 69 // False: Every Bound element always does NOT satisfy the given comparison 70 // operator 71 // NotDetermined: Some elements satisfy the given comparison operator and 72 // some elements not 73 enum class CmpEvalResult { True, False, NotDetermined }; 74 75 // Returns the kind of overlap between Bound A and Bound A in a single 76 // dimension. 77 OverlapKind TORCH_API boundOverlap(const Bound& A, const Bound& B); 78 79 // The comparison is conservative and the compare result is deterministic. 80 // It means that every element of the Bound to be compared needs to satisfy 81 // the given comparison operator. 82 CmpEvalResult TORCH_API compareBound( 83 const Bound& a, 84 const Bound& b, 85 const CompareSelectOperation& cmp_op); 86 87 // A multi dimensional bound representing the bound of a set of indices. 88 using IndexBounds = std::vector<Bound>; 89 90 // Returns true if two IndexBounds are equivalent. 91 bool TORCH_API indexBoundsEquals(const IndexBounds& A, const IndexBounds& B); 92 93 // Flattens a multi dimensional bound to a single dimension. The IndexBounds "a" 94 // *must* encapsulate the entire range of the buffer. 95 Bound TORCH_API flattenBounds(const IndexBounds& a); 96 97 // Determines the kind of overlap in X dimensions. 98 OverlapKind TORCH_API overlaps(const IndexBounds& a, const IndexBounds& b); 99 100 // Returns the Bound slices created by subtracing bound B from bound A. 101 // Multiple Bounds can be returned in the case where B slices A into two 102 // distinct regions with no overlap. 103 // 104 // For example: 105 // subtractBound((0, 10), (2, 4)) => [(0, 1), (5, 10)] 106 // bound A: (0, 10) 107 // bound B: (2, 4) 108 // If we remove slice (2, 4) from the slice (0, 10), we will be left 109 // with 2 slices, one at the start (0, 1), and one at the end (5, 10). 110 // So, the result of this subtraction is [(0, 1), (5, 10)]. 111 // 112 // Note: this doesn't use IndexBounds because the Bounds returned do not 113 // represent multiple different dimensions. 114 std::vector<Bound> TORCH_API subtractBound(const Bound& a, const Bound& b); 115 116 // Returns the bound slices created by subtracting the IndexBounds B from A. 117 std::vector<IndexBounds> TORCH_API subtractIndicesBounds( 118 const IndexBounds& A, 119 const IndexBounds& B, 120 OverlapKind overlap); 121 std::vector<IndexBounds> TORCH_API 122 subtractIndicesBounds(const IndexBounds& A, const IndexBounds& B); 123 124 } // namespace analysis 125 } // namespace tensorexpr 126 } // namespace jit 127 } // namespace torch 128