xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/inductor/aoti_runtime/utils.h>
4 
5 #include <cassert>
6 #include <cstdint>
7 #include <cstring>
8 
9 namespace torch::aot_inductor {
10 
11 // Can't use c10::ArrayRef because it's not truly header-only and
12 // pulls in other c10 headers. This is (sadly) copy-pasted and
13 // adapted.
14 template <typename T>
15 class MiniArrayRef final {
16  public:
17   using iterator = T*;
18   using const_iterator = const T*;
19   using size_type = size_t;
20   using value_type = T;
21 
22   using reverse_iterator = std::reverse_iterator<iterator>;
23 
24  private:
25   /// The start of the array, in an external buffer.
26   T* Data;
27 
28   /// The number of elements.
29   size_type Length;
30 
31  public:
32   /// @name Constructors
33   /// @{
34 
35   /// Construct an empty MiniArrayRef.
MiniArrayRef()36   /* implicit */ constexpr MiniArrayRef() : Data(nullptr), Length(0) {}
37 
38   /// Construct an MiniArrayRef from a single element.
39   // TODO Make this explicit
MiniArrayRef(const T & OneElt)40   constexpr MiniArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
41 
42   /// Construct an MiniArrayRef from a pointer and length.
MiniArrayRef(T * data,size_t length)43   constexpr MiniArrayRef(T* data, size_t length) : Data(data), Length(length) {}
44 
45   /// Construct an MiniArrayRef from a range.
MiniArrayRef(T * begin,T * end)46   constexpr MiniArrayRef(T* begin, T* end) : Data(begin), Length(end - begin) {}
47 
48   template <
49       typename Container,
50       typename = std::enable_if_t<std::is_same_v<
51           std::remove_const_t<decltype(std::declval<Container>().data())>,
52           T*>>>
MiniArrayRef(Container & container)53   /* implicit */ MiniArrayRef(Container& container)
54       : Data(container.data()), Length(container.size()) {}
55 
56   /// Construct an MiniArrayRef from a std::vector.
57   // The enable_if stuff here makes sure that this isn't used for
58   // std::vector<bool>, because MiniArrayRef can't work on a std::vector<bool>
59   // bitfield.
60   template <typename A>
MiniArrayRef(const std::vector<T,A> & Vec)61   /* implicit */ MiniArrayRef(const std::vector<T, A>& Vec)
62       : Data(Vec.data()), Length(Vec.size()) {
63     static_assert(
64         !std::is_same<T, bool>::value,
65         "MiniArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
66   }
67 
68   /// Construct an MiniArrayRef from a std::array
69   template <size_t N>
MiniArrayRef(std::array<T,N> & Arr)70   /* implicit */ constexpr MiniArrayRef(std::array<T, N>& Arr)
71       : Data(Arr.data()), Length(N) {}
72 
73   /// Construct an MiniArrayRef from a C array.
74   template <size_t N>
75   // NOLINTNEXTLINE(*c-array*)
MiniArrayRef(T (& Arr)[N])76   /* implicit */ constexpr MiniArrayRef(T (&Arr)[N]) : Data(Arr), Length(N) {}
77 
78   // /// Construct an MiniArrayRef from an empty C array.
MiniArrayRef(const volatile void * Arr)79   /* implicit */ constexpr MiniArrayRef(const volatile void* Arr)
80       : Data(nullptr), Length(0) {}
81 
82   /// Construct an MiniArrayRef from a std::initializer_list.
MiniArrayRef(const std::initializer_list<T> & Vec)83   /* implicit */ constexpr MiniArrayRef(const std::initializer_list<T>& Vec)
84       : Data(
85             std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
86                                              : std::begin(Vec)),
87         Length(Vec.size()) {}
88 
89   /// @}
90   /// @name Simple Operations
91   /// @{
92 
begin()93   constexpr iterator begin() const {
94     return Data;
95   }
end()96   constexpr iterator end() const {
97     return Data + Length;
98   }
99 
100   // These are actually the same as iterator, since MiniArrayRef only
101   // gives you const iterators.
cbegin()102   constexpr const_iterator cbegin() const {
103     return Data;
104   }
cend()105   constexpr const_iterator cend() const {
106     return Data + Length;
107   }
108 
rbegin()109   constexpr reverse_iterator rbegin() const {
110     return reverse_iterator(end());
111   }
rend()112   constexpr reverse_iterator rend() const {
113     return reverse_iterator(begin());
114   }
115 
116   /// empty - Check if the array is empty.
empty()117   constexpr bool empty() const {
118     return Length == 0;
119   }
120 
data()121   constexpr T* data() const {
122     return Data;
123   }
124 
125   /// size - Get the array size.
size()126   constexpr size_t size() const {
127     return Length;
128   }
129 
130   /// equals - Check for element-wise equality.
equals(MiniArrayRef RHS)131   constexpr bool equals(MiniArrayRef RHS) const {
132     return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
133   }
134 
135   /// @}
136   /// @name Operator Overloads
137   /// @{
138   constexpr const T& operator[](size_t Index) const {
139     return Data[Index];
140   }
141 
142   /// Disallow accidental assignment from a temporary.
143   ///
144   /// The declaration here is extra complicated so that "arrayRef = {}"
145   /// continues to select the move assignment operator.
146   template <typename U>
147   std::enable_if_t<std::is_same_v<U, T>, MiniArrayRef<T>>& operator=(
148       U&& Temporary) = delete;
149 
150   /// Disallow accidental assignment from a temporary.
151   ///
152   /// The declaration here is extra complicated so that "arrayRef = {}"
153   /// continues to select the move assignment operator.
154   template <typename U>
155   std::enable_if_t<std::is_same_v<U, T>, MiniArrayRef<T>>& operator=(
156       std::initializer_list<U>) = delete;
157 };
158 
159 using MiniIntArrayRef = MiniArrayRef<int64_t>;
160 
161 static_assert(
162     sizeof(MiniIntArrayRef) == sizeof(void*) + sizeof(size_t),
163     "changing the size of MiniArrayRef breaks ABI compatibility!");
164 
is_contiguous_strides_for_shape(int64_t ndim,const int64_t * strides_ptr,const int64_t * sizes_ptr)165 inline bool is_contiguous_strides_for_shape(
166     int64_t ndim,
167     const int64_t* strides_ptr,
168     const int64_t* sizes_ptr) {
169   int64_t z = 1;
170   for (int64_t d = ndim - 1; d >= 0; d--) {
171     const auto& size_d = sizes_ptr[d];
172     if (size_d != 1) {
173       if (strides_ptr[d] == z) {
174         z *= size_d;
175       } else {
176         return false;
177       }
178     }
179   }
180   return true;
181 }
182 
183 // Shim for AOTI generated code to pretend a raw array works like an
184 // AtenTensorHandle.
185 template <typename T>
186 class ArrayRefTensor {
187  public:
188   ArrayRefTensor() = default;
189 
ArrayRefTensor(MiniArrayRef<T> arr,MiniArrayRef<const int64_t> sizes,MiniArrayRef<const int64_t> strides,int32_t device_type,int32_t device_idx)190   explicit ArrayRefTensor(
191       MiniArrayRef<T> arr,
192       MiniArrayRef<const int64_t> sizes,
193       MiniArrayRef<const int64_t> strides,
194       int32_t device_type,
195       int32_t device_idx)
196       : arrayRef_(arr),
197         sizes_(sizes),
198         strides_(strides),
199         device_type_(device_type),
200         device_idx_(device_idx) {
201     assert(sizes.size() == strides.size());
202     assert(is_contiguous_strides_for_shape(
203         sizes.size(), strides.data(), sizes.data()));
204   }
205 
expensiveCopyToTensor()206   AtenTensorHandle expensiveCopyToTensor() const {
207     AtenTensorHandle result = nullptr;
208     AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(
209         sizes_.size(),
210         sizes_.data(),
211         strides_.data(),
212         aoti_torch_dtype<std::remove_const_t<T>>(),
213         device_type_,
214         device_idx_,
215         &result));
216     void* dataPtr = nullptr;
217     AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(result, &dataPtr));
218     std::memcpy(dataPtr, data(), numel() * sizeof(T));
219     return result;
220   }
221 
222   // We need to look the same as RAIIAtenTensorHandle, which returns
223   // an owning AtenTensorHandle from release(). So, we allocate one!
release()224   AtenTensorHandle release() {
225     return expensiveCopyToTensor();
226   }
227 
228   // We don't need to free any memory.
reset()229   void reset() {}
230 
sizes()231   auto sizes() const {
232     return sizes_;
233   }
234 
strides()235   auto strides() const {
236     return strides_;
237   }
238 
device_type()239   auto device_type() const {
240     return device_type_;
241   }
242 
device_idx()243   auto device_idx() const {
244     return device_idx_;
245   }
246 
data()247   T* data() const {
248     return arrayRef_.data();
249   }
250 
numel()251   auto numel() const {
252     return arrayRef_.size();
253   }
254 
set_arrayref(MiniArrayRef<T> new_arrayref)255   void set_arrayref(MiniArrayRef<T> new_arrayref) {
256     arrayRef_ = new_arrayref;
257   }
258 
259  private:
260   MiniArrayRef<T> arrayRef_;
261   // We expect generated code to have statically available sizes &
262   // strides for us.
263   MiniArrayRef<const int64_t> sizes_;
264   MiniArrayRef<const int64_t> strides_;
265   int32_t device_type_ = 0;
266   int32_t device_idx_ = 0;
267   // We continue to zero-initialize this field in case we repurpose
268   // the space later; having predictable contents can only help.
269   int32_t unusedDoNotRemoveForABICompatibility_ = 0;
270 };
271 
272 static_assert(
273     sizeof(ArrayRefTensor<int>) ==
274         3 * sizeof(MiniIntArrayRef) + 3 * sizeof(int32_t) +
275             (alignof(ArrayRefTensor<int>) > 4 ? sizeof(int32_t) : 0),
276     "changing the size of ArrayRefTensor breaks ABI compatibility!");
277 
reinterpret_tensor_wrapper(AtenTensorHandle self,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int64_t storage_offset)278 inline AtenTensorHandle reinterpret_tensor_wrapper(
279     AtenTensorHandle self,
280     int64_t ndim,
281     const int64_t* sizes_ptr,
282     const int64_t* strides_ptr,
283     int64_t storage_offset) {
284   AtenTensorHandle result = nullptr;
285   AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__reinterpret_tensor(
286       self, ndim, sizes_ptr, strides_ptr, storage_offset, &result));
287   return result;
288 }
289 
290 template <typename T>
reinterpret_tensor_wrapper(const ArrayRefTensor<T> & self,int64_t ndim,const int64_t * sizes_ptr,const int64_t * strides_ptr,int64_t storage_offset)291 inline ArrayRefTensor<T> reinterpret_tensor_wrapper(
292     const ArrayRefTensor<T>& self,
293     int64_t ndim,
294     const int64_t* sizes_ptr,
295     const int64_t* strides_ptr,
296     int64_t storage_offset) {
297   // REVIEW: we should add a way to build the DSO in debug mode during
298   // tests so we can have checks like this!
299   assert(is_contiguous_strides_for_shape(ndim, strides_ptr, sizes_ptr));
300   return ArrayRefTensor<T>(
301       MiniArrayRef<T>(
302           self.data() + storage_offset, self.numel() - storage_offset),
303       MiniArrayRef<const int64_t>(sizes_ptr, ndim),
304       MiniArrayRef<const int64_t>(strides_ptr, ndim),
305       self.device_type(),
306       self.device_idx());
307 }
308 
get_data_ptr_wrapper(AtenTensorHandle tensor)309 inline void* get_data_ptr_wrapper(AtenTensorHandle tensor) {
310   void* result = nullptr;
311   AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(tensor, &result));
312   return result;
313 }
314 
315 template <typename T>
get_data_ptr_wrapper(ArrayRefTensor<T> & tensor)316 inline T* get_data_ptr_wrapper(ArrayRefTensor<T>& tensor) {
317   return tensor.data();
318 }
319 
320 template <typename T>
get_data_ptr_wrapper(const MiniArrayRef<T> & arr)321 inline T* get_data_ptr_wrapper(const MiniArrayRef<T>& arr) {
322   return arr.data();
323 }
324 
unwrap_raii_handle_if_needed(const RAIIAtenTensorHandle & handle)325 inline AtenTensorHandle unwrap_raii_handle_if_needed(
326     const RAIIAtenTensorHandle& handle) {
327   return handle.get();
328 }
329 
330 template <typename T>
unwrap_raii_handle_if_needed(const ArrayRefTensor<T> & tensor)331 inline const ArrayRefTensor<T>& unwrap_raii_handle_if_needed(
332     const ArrayRefTensor<T>& tensor) {
333   return tensor;
334 }
335 
336 template <typename T>
unwrap_raii_handle_if_needed(ArrayRefTensor<T> & tensor)337 inline ArrayRefTensor<T>& unwrap_raii_handle_if_needed(
338     ArrayRefTensor<T>& tensor) {
339   return tensor;
340 }
341 
wrap_with_raii_handle_if_needed(AtenTensorHandle handle)342 inline RAIIAtenTensorHandle wrap_with_raii_handle_if_needed(
343     AtenTensorHandle handle) {
344   return RAIIAtenTensorHandle(handle);
345 }
346 
347 template <typename T>
wrap_with_raii_handle_if_needed(const ArrayRefTensor<T> & tensor)348 inline const ArrayRefTensor<T>& wrap_with_raii_handle_if_needed(
349     const ArrayRefTensor<T>& tensor) {
350   return tensor;
351 }
352 
353 template <typename T>
wrap_with_raii_handle_if_needed(ArrayRefTensor<T> & tensor)354 inline ArrayRefTensor<T>& wrap_with_raii_handle_if_needed(
355     ArrayRefTensor<T>& tensor) {
356   return tensor;
357 }
358 
359 template <typename T>
expensive_copy_to_tensor_if_needed(const ArrayRefTensor<T> & tensor)360 inline RAIIAtenTensorHandle expensive_copy_to_tensor_if_needed(
361     const ArrayRefTensor<T>& tensor) {
362   return tensor.expensiveCopyToTensor();
363 }
364 
expensive_copy_to_tensor_if_needed(AtenTensorHandle handle)365 inline AtenTensorHandle expensive_copy_to_tensor_if_needed(
366     AtenTensorHandle handle) {
367   return handle;
368 }
369 
370 template <typename T>
convert_arrayref_tensor_to_tensor(const T & t)371 const T& convert_arrayref_tensor_to_tensor(const T& t) {
372   return t;
373 }
374 
375 template <typename T>
convert_arrayref_tensor_to_tensor(const ArrayRefTensor<T> & art)376 RAIIAtenTensorHandle convert_arrayref_tensor_to_tensor(
377     const ArrayRefTensor<T>& art) {
378   return art.expensiveCopyToTensor();
379 }
380 
381 } // namespace torch::aot_inductor
382