1 #pragma once 2 #include <c10/core/SymBool.h> 3 #include <c10/core/SymInt.h> 4 #include <c10/macros/Export.h> 5 #include <c10/macros/Macros.h> 6 #include <c10/util/DimVector.h> 7 8 #include <atomic> 9 #include <cstdint> 10 #include <mutex> 11 #include <utility> 12 13 namespace c10 { 14 15 class C10_API SymbolicShapeMeta { 16 public: 17 // Basic metadata from which other quantities are derived 18 SymDimVector sizes_ = {0}; 19 SymDimVector strides_ = {1}; 20 SymInt storage_offset_ = 0; 21 22 bool strides_valid_ = true; // e.g. for sparse where there are no strides 23 24 SymbolicShapeMeta() = default; 25 SymbolicShapeMeta(const SymbolicShapeMeta& other); 26 SymbolicShapeMeta& operator=(const SymbolicShapeMeta& other) = delete; 27 SymbolicShapeMeta& operator=(SymbolicShapeMeta&& other) = delete; 28 refresh_numel()29 void refresh_numel() { 30 // Non-const, don't need to hold mutables_ lock 31 available_.fetch_and(~numel_avail); 32 numel_ = 1; 33 } 34 refresh_contiguous()35 void refresh_contiguous() { 36 // Non-const, don't need to hold mutables_ lock 37 available_.fetch_and(numel_avail); 38 is_contiguous_ = false; 39 is_channels_last_contiguous_ = false; 40 is_channels_last_3d_contiguous_ = false; 41 is_channels_last_ = false; 42 is_channels_last_3d_ = false; 43 is_non_overlapping_and_dense_ = false; 44 } 45 dim()46 int64_t dim() const { 47 return static_cast<int64_t>(sizes_.size()); 48 } 49 50 // Accessors for derived quantities, computed lazily on first access 51 has_numel()52 bool has_numel() const { 53 return available_.load() & numel_avail; 54 } has_is_contiguous()55 bool has_is_contiguous() const { 56 return available_.load() & is_contiguous_avail; 57 } has_is_channels_last_contiguous()58 bool has_is_channels_last_contiguous() const { 59 return available_.load() & is_channels_last_contiguous_avail; 60 } has_is_channels_last_3d_contiguous()61 bool has_is_channels_last_3d_contiguous() const { 62 return available_.load() & is_channels_last_3d_contiguous_avail; 63 } has_is_channels_last()64 bool has_is_channels_last() const { 65 return available_.load() & is_channels_last_avail; 66 } has_is_channels_last_3d()67 bool has_is_channels_last_3d() const { 68 return available_.load() & is_channels_last_3d_avail; 69 } has_is_non_overlapping_and_dense()70 bool has_is_non_overlapping_and_dense() const { 71 return available_.load() & is_non_overlapping_and_dense_avail; 72 } 73 74 // Accessors to cached derived properties 75 // DO NOT call with mutables_ lock held numel()76 const SymInt& numel() const { 77 if (C10_UNLIKELY(!has_numel())) { 78 init_numel(); 79 } 80 return numel_; 81 } 82 is_contiguous()83 const SymBool& is_contiguous() const { 84 if (C10_UNLIKELY(!has_is_contiguous())) { 85 init_is_contiguous(); 86 } 87 return is_contiguous_; 88 } 89 is_channels_last_contiguous()90 const SymBool& is_channels_last_contiguous() const { 91 if (C10_UNLIKELY(!has_is_channels_last_contiguous())) { 92 init_is_channels_last_contiguous(); 93 } 94 return is_channels_last_contiguous_; 95 } 96 is_channels_last_3d_contiguous()97 const SymBool& is_channels_last_3d_contiguous() const { 98 if (C10_UNLIKELY(!has_is_channels_last_3d_contiguous())) { 99 init_is_channels_last_3d_contiguous(); 100 } 101 return is_channels_last_3d_contiguous_; 102 } 103 is_channels_last()104 const SymBool& is_channels_last() const { 105 if (C10_UNLIKELY(!has_is_channels_last())) { 106 init_is_channels_last(); 107 } 108 return is_channels_last_; 109 } 110 is_channels_last_3d()111 const SymBool& is_channels_last_3d() const { 112 if (C10_UNLIKELY(!has_is_channels_last_3d())) { 113 init_is_channels_last_3d(); 114 } 115 return is_channels_last_3d_; 116 } 117 is_non_overlapping_and_dense()118 const SymBool& is_non_overlapping_and_dense() const { 119 if (C10_UNLIKELY(!has_is_non_overlapping_and_dense())) { 120 init_is_non_overlapping_and_dense(); 121 } 122 return is_non_overlapping_and_dense_; 123 } 124 125 // Assumptions so we can short-circuit computation 126 // NOTE: Don't need to lock mutables_ since these aren't const 127 void assume_contiguous(SymBool val = true) { 128 is_contiguous_ = std::move(val); 129 available_.fetch_or(is_contiguous_avail); 130 } 131 void assume_channels_last_contiguous(SymBool val = true) { 132 is_contiguous_ = std::move(val); 133 available_.fetch_or(is_channels_last_contiguous_avail); 134 } 135 void assume_channels_last_3d_contiguous(SymBool val = true) { 136 is_channels_last_3d_contiguous_ = std::move(val); 137 available_.fetch_or(is_channels_last_3d_contiguous_avail); 138 } 139 void assume_channels_last(SymBool val = true) { 140 is_channels_last_ = std::move(val); 141 available_.fetch_or(is_channels_last_avail); 142 } 143 void assume_channels_last_3d(SymBool val = true) { 144 is_channels_last_3d_ = std::move(val); 145 available_.fetch_or(is_channels_last_3d_avail); 146 } 147 void assume_non_overlapping_and_dense(SymBool val = true) { 148 is_non_overlapping_and_dense_ = std::move(val); 149 available_.fetch_or(is_non_overlapping_and_dense_avail); 150 } 151 152 private: 153 SymBool compute_contiguous() const; 154 SymBool compute_channels_last_contiguous_2d() const; 155 SymBool compute_channels_last_contiguous_3d() const; 156 SymBool compute_strides_like_channels_last_2d() const; 157 SymBool compute_strides_like_channels_last_3d() const; 158 SymBool compute_non_overlapping_and_dense() const; 159 160 // These are little wrappers over the real compute_ functions that 161 // can make use of other contiguity fields to short circuit. 162 // They need to be implemented separately for SymBool, as SymBool does 163 // not short circuit. 164 // TODO: should the SymBool cases avoid the short circuit? Need to reason 165 // if its correct, and reason if the simpler expressions are better for 166 // analysis (maybe not!) 167 168 SymBool compute_channels_last_contiguous_3d_dim5() const; 169 SymBool compute_channels_last_2d_dim5() const; 170 SymBool compute_channels_last_3d_dim5() const; 171 SymBool compute_is_non_overlapping_and_dense_dim4() const; 172 SymBool compute_is_non_overlapping_and_dense_dim5() const; 173 SymBool compute_is_non_overlapping_and_dense_anydim() const; 174 175 void init_numel() const; 176 void init_is_contiguous() const; 177 void init_is_channels_last_contiguous() const; 178 void init_is_channels_last_3d_contiguous() const; 179 void init_is_channels_last() const; 180 void init_is_channels_last_3d() const; 181 void init_is_non_overlapping_and_dense() const; 182 183 // NOTE: These only set if !has_foo() 184 void set_numel(SymInt val) const; 185 void set_is_contiguous(SymBool val) const; 186 void set_is_channels_last_contiguous(SymBool val) const; 187 void set_is_channels_last_3d_contiguous(SymBool val) const; 188 void set_is_channels_last(SymBool val) const; 189 void set_is_channels_last_3d(SymBool val) const; 190 void set_is_non_overlapping_and_dense(SymBool val) const; 191 192 // Lazily initialized variables, with the corresponding available_ flag 193 // indicating whether the value has been initialized 194 mutable std::atomic<int> available_{0}; 195 enum avail { 196 numel_avail = 1 << 0, 197 is_contiguous_avail = 1 << 1, 198 is_channels_last_contiguous_avail = 1 << 2, 199 is_channels_last_3d_contiguous_avail = 1 << 3, 200 is_channels_last_avail = 1 << 4, 201 is_channels_last_3d_avail = 1 << 5, 202 is_non_overlapping_and_dense_avail = 1 << 6, 203 }; 204 205 // Mutex to prevent races when initializing the variable from const accessors 206 mutable std::mutex mutables_; 207 mutable SymInt numel_ = 1; 208 mutable SymBool is_contiguous_{true}; 209 mutable SymBool is_channels_last_contiguous_{false}; 210 mutable SymBool is_channels_last_3d_contiguous_{false}; 211 mutable SymBool is_channels_last_{false}; 212 mutable SymBool is_channels_last_3d_{false}; 213 mutable SymBool is_non_overlapping_and_dense_{true}; 214 }; 215 216 } // namespace c10 217