1 #pragma once 2 3 #include <algorithm> 4 #include <cstdint> 5 6 #include <c10/macros/Macros.h> 7 #include <c10/util/ArrayRef.h> 8 #include <c10/util/SmallVector.h> 9 10 #define C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5 11 12 namespace c10::impl { 13 14 // Packed container for TensorImpl sizes and strides. 15 // This design improves on the previous approach of using a pair of 16 // c10::SmallVector<int64_t, 5> by specializing for the operations we 17 // actually use and enforcing that the number of sizes is the same as 18 // the number of strides. The memory layout is as follows: 19 // 20 // 1 size_t for the size 21 // 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer 22 // to out-of-line array 23 class C10_API SizesAndStrides { 24 public: 25 // TODO: different iterator types for sizes & strides to prevent 26 // mixing the two accidentally. 27 using sizes_iterator = int64_t*; 28 using sizes_const_iterator = const int64_t*; 29 using strides_iterator = int64_t*; 30 using strides_const_iterator = const int64_t*; 31 32 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) SizesAndStrides()33 SizesAndStrides() { 34 size_at_unchecked(0) = 0; 35 stride_at_unchecked(0) = 1; 36 } 37 ~SizesAndStrides()38 ~SizesAndStrides() { 39 if (C10_UNLIKELY(!isInline())) { 40 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) 41 free(outOfLineStorage_); 42 } 43 } 44 45 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) SizesAndStrides(const SizesAndStrides & rhs)46 SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) { 47 if (C10_LIKELY(rhs.isInline())) { 48 copyDataInline(rhs); 49 } else { 50 allocateOutOfLineStorage(size_); 51 copyDataOutline(rhs); 52 } 53 } 54 55 SizesAndStrides& operator=(const SizesAndStrides& rhs) { 56 if (this == &rhs) { 57 return *this; 58 } 59 if (C10_LIKELY(rhs.isInline())) { 60 if (C10_UNLIKELY(!isInline())) { 61 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) 62 free(outOfLineStorage_); 63 } 64 copyDataInline(rhs); 65 } else { 66 if (isInline()) { 67 allocateOutOfLineStorage(rhs.size_); 68 } else { 69 resizeOutOfLineStorage(rhs.size_); 70 } 71 copyDataOutline(rhs); 72 } 73 size_ = rhs.size_; 74 return *this; 75 } 76 77 // Move from rhs. rhs.size() == 0 afterwards. SizesAndStrides(SizesAndStrides && rhs)78 SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) { 79 if (C10_LIKELY(isInline())) { 80 memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); 81 } else { 82 outOfLineStorage_ = rhs.outOfLineStorage_; 83 rhs.outOfLineStorage_ = nullptr; 84 } 85 86 rhs.size_ = 0; 87 } 88 89 // Move from rhs. rhs.size() == 0 afterwards. 90 SizesAndStrides& operator=(SizesAndStrides&& rhs) noexcept { 91 if (this == &rhs) { 92 return *this; 93 } 94 if (C10_LIKELY(rhs.isInline())) { 95 if (C10_UNLIKELY(!isInline())) { 96 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) 97 free(outOfLineStorage_); 98 } 99 copyDataInline(rhs); 100 } else { 101 // They're outline. We're going to steal their vector. 102 if (!isInline()) { 103 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) 104 free(outOfLineStorage_); 105 } 106 outOfLineStorage_ = rhs.outOfLineStorage_; 107 rhs.outOfLineStorage_ = nullptr; 108 } 109 size_ = rhs.size_; 110 rhs.size_ = 0; 111 112 return *this; 113 } 114 size()115 size_t size() const noexcept { 116 return size_; 117 } 118 sizes_data()119 const int64_t* sizes_data() const noexcept { 120 if (C10_LIKELY(isInline())) { 121 return &inlineStorage_[0]; 122 } else { 123 return &outOfLineStorage_[0]; 124 } 125 } 126 sizes_data()127 int64_t* sizes_data() noexcept { 128 if (C10_LIKELY(isInline())) { 129 return &inlineStorage_[0]; 130 } else { 131 return &outOfLineStorage_[0]; 132 } 133 } 134 sizes_begin()135 sizes_const_iterator sizes_begin() const noexcept { 136 return sizes_data(); 137 } 138 sizes_begin()139 sizes_iterator sizes_begin() noexcept { 140 return sizes_data(); 141 } 142 sizes_end()143 sizes_const_iterator sizes_end() const noexcept { 144 return sizes_begin() + size(); 145 } 146 sizes_end()147 sizes_iterator sizes_end() noexcept { 148 return sizes_begin() + size(); 149 } 150 sizes_arrayref()151 IntArrayRef sizes_arrayref() const noexcept { 152 return IntArrayRef{sizes_data(), size()}; 153 } 154 set_sizes(IntArrayRef newSizes)155 void set_sizes(IntArrayRef newSizes) { 156 resize(newSizes.size()); 157 std::copy(newSizes.begin(), newSizes.end(), sizes_begin()); 158 } 159 set_strides(IntArrayRef strides)160 void set_strides(IntArrayRef strides) { 161 TORCH_INTERNAL_ASSERT(strides.size() == size()); 162 std::copy(strides.begin(), strides.end(), strides_begin()); 163 } 164 strides_data()165 const int64_t* strides_data() const noexcept { 166 if (C10_LIKELY(isInline())) { 167 return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; 168 } else { 169 return &outOfLineStorage_[size()]; 170 } 171 } 172 strides_data()173 int64_t* strides_data() noexcept { 174 if (C10_LIKELY(isInline())) { 175 return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; 176 } else { 177 return &outOfLineStorage_[size()]; 178 } 179 } 180 strides_begin()181 strides_const_iterator strides_begin() const noexcept { 182 if (C10_LIKELY(isInline())) { 183 return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; 184 } else { 185 return &outOfLineStorage_[size()]; 186 } 187 } 188 strides_begin()189 strides_iterator strides_begin() noexcept { 190 if (C10_LIKELY(isInline())) { 191 return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE]; 192 } else { 193 return &outOfLineStorage_[size()]; 194 } 195 } 196 strides_end()197 strides_const_iterator strides_end() const noexcept { 198 return strides_begin() + size(); 199 } 200 strides_end()201 strides_iterator strides_end() noexcept { 202 return strides_begin() + size(); 203 } 204 strides_arrayref()205 IntArrayRef strides_arrayref() const noexcept { 206 return IntArrayRef{strides_data(), size()}; 207 } 208 209 // Size accessors. size_at(size_t idx)210 int64_t size_at(size_t idx) const noexcept { 211 assert(idx < size()); 212 return sizes_data()[idx]; 213 } 214 size_at(size_t idx)215 int64_t& size_at(size_t idx) noexcept { 216 assert(idx < size()); 217 return sizes_data()[idx]; 218 } 219 size_at_unchecked(size_t idx)220 int64_t size_at_unchecked(size_t idx) const noexcept { 221 return sizes_data()[idx]; 222 } 223 size_at_unchecked(size_t idx)224 int64_t& size_at_unchecked(size_t idx) noexcept { 225 return sizes_data()[idx]; 226 } 227 228 // Size accessors. stride_at(size_t idx)229 int64_t stride_at(size_t idx) const noexcept { 230 assert(idx < size()); 231 return strides_data()[idx]; 232 } 233 stride_at(size_t idx)234 int64_t& stride_at(size_t idx) noexcept { 235 assert(idx < size()); 236 return strides_data()[idx]; 237 } 238 stride_at_unchecked(size_t idx)239 int64_t stride_at_unchecked(size_t idx) const noexcept { 240 return strides_data()[idx]; 241 } 242 stride_at_unchecked(size_t idx)243 int64_t& stride_at_unchecked(size_t idx) noexcept { 244 return strides_data()[idx]; 245 } 246 resize(size_t newSize)247 void resize(size_t newSize) { 248 const auto oldSize = size(); 249 if (newSize == oldSize) { 250 return; 251 } 252 if (C10_LIKELY( 253 newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) { 254 if (oldSize < newSize) { 255 const auto bytesToZero = 256 (newSize - oldSize) * sizeof(inlineStorage_[0]); 257 memset(&inlineStorage_[oldSize], 0, bytesToZero); 258 memset( 259 &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize], 260 0, 261 bytesToZero); 262 } 263 size_ = newSize; 264 } else { 265 resizeSlowPath(newSize, oldSize); 266 } 267 } 268 269 void resizeSlowPath(size_t newSize, size_t oldSize); 270 271 private: isInline()272 bool isInline() const noexcept { 273 return size_ <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE; 274 } 275 copyDataInline(const SizesAndStrides & rhs)276 void copyDataInline(const SizesAndStrides& rhs) { 277 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline()); 278 memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_)); 279 } 280 storageBytes(size_t size)281 static size_t storageBytes(size_t size) noexcept { 282 return size * 2 * sizeof(int64_t); 283 } 284 allocateOutOfLineStorage(size_t size)285 void allocateOutOfLineStorage(size_t size) { 286 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) 287 outOfLineStorage_ = static_cast<int64_t*>(malloc(storageBytes(size))); 288 TORCH_CHECK( 289 outOfLineStorage_, 290 "Could not allocate memory for Tensor SizesAndStrides!"); 291 } 292 resizeOutOfLineStorage(size_t newSize)293 void resizeOutOfLineStorage(size_t newSize) { 294 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline()); 295 outOfLineStorage_ = static_cast<int64_t*>( 296 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) 297 realloc(outOfLineStorage_, storageBytes(newSize))); 298 TORCH_CHECK( 299 outOfLineStorage_, 300 "Could not allocate memory for Tensor SizesAndStrides!"); 301 } 302 copyDataOutline(const SizesAndStrides & rhs)303 void copyDataOutline(const SizesAndStrides& rhs) noexcept { 304 memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_)); 305 } 306 307 size_t size_{1}; 308 union { 309 int64_t* outOfLineStorage_; 310 // NOLINTNEXTLINE(*c-array*) 311 int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{}; 312 }; 313 }; 314 315 } // namespace c10::impl 316