1 #pragma once
2
3 #include <torch/csrc/inductor/aoti_runtime/utils.h>
4
5 #include <cassert>
6 #include <cstdint>
7 #include <cstring>
8
9 namespace torch::aot_inductor {
10
11 // Can't use c10::ArrayRef because it's not truly header-only and
12 // pulls in other c10 headers. This is (sadly) copy-pasted and
13 // adapted.
14 template <typename T>
15 class MiniArrayRef final {
16 public:
17 using iterator = T*;
18 using const_iterator = const T*;
19 using size_type = size_t;
20 using value_type = T;
21
22 using reverse_iterator = std::reverse_iterator<iterator>;
23
24 private:
25 /// The start of the array, in an external buffer.
26 T* Data;
27
28 /// The number of elements.
29 size_type Length;
30
31 public:
32 /// @name Constructors
33 /// @{
34
35 /// Construct an empty MiniArrayRef.
MiniArrayRef()36 /* implicit */ constexpr MiniArrayRef() : Data(nullptr), Length(0) {}
37
38 /// Construct an MiniArrayRef from a single element.
39 // TODO Make this explicit
MiniArrayRef(const T & OneElt)40 constexpr MiniArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
41
42 /// Construct an MiniArrayRef from a pointer and length.
MiniArrayRef(T * data,size_t length)43 constexpr MiniArrayRef(T* data, size_t length) : Data(data), Length(length) {}
44
45 /// Construct an MiniArrayRef from a range.
MiniArrayRef(T * begin,T * end)46 constexpr MiniArrayRef(T* begin, T* end) : Data(begin), Length(end - begin) {}
47
48 template <
49 typename Container,
50 typename = std::enable_if_t<std::is_same_v<
51 std::remove_const_t<decltype(std::declval<Container>().data())>,
52 T*>>>
MiniArrayRef(Container & container)53 /* implicit */ MiniArrayRef(Container& container)
54 : Data(container.data()), Length(container.size()) {}
55
56 /// Construct an MiniArrayRef from a std::vector.
57 // The enable_if stuff here makes sure that this isn't used for
58 // std::vector<bool>, because MiniArrayRef can't work on a std::vector<bool>
59 // bitfield.
60 template <typename A>
MiniArrayRef(const std::vector<T,A> & Vec)61 /* implicit */ MiniArrayRef(const std::vector<T, A>& Vec)
62 : Data(Vec.data()), Length(Vec.size()) {
63 static_assert(
64 !std::is_same<T, bool>::value,
65 "MiniArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
66 }
67
68 /// Construct an MiniArrayRef from a std::array
69 template <size_t N>
MiniArrayRef(std::array<T,N> & Arr)70 /* implicit */ constexpr MiniArrayRef(std::array<T, N>& Arr)
71 : Data(Arr.data()), Length(N) {}
72
73 /// Construct an MiniArrayRef from a C array.
74 template <size_t N>
75 // NOLINTNEXTLINE(*c-array*)
MiniArrayRef(T (& Arr)[N])76 /* implicit */ constexpr MiniArrayRef(T (&Arr)[N]) : Data(Arr), Length(N) {}
77
78 // /// Construct an MiniArrayRef from an empty C array.
MiniArrayRef(const volatile void * Arr)79 /* implicit */ constexpr MiniArrayRef(const volatile void* Arr)
80 : Data(nullptr), Length(0) {}
81
82 /// Construct an MiniArrayRef from a std::initializer_list.
MiniArrayRef(const std::initializer_list<T> & Vec)83 /* implicit */ constexpr MiniArrayRef(const std::initializer_list<T>& Vec)
84 : Data(
85 std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
86 : std::begin(Vec)),
87 Length(Vec.size()) {}
88
89 /// @}
90 /// @name Simple Operations
91 /// @{
92
begin()93 constexpr iterator begin() const {
94 return Data;
95 }
end()96 constexpr iterator end() const {
97 return Data + Length;
98 }
99
100 // These are actually the same as iterator, since MiniArrayRef only
101 // gives you const iterators.
cbegin()102 constexpr const_iterator cbegin() const {
103 return Data;
104 }
cend()105 constexpr const_iterator cend() const {
106 return Data + Length;
107 }
108
rbegin()109 constexpr reverse_iterator rbegin() const {
110 return reverse_iterator(end());
111 }
rend()112 constexpr reverse_iterator rend() const {
113 return reverse_iterator(begin());
114 }
115
116 /// empty - Check if the array is empty.
empty()117 constexpr bool empty() const {
118 return Length == 0;
119 }
120
data()121 constexpr T* data() const {
122 return Data;
123 }
124
125 /// size - Get the array size.
size()126 constexpr size_t size() const {
127 return Length;
128 }
129
130 /// equals - Check for element-wise equality.
equals(MiniArrayRef RHS)131 constexpr bool equals(MiniArrayRef RHS) const {
132 return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
133 }
134
135 /// @}
136 /// @name Operator Overloads
137 /// @{
138 constexpr const T& operator[](size_t Index) const {
139 return Data[Index];
140 }
141
142 /// Disallow accidental assignment from a temporary.
143 ///
144 /// The declaration here is extra complicated so that "arrayRef = {}"
145 /// continues to select the move assignment operator.
146 template <typename U>
147 std::enable_if_t<std::is_same_v<U, T>, MiniArrayRef<T>>& operator=(
148 U&& Temporary) = delete;
149
150 /// Disallow accidental assignment from a temporary.
151 ///
152 /// The declaration here is extra complicated so that "arrayRef = {}"
153 /// continues to select the move assignment operator.
154 template <typename U>
155 std::enable_if_t<std::is_same_v<U, T>, MiniArrayRef<T>>& operator=(
156 std::initializer_list<U>) = delete;
157 };
158
159 using MiniIntArrayRef = MiniArrayRef<int64_t>;
160
161 static_assert(
162 sizeof(MiniIntArrayRef) == sizeof(void*) + sizeof(size_t),
163 "changing the size of MiniArrayRef breaks ABI compatibility!");
164
is_contiguous_strides_for_shape(int64_t ndim,const int64_t * strides_ptr,const int64_t * sizes_ptr)165 inline bool is_contiguous_strides_for_shape(
166 int64_t ndim,
167 const int64_t* strides_ptr,
168 const int64_t* sizes_ptr) {
169 int64_t z = 1;
170 for (int64_t d = ndim - 1; d >= 0; d--) {
171 const auto& size_d = sizes_ptr[d];
172 if (size_d != 1) {
173 if (strides_ptr[d] == z) {
174 z *= size_d;
175 } else {
176 return false;
177 }
178 }
179 }
180 return true;
181 }
182
183 // Shim for AOTI generated code to pretend a raw array works like an
184 // AtenTensorHandle.
185 template <typename T>
186 class ArrayRefTensor {
187 public:
188 ArrayRefTensor() = default;
189
ArrayRefTensor(MiniArrayRef<T> arr,MiniArrayRef<const int64_t> sizes,MiniArrayRef<const int64_t> strides,int32_t device_type,int32_t device_idx)190 explicit ArrayRefTensor(
191 MiniArrayRef<T> arr,
192 MiniArrayRef<const int64_t> sizes,
193 MiniArrayRef<const int64_t> strides,
194 int32_t device_type,
195 int32_t device_idx)
196 : arrayRef_(arr),
197 sizes_(sizes),
198 strides_(strides),
199 device_type_(device_type),
200 device_idx_(device_idx) {
201 assert(sizes.size() == strides.size());
202 assert(is_contiguous_strides_for_shape(
203 sizes.size(), strides.data(), sizes.data()));
204 }
205
expensiveCopyToTensor()206 AtenTensorHandle expensiveCopyToTensor() const {
207 AtenTensorHandle result = nullptr;
208 AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(
209 sizes_.size(),
210 sizes_.data(),
211 strides_.data(),
212 aoti_torch_dtype<std::remove_const_t<T>>(),
213 device_type_,
214 device_idx_,
215 &result));
216 void* dataPtr = nullptr;
217 AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(result, &dataPtr));
218 std::memcpy(dataPtr, data(), numel() * sizeof(T));
219 return result;
220 }
221
222 // We need to look the same as RAIIAtenTensorHandle, which returns
223 // an owning AtenTensorHandle from release(). So, we allocate one!
release()224 AtenTensorHandle release() {
225 return expensiveCopyToTensor();
226 }
227
228 // We don't need to free any memory.
reset()229 void reset() {}
230
sizes()231 auto sizes() const {
232 return sizes_;
233 }
234
strides()235 auto strides() const {
236 return strides_;
237 }
238
device_type()239 auto device_type() const {
240 return device_type_;
241 }
242
device_idx()243 auto device_idx() const {
244 return device_idx_;
245 }
246
data()247 T* data() const {
248 return arrayRef_.data();
249 }
250
numel()251 auto numel() const {
252 return arrayRef_.size();
253 }
254
set_arrayref(MiniArrayRef<T> new_arrayref)255 void set_arrayref(MiniArrayRef<T> new_arrayref) {
256 arrayRef_ = new_arrayref;
257 }
258
259 private:
260 MiniArrayRef<T> arrayRef_;
261 // We expect generated code to have statically available sizes &
262 // strides for us.
263 MiniArrayRef<const int64_t> sizes_;
264 MiniArrayRef<const int64_t> strides_;
265 int32_t device_type_ = 0;
266 int32_t device_idx_ = 0;
267 // We continue to zero-initialize this field in case we repurpose
268 // the space later; having predictable contents can only help.
269 int32_t unusedDoNotRemoveForABICompatibility_ = 0;
270 };
271
272 static_assert(
273 sizeof(ArrayRefTensor<int>) ==
274 3 * sizeof(MiniIntArrayRef) + 3 * sizeof(int32_t) +
275 (alignof(ArrayRefTensor<int>) > 4 ? sizeof(int32_t) : 0),
276 "changing the size of ArrayRefTensor breaks ABI compatibility!");
277
reinterpret_tensor_wrapper(AtenTensorHandle self,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int64_t storage_offset)278 inline AtenTensorHandle reinterpret_tensor_wrapper(
279 AtenTensorHandle self,
280 int64_t ndim,
281 const int64_t* sizes_ptr,
282 const int64_t* strides_ptr,
283 int64_t storage_offset) {
284 AtenTensorHandle result = nullptr;
285 AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__reinterpret_tensor(
286 self, ndim, sizes_ptr, strides_ptr, storage_offset, &result));
287 return result;
288 }
289
290 template <typename T>
reinterpret_tensor_wrapper(const ArrayRefTensor<T> & self,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int64_t storage_offset)291 inline ArrayRefTensor<T> reinterpret_tensor_wrapper(
292 const ArrayRefTensor<T>& self,
293 int64_t ndim,
294 const int64_t* sizes_ptr,
295 const int64_t* strides_ptr,
296 int64_t storage_offset) {
297 // REVIEW: we should add a way to build the DSO in debug mode during
298 // tests so we can have checks like this!
299 assert(is_contiguous_strides_for_shape(ndim, strides_ptr, sizes_ptr));
300 return ArrayRefTensor<T>(
301 MiniArrayRef<T>(
302 self.data() + storage_offset, self.numel() - storage_offset),
303 MiniArrayRef<const int64_t>(sizes_ptr, ndim),
304 MiniArrayRef<const int64_t>(strides_ptr, ndim),
305 self.device_type(),
306 self.device_idx());
307 }
308
get_data_ptr_wrapper(AtenTensorHandle tensor)309 inline void* get_data_ptr_wrapper(AtenTensorHandle tensor) {
310 void* result = nullptr;
311 AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(tensor, &result));
312 return result;
313 }
314
315 template <typename T>
get_data_ptr_wrapper(ArrayRefTensor<T> & tensor)316 inline T* get_data_ptr_wrapper(ArrayRefTensor<T>& tensor) {
317 return tensor.data();
318 }
319
320 template <typename T>
get_data_ptr_wrapper(const MiniArrayRef<T> & arr)321 inline T* get_data_ptr_wrapper(const MiniArrayRef<T>& arr) {
322 return arr.data();
323 }
324
unwrap_raii_handle_if_needed(const RAIIAtenTensorHandle & handle)325 inline AtenTensorHandle unwrap_raii_handle_if_needed(
326 const RAIIAtenTensorHandle& handle) {
327 return handle.get();
328 }
329
330 template <typename T>
unwrap_raii_handle_if_needed(const ArrayRefTensor<T> & tensor)331 inline const ArrayRefTensor<T>& unwrap_raii_handle_if_needed(
332 const ArrayRefTensor<T>& tensor) {
333 return tensor;
334 }
335
336 template <typename T>
unwrap_raii_handle_if_needed(ArrayRefTensor<T> & tensor)337 inline ArrayRefTensor<T>& unwrap_raii_handle_if_needed(
338 ArrayRefTensor<T>& tensor) {
339 return tensor;
340 }
341
wrap_with_raii_handle_if_needed(AtenTensorHandle handle)342 inline RAIIAtenTensorHandle wrap_with_raii_handle_if_needed(
343 AtenTensorHandle handle) {
344 return RAIIAtenTensorHandle(handle);
345 }
346
347 template <typename T>
wrap_with_raii_handle_if_needed(const ArrayRefTensor<T> & tensor)348 inline const ArrayRefTensor<T>& wrap_with_raii_handle_if_needed(
349 const ArrayRefTensor<T>& tensor) {
350 return tensor;
351 }
352
353 template <typename T>
wrap_with_raii_handle_if_needed(ArrayRefTensor<T> & tensor)354 inline ArrayRefTensor<T>& wrap_with_raii_handle_if_needed(
355 ArrayRefTensor<T>& tensor) {
356 return tensor;
357 }
358
359 template <typename T>
expensive_copy_to_tensor_if_needed(const ArrayRefTensor<T> & tensor)360 inline RAIIAtenTensorHandle expensive_copy_to_tensor_if_needed(
361 const ArrayRefTensor<T>& tensor) {
362 return tensor.expensiveCopyToTensor();
363 }
364
expensive_copy_to_tensor_if_needed(AtenTensorHandle handle)365 inline AtenTensorHandle expensive_copy_to_tensor_if_needed(
366 AtenTensorHandle handle) {
367 return handle;
368 }
369
370 template <typename T>
convert_arrayref_tensor_to_tensor(const T & t)371 const T& convert_arrayref_tensor_to_tensor(const T& t) {
372 return t;
373 }
374
375 template <typename T>
convert_arrayref_tensor_to_tensor(const ArrayRefTensor<T> & art)376 RAIIAtenTensorHandle convert_arrayref_tensor_to_tensor(
377 const ArrayRefTensor<T>& art) {
378 return art.expensiveCopyToTensor();
379 }
380
381 } // namespace torch::aot_inductor
382