xref: /aosp_15_r20/external/pytorch/c10/core/impl/SizesAndStrides.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <algorithm>
4 #include <cstdint>
5 
6 #include <c10/macros/Macros.h>
7 #include <c10/util/ArrayRef.h>
8 #include <c10/util/SmallVector.h>
9 
10 #define C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE 5
11 
12 namespace c10::impl {
13 
14 // Packed container for TensorImpl sizes and strides.
15 // This design improves on the previous approach of using a pair of
16 // c10::SmallVector<int64_t, 5> by specializing for the operations we
17 // actually use and enforcing that the number of sizes is the same as
18 // the number of strides. The memory layout is as follows:
19 //
20 // 1 size_t for the size
21 // 5 eightbytes of inline sizes and 5 eightbytes of inline strides, OR pointer
22 // to out-of-line array
23 class C10_API SizesAndStrides {
24  public:
25   // TODO: different iterator types for sizes & strides to prevent
26   // mixing the two accidentally.
27   using sizes_iterator = int64_t*;
28   using sizes_const_iterator = const int64_t*;
29   using strides_iterator = int64_t*;
30   using strides_const_iterator = const int64_t*;
31 
32   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
SizesAndStrides()33   SizesAndStrides() {
34     size_at_unchecked(0) = 0;
35     stride_at_unchecked(0) = 1;
36   }
37 
~SizesAndStrides()38   ~SizesAndStrides() {
39     if (C10_UNLIKELY(!isInline())) {
40       // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
41       free(outOfLineStorage_);
42     }
43   }
44 
45   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
SizesAndStrides(const SizesAndStrides & rhs)46   SizesAndStrides(const SizesAndStrides& rhs) : size_(rhs.size_) {
47     if (C10_LIKELY(rhs.isInline())) {
48       copyDataInline(rhs);
49     } else {
50       allocateOutOfLineStorage(size_);
51       copyDataOutline(rhs);
52     }
53   }
54 
55   SizesAndStrides& operator=(const SizesAndStrides& rhs) {
56     if (this == &rhs) {
57       return *this;
58     }
59     if (C10_LIKELY(rhs.isInline())) {
60       if (C10_UNLIKELY(!isInline())) {
61         // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
62         free(outOfLineStorage_);
63       }
64       copyDataInline(rhs);
65     } else {
66       if (isInline()) {
67         allocateOutOfLineStorage(rhs.size_);
68       } else {
69         resizeOutOfLineStorage(rhs.size_);
70       }
71       copyDataOutline(rhs);
72     }
73     size_ = rhs.size_;
74     return *this;
75   }
76 
77   // Move from rhs. rhs.size() == 0 afterwards.
SizesAndStrides(SizesAndStrides && rhs)78   SizesAndStrides(SizesAndStrides&& rhs) noexcept : size_(rhs.size_) {
79     if (C10_LIKELY(isInline())) {
80       memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_));
81     } else {
82       outOfLineStorage_ = rhs.outOfLineStorage_;
83       rhs.outOfLineStorage_ = nullptr;
84     }
85 
86     rhs.size_ = 0;
87   }
88 
89   // Move from rhs. rhs.size() == 0 afterwards.
90   SizesAndStrides& operator=(SizesAndStrides&& rhs) noexcept {
91     if (this == &rhs) {
92       return *this;
93     }
94     if (C10_LIKELY(rhs.isInline())) {
95       if (C10_UNLIKELY(!isInline())) {
96         // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
97         free(outOfLineStorage_);
98       }
99       copyDataInline(rhs);
100     } else {
101       // They're outline. We're going to steal their vector.
102       if (!isInline()) {
103         // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
104         free(outOfLineStorage_);
105       }
106       outOfLineStorage_ = rhs.outOfLineStorage_;
107       rhs.outOfLineStorage_ = nullptr;
108     }
109     size_ = rhs.size_;
110     rhs.size_ = 0;
111 
112     return *this;
113   }
114 
size()115   size_t size() const noexcept {
116     return size_;
117   }
118 
sizes_data()119   const int64_t* sizes_data() const noexcept {
120     if (C10_LIKELY(isInline())) {
121       return &inlineStorage_[0];
122     } else {
123       return &outOfLineStorage_[0];
124     }
125   }
126 
sizes_data()127   int64_t* sizes_data() noexcept {
128     if (C10_LIKELY(isInline())) {
129       return &inlineStorage_[0];
130     } else {
131       return &outOfLineStorage_[0];
132     }
133   }
134 
sizes_begin()135   sizes_const_iterator sizes_begin() const noexcept {
136     return sizes_data();
137   }
138 
sizes_begin()139   sizes_iterator sizes_begin() noexcept {
140     return sizes_data();
141   }
142 
sizes_end()143   sizes_const_iterator sizes_end() const noexcept {
144     return sizes_begin() + size();
145   }
146 
sizes_end()147   sizes_iterator sizes_end() noexcept {
148     return sizes_begin() + size();
149   }
150 
sizes_arrayref()151   IntArrayRef sizes_arrayref() const noexcept {
152     return IntArrayRef{sizes_data(), size()};
153   }
154 
set_sizes(IntArrayRef newSizes)155   void set_sizes(IntArrayRef newSizes) {
156     resize(newSizes.size());
157     std::copy(newSizes.begin(), newSizes.end(), sizes_begin());
158   }
159 
set_strides(IntArrayRef strides)160   void set_strides(IntArrayRef strides) {
161     TORCH_INTERNAL_ASSERT(strides.size() == size());
162     std::copy(strides.begin(), strides.end(), strides_begin());
163   }
164 
strides_data()165   const int64_t* strides_data() const noexcept {
166     if (C10_LIKELY(isInline())) {
167       return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE];
168     } else {
169       return &outOfLineStorage_[size()];
170     }
171   }
172 
strides_data()173   int64_t* strides_data() noexcept {
174     if (C10_LIKELY(isInline())) {
175       return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE];
176     } else {
177       return &outOfLineStorage_[size()];
178     }
179   }
180 
strides_begin()181   strides_const_iterator strides_begin() const noexcept {
182     if (C10_LIKELY(isInline())) {
183       return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE];
184     } else {
185       return &outOfLineStorage_[size()];
186     }
187   }
188 
strides_begin()189   strides_iterator strides_begin() noexcept {
190     if (C10_LIKELY(isInline())) {
191       return &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE];
192     } else {
193       return &outOfLineStorage_[size()];
194     }
195   }
196 
strides_end()197   strides_const_iterator strides_end() const noexcept {
198     return strides_begin() + size();
199   }
200 
strides_end()201   strides_iterator strides_end() noexcept {
202     return strides_begin() + size();
203   }
204 
strides_arrayref()205   IntArrayRef strides_arrayref() const noexcept {
206     return IntArrayRef{strides_data(), size()};
207   }
208 
209   // Size accessors.
size_at(size_t idx)210   int64_t size_at(size_t idx) const noexcept {
211     assert(idx < size());
212     return sizes_data()[idx];
213   }
214 
size_at(size_t idx)215   int64_t& size_at(size_t idx) noexcept {
216     assert(idx < size());
217     return sizes_data()[idx];
218   }
219 
size_at_unchecked(size_t idx)220   int64_t size_at_unchecked(size_t idx) const noexcept {
221     return sizes_data()[idx];
222   }
223 
size_at_unchecked(size_t idx)224   int64_t& size_at_unchecked(size_t idx) noexcept {
225     return sizes_data()[idx];
226   }
227 
228   // Size accessors.
stride_at(size_t idx)229   int64_t stride_at(size_t idx) const noexcept {
230     assert(idx < size());
231     return strides_data()[idx];
232   }
233 
stride_at(size_t idx)234   int64_t& stride_at(size_t idx) noexcept {
235     assert(idx < size());
236     return strides_data()[idx];
237   }
238 
stride_at_unchecked(size_t idx)239   int64_t stride_at_unchecked(size_t idx) const noexcept {
240     return strides_data()[idx];
241   }
242 
stride_at_unchecked(size_t idx)243   int64_t& stride_at_unchecked(size_t idx) noexcept {
244     return strides_data()[idx];
245   }
246 
resize(size_t newSize)247   void resize(size_t newSize) {
248     const auto oldSize = size();
249     if (newSize == oldSize) {
250       return;
251     }
252     if (C10_LIKELY(
253             newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE && isInline())) {
254       if (oldSize < newSize) {
255         const auto bytesToZero =
256             (newSize - oldSize) * sizeof(inlineStorage_[0]);
257         memset(&inlineStorage_[oldSize], 0, bytesToZero);
258         memset(
259             &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE + oldSize],
260             0,
261             bytesToZero);
262       }
263       size_ = newSize;
264     } else {
265       resizeSlowPath(newSize, oldSize);
266     }
267   }
268 
269   void resizeSlowPath(size_t newSize, size_t oldSize);
270 
271  private:
isInline()272   bool isInline() const noexcept {
273     return size_ <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE;
274   }
275 
copyDataInline(const SizesAndStrides & rhs)276   void copyDataInline(const SizesAndStrides& rhs) {
277     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rhs.isInline());
278     memcpy(inlineStorage_, rhs.inlineStorage_, sizeof(inlineStorage_));
279   }
280 
storageBytes(size_t size)281   static size_t storageBytes(size_t size) noexcept {
282     return size * 2 * sizeof(int64_t);
283   }
284 
allocateOutOfLineStorage(size_t size)285   void allocateOutOfLineStorage(size_t size) {
286     // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
287     outOfLineStorage_ = static_cast<int64_t*>(malloc(storageBytes(size)));
288     TORCH_CHECK(
289         outOfLineStorage_,
290         "Could not allocate memory for Tensor SizesAndStrides!");
291   }
292 
resizeOutOfLineStorage(size_t newSize)293   void resizeOutOfLineStorage(size_t newSize) {
294     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!isInline());
295     outOfLineStorage_ = static_cast<int64_t*>(
296         // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
297         realloc(outOfLineStorage_, storageBytes(newSize)));
298     TORCH_CHECK(
299         outOfLineStorage_,
300         "Could not allocate memory for Tensor SizesAndStrides!");
301   }
302 
copyDataOutline(const SizesAndStrides & rhs)303   void copyDataOutline(const SizesAndStrides& rhs) noexcept {
304     memcpy(outOfLineStorage_, rhs.outOfLineStorage_, storageBytes(rhs.size_));
305   }
306 
307   size_t size_{1};
308   union {
309     int64_t* outOfLineStorage_;
310     // NOLINTNEXTLINE(*c-array*)
311     int64_t inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * 2]{};
312   };
313 };
314 
315 } // namespace c10::impl
316