xref: /aosp_15_r20/external/pytorch/functorch/csrc/dim/arena.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #pragma once
8 #include <ATen/ATen.h>
9 #include "minpybind.h"
10 
11 #ifdef _WIN32
12 #include <intrin.h>
13 // https://stackoverflow.com/questions/355967/how-to-use-msvc-intrinsics-to-get-the-equivalent-of-this-gcc-code
__builtin_clz(unsigned int x)14 inline unsigned int __builtin_clz(unsigned int x) {
15     unsigned long r = 0;
16     _BitScanReverse(&r, x);
17     return (31 - r);
18 }
19 #endif
20 
round2min8(int num)21 inline int round2min8(int num) {
22    int nzeros = __builtin_clz((num - 1)|4);
23    return 1 << (32 - nzeros);
24 }
25 
26 struct Arena;
27 template<typename T>
28 struct OwnedSlice;
29 
30 template<typename T>
31 struct Slice {
SliceSlice32     Slice()
33     :  begin_(nullptr), size_(0), capacity_(0) {}
34 
35     template<typename... Args>
36     Slice(Arena& arena, Args&&... args);
37 
beginSlice38     T* begin() const {
39         return begin_;
40     }
endSlice41     T* end() const {
42         return begin_ + size_;
43     }
sizeSlice44     int size() const {
45         return size_;
46     }
capacitySlice47     int capacity() const {
48         return capacity_;
49     }
50 
51     T& back(int i=-1) {
52         return begin_[size_ + i];
53     }
54 
55     T& operator[](int i) const {
56         return begin_[i];
57     }
indexSlice58     std::optional<int> index(const T& value) {
59         for (int i : enumerate()) {
60             if (begin_[i] == value) {
61                 return i;
62             }
63         }
64         return std::nullopt;
65     }
containsSlice66     bool contains(const T& value) {
67         return index(value).has_value();
68     }
69 
70     void insert(Arena& arena, Slice where, Slice to_insert);
insertSlice71     void insert(Arena& arena, Slice where, T v) {
72         return insert(arena, where, Slice(&v, &v + 1));
73     }
insertSlice74     void insert(Arena& arena, int where, T v) {
75         return insert(arena, slice(where, where), v);
76     }
77     void append(Arena& arena, T value);
78     void extend(Arena& arena, Slice to_insert);
extendSlice79     void extend(Arena& arena, const T* begin, const T* end) {
80         return extend(arena, Slice<T>((T*)begin, (T*)end));
81     }
82 
removeSlice83     bool remove(Arena& A, T value) {
84         auto idx = index(value);
85         if (idx) {
86             insert(A, slice(*idx, *idx + 1), Slice());
87         }
88         return idx.has_value();
89     }
90 
sliceSlice91     Slice slice(int begin) {
92         return slice(begin, size_);
93     }
94 
sliceSlice95     Slice slice(int begin, int end) {
96         if (begin < 0) {
97             begin += size_;
98         }
99         if (end < 0) {
100             end += size_;
101         }
102         Slice result;
103         result.begin_ = begin_ + begin;
104         result.size_ = end - begin;
105         result.capacity_ = result.size_;
106         return result;
107     }
108 
insideSlice109     bool inside(Slice where) {
110         return begin() <= where.begin() && where.end() <= end();
111     }
112 
enumerateSlice113     irange enumerate() const {
114         return irange(size_);
115     }
116 
reversed_enumerateSlice117     irange reversed_enumerate() const {
118         return irange(size_ - 1, -1, -1);
119     }
120 
121     bool operator==(const Slice<T>& rhs) const {
122         if (size() != rhs.size()) {
123             return false;
124         }
125         return std::equal(begin(), end(), rhs.begin());
126     }
127 
SliceSlice128     Slice(T* begin, T* end)
129     : begin_(begin), size_(end - begin), capacity_(size_) {}
130 
131 protected:
_lengthSlice132     static int _length(const T& t) {
133         return 1;
134     }
_lengthSlice135     static int _length(Slice t) {
136         return t.size_;
137     }
_insertSlice138     static T* _insert(T*& dst, T t) {
139         *dst = std::move(t);
140         return ++dst;
141     }
_insertSlice142     static T* _insert(T*& dst, Slice t) {
143         std::memcpy(dst, t.begin_, sizeof(T)*t.size_);
144         dst += t.size_;
145         return dst;
146     }
147     T* begin_;
148     int size_;
149     int capacity_;
150     friend struct OwnedSlice<T>;
151 };
152 
153 template<typename T>
154 struct OwnedSlice {
155     typedef void (*deleter_t)(Slice<T>);
156     static void _no_delete(Slice<T>) {}
157     OwnedSlice()
158     : deleter_(_no_delete) {}
159     OwnedSlice(const OwnedSlice&) = delete;
160     OwnedSlice& operator=(const OwnedSlice&) = delete;
161     ~OwnedSlice() {
162         deleter_(slice_);
163         if (slice_.size_ > 8) {
164             delete [] slice_.begin_;
165         }
166     }
167     void set(Slice<T> to_own, deleter_t deleter = _no_delete) {
168         slice_.size_ = slice_.capacity_ = to_own.size();
169         slice_.begin_ = (slice_.size_ > 8) ? new T[slice_.size_] : &small_buf[0];
170         std::memcpy(slice_.begin_, to_own.begin(), slice_.size_ * sizeof(T));
171         deleter_ = deleter;
172     }
173     Slice<T> slice() const {
174         return slice_;
175     }
176 private:
177     Slice<T> slice_;
178     deleter_t deleter_;
179     T small_buf[8];
180 };
181 
182 template<typename T>
183 inline std::ostream& operator<<(std::ostream& s, const Slice<T>& v) {
184     s << "[";
185     for (int i : v.enumerate()) {
186         if (i > 0) {
187             s << ", ";
188         }
189         s << v[i];
190     }
191     s << "]";
192     return s;
193 }
194 
195 struct TensorRef {
196     TensorRef()
197     : impl_(nullptr){}
198     TensorRef(const at::Tensor& t)
199     : impl_(t.unsafeGetTensorImpl()) {}
200     const at::Tensor& operator*() const {
201         return *(at::Tensor*)this;
202     }
203     at::Tensor* operator->() const {
204         return (at::Tensor*)this;
205     }
206     operator bool() const {
207         return impl_ != nullptr;
208     }
209 private:
210     at::TensorImpl* impl_;
211 };
212 
213 constexpr int ARENA_MAX_SIZE = 4096;
214 constexpr int ALIGNMENT = 8;
215 struct Arena {
216     Arena()
217     : allocated_(0) {}
218     template<typename T>
219     T* allocate(int n) {
220         if (!n) {
221             return nullptr;
222         }
223         int to_allocate = sizeof(T)*n;
224         int to_allocate_rounded = ALIGNMENT * ((to_allocate - 1) / ALIGNMENT + 1);
225         auto prev_allocated = allocated_;
226         allocated_ += to_allocate_rounded;
227         if (C10_UNLIKELY_OR_CONST(allocated_ > ARENA_MAX_SIZE)) {
228             overflow_.emplace_back(new char[to_allocate]);
229             return (T*) &overflow_.back()[0];
230         }
231         return (T*) (buffer_ + prev_allocated);
232     }
233     TensorRef autorelease(at::Tensor s) {
234         auto ref = TensorRef(s);
235         s.unsafeReleaseTensorImpl();
236         ar_tensors_.append(*this, ref);
237         return ref;
238     }
239     mpy::handle autorelease(mpy::object obj) {
240         ar_objects_.append(*this, obj);
241         obj.release();
242         return ar_objects_.back();
243     }
244     ~Arena() {
245         for(TensorRef t: ar_tensors_) {
246             c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(t->unsafeGetTensorImpl());
247         }
248         for(mpy::handle h: ar_objects_) {
249             mpy::object::steal(h);
250         }
251     }
252 private:
253     int64_t allocated_;
254     char buffer_[ARENA_MAX_SIZE];
255     Slice<TensorRef> ar_tensors_;
256     Slice<mpy::handle> ar_objects_;
257     std::vector<std::unique_ptr<char[]>> overflow_;
258 };
259 
260 template<typename T>
261 inline void Slice<T>::insert(Arena& arena, Slice where, Slice to_insert) {
262     AT_ASSERT(inside(where));
263     Slice result = *this;
264     /// b------sb---se-----e,  0----n
265     T* body_dest = where.begin();
266     if (where.size() != to_insert.size()) {
267         int new_size = size() - where.size() + to_insert.size();
268         T* tail_dest = where.begin() + to_insert.size();
269         if (new_size >= capacity_) {
270             int new_capacity = new_size ? round2min8(new_size) : 0;
271             result.capacity_ = new_capacity;
272             result.begin_ = arena.allocate<T>(new_capacity);
273             body_dest = result.begin_ + (where.begin() - begin());
274             tail_dest = body_dest + to_insert.size();
275             //std::memcpy(result.begin_, begin_, sizeof(T)*(where.begin() - begin()));
276             std::copy(begin_, begin_ + (where.begin() - begin()), result.begin_);
277         }
278         std::memmove(tail_dest, where.end(), sizeof(T)*(end() - where.end()));
279         result.size_ = new_size;
280     }
281 
282     //std::memcpy(body_dest, to_insert.begin(), sizeof(T)*to_insert.size());
283     std::copy(to_insert.begin(), to_insert.end(), body_dest);
284     *this = result;
285 }
286 
287 template<typename T>
288 inline void Slice<T>::append(Arena& arena, T value) {
289     Slice result = *this;
290     if (size_ == capacity_) {
291         int new_size = size_ ? round2min8(size_)*2 : 8;
292         T* n = arena.allocate<T>(new_size);
293         //memcpy(n, begin_, size_*sizeof(T));
294         std::copy(begin_, begin_ + size_, n);
295         result.begin_ = n;
296         result.capacity_ = new_size;
297     }
298     result[result.size_++] = std::move(value);
299     *this = result;
300 }
301 
302 template<typename T>
303 inline void Slice<T>::extend(Arena& arena, Slice<T> rhs) {
304     Slice result = *this;
305     result.size_ = size_ + rhs.size();
306     if (result.size_ > capacity_) {
307         int new_size = round2min8(result.size_);
308         T* n = arena.allocate<T>(new_size);
309         //memcpy(n, begin_, size_*sizeof(T));
310         std::copy(begin_, begin_+size_, n);
311         result.begin_ = n;
312         result.capacity_ = new_size;
313     }
314     //memcpy(result.begin_ + size_, rhs.begin(), sizeof(T)*rhs.size());
315     std::copy(rhs.begin(), rhs.end(), result.begin_ + size_);
316     *this = result;
317 }
318 
319 template<typename T>
320 template<typename... Args>
321 Slice<T>::Slice(Arena& arena, Args&&... args) {
322     int lens[] = {_length(args)...};
323     size_ = 0;
324     for (auto i : lens) {
325         size_ += i;
326     }
327     capacity_ = size_ ? round2min8(size_) : 0;
328     begin_ = arena.allocate<T>(capacity_);
329     T* dst_ = begin_;
330     T* unused[] = {_insert(dst_, args)...};
331     (void) unused;
332 }
333