xref: /aosp_15_r20/external/executorch/runtime/core/portable_type/tensor_impl.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/core/array_ref.h>
12 #include <executorch/runtime/core/error.h>
13 #include <executorch/runtime/core/portable_type/scalar_type.h>
14 #include <executorch/runtime/core/tensor_shape_dynamism.h>
15 
16 // Forward declaration of a helper that provides access to internal resizing
17 // methods of TensorImpl. Real definition is in
18 // executorch/runtime/core/exec_aten/tensor_util.h.
19 namespace executorch {
20 namespace runtime {
21 namespace internal {
22 class TensorResizerFriend;
23 } // namespace internal
24 } // namespace runtime
25 } // namespace executorch
26 
27 namespace executorch {
28 namespace runtime {
29 namespace etensor {
30 
31 /**
32  * Manages the storage behind an ETensor (torch::executor::Tensor).
33  *
34  * Note that instances of this class do not own the arrays given to it
35  * (sizes/strides/data), which means that the caller must guarantee that they
36  * live longer than a given instance of this class.
37  *
38  * Note on types:
39  *
40  * Code that uses ETensor should also be able to build against at::Tensor. So,
41  * although the overlapping APIs don't need to be exactly the same, their types
42  * should be semantically similar.
43  *
44  * Many of the methods in at::Tensor use int64_t for parameter and return types.
45  * This can be a waste when building for 32-bit environments. So, TensorImpl and
46  * ETensor use ssize_t instead: like int64_t it is signed, but it will match the
47  * native word size of the target architecture. This will avoid unnecessarily
48  * expensive uses of 64-bit integers on 32-bit machines.
49  *
50  * But, since the types are not identical, code that uses ETensor needs to be
51  * generic about the local types it uses when working with these methods. In
52  * most cases, `auto` will do the trick. In the worst case, code can be guarded
53  * with `#ifdef USE_ATEN_LIB`.
54  */
55 class TensorImpl {
56  public:
57   /**
58    * The type used for elements of `sizes()`.
59    *
60    * This must match the size/signedness of the type used for `Tensor.sizes` in
61    * //executorch/schema/program.fbs.
62    *
63    * Note that at::TensorImpl uses `int64_t` for this type. ExecuTorch uses
64    * `int32_t` to save memory, since no single size value will ever be larger
65    * than 2 billion.
66    */
67   using SizesType = int32_t;
68 
69   /**
70    * The type used for elements of `dim_order()`.
71    *
72    * This must match the size/signedness of the type used for `Tensor.dim_order`
73    * in //executorch/schema/program.fbs.
74    */
75   using DimOrderType = uint8_t;
76 
77   /**
78    * The type used for elements of `strides()`.
79    *
80    * This must match the size/signedness of the type used for `Tensor.strides`
81    * in //executorch/schema/program.fbs.
82    *
83    * Note that at::TensorImpl uses `int64_t` for this type. ExecuTorch uses
84    * `int32_t` to save memory, since no single stride value will ever be larger
85    * than 2 billion.
86    */
87   using StridesType = int32_t;
88 
89   TensorImpl() = delete;
90 
91   /**
92    * @param type The type of the data (int, float, bool).
93    * @param dim Number of dimensions, and the length of the `sizes` array.
94    * @param sizes Sizes of the tensor at each dimension. Must contain `dim`
95    *     entries.
96    * @param data Pointer to the data, whose size is determined by `type`,
97    *     `dim`, and `sizes`. The tensor will not own this memory.
98    * @param dim_order Order in which dimensions are laid out in memory.
99    * @param strides Strides of the tensor at each dimension. Must contain `dim`
100    *     entries.
101    * @param dynamism The mutability of the shape of the tensor.
102    */
103   TensorImpl(
104       ScalarType type,
105       ssize_t dim,
106       SizesType* sizes,
107       void* data = nullptr,
108       DimOrderType* dim_order = nullptr,
109       StridesType* strides = nullptr,
110       TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC);
111 
112   /**
113    * Returns the size of the tensor in bytes.
114    *
115    * NOTE: This returns the size of the data used by the tensor's current shape,
116    * not the capacity of the underlying buffer.
117    */
118   size_t nbytes() const;
119 
120   /**
121    * Returns the size of the tensor at the given dimension.
122    *
123    * NOTE: size() intentionally does not return SizeType even though it
124    * returns an element of an array of SizeType. This is to help make calls of
125    * this method more compatible with at::Tensor, and more consistent with the
126    * rest of the methods on this class and in ETensor.
127    */
size(ssize_t dim)128   ssize_t size(ssize_t dim) const {
129     ET_CHECK_MSG(
130         dim < dim_ && dim >= 0,
131         "Dimension out of range (expected to be in range of [0, %zd], but got %zd",
132         dim_ - 1,
133         dim);
134     return sizes_[dim];
135   }
136 
137   /// Returns the tensor's number of dimensions.
dim()138   ssize_t dim() const {
139     return dim_;
140   }
141 
142   /// Returns the number of elements in the tensor.
numel()143   ssize_t numel() const {
144     return numel_;
145   }
146 
147   /// Returns the type of the elements in the tensor (int32, float, bool, etc).
scalar_type()148   ScalarType scalar_type() const {
149     return type_;
150   }
151 
dtype()152   inline ScalarType dtype() const {
153     return scalar_type();
154   }
155 
156   /// Returns the size in bytes of one element of the tensor.
157   ssize_t element_size() const;
158 
159   /// Returns the sizes of the tensor at each dimension.
sizes()160   const ArrayRef<SizesType> sizes() const {
161     return ArrayRef<SizesType>{sizes_, static_cast<size_t>(dim_)};
162   }
163 
164   /// Returns the order the dimensions are laid out in memory.
dim_order()165   const ArrayRef<DimOrderType> dim_order() const {
166     return ArrayRef<DimOrderType>{dim_order_, static_cast<size_t>(dim_)};
167   }
168 
169   /// Returns the strides of the tensor at each dimension.
strides()170   const ArrayRef<StridesType> strides() const {
171     return ArrayRef<StridesType>{strides_, static_cast<size_t>(dim_)};
172   }
173 
174   /// Returns the mutability of the shape of the tensor.
shape_dynamism()175   TensorShapeDynamism shape_dynamism() const {
176     return shape_dynamism_;
177   }
178 
179   /// Returns a pointer of type T to the constant underlying data blob.
180   template <typename T>
data()181   inline const T* data() const {
182     return static_cast<const T*>(data());
183   }
184 
185   /// Returns a pointer to the constant underlying data blob.
data()186   const void* data() const {
187     return data_;
188   }
189 
190   /// Returns a pointer of type T to the mutable underlying data blob.
191   template <typename T>
mutable_data()192   inline T* mutable_data() const {
193     return static_cast<T*>(mutable_data());
194   }
195 
196   /// Returns a pointer to the mutable underlying data blob.
mutable_data()197   void* mutable_data() const {
198     return data_;
199   }
200 
201   /// Sets the underlying data blob to the passed in pointer.
set_data(void * ptr)202   void set_data(void* ptr) {
203     data_ = ptr;
204   }
205 
206   /*
207    * DEPRECATED: Use torch::executor::resize_tensor() or
208    * torch::executor::resize_tensor_impl().
209    */
210   ET_DEPRECATED
set_sizes_contiguous(ArrayRef<SizesType> new_sizes)211   void set_sizes_contiguous(ArrayRef<SizesType> new_sizes) {
212     Error err = internal_resize_contiguous(new_sizes);
213     ET_CHECK_MSG(
214         err == Error::Ok, "Could not resize Tensor; see logs for details");
215   }
216 
217  private:
218   // For access to internal_resize_contiguous().
219   friend class ::executorch::runtime::internal::TensorResizerFriend;
220 
221   /**
222    * Set the sizes and strides of a tensor assuming contiguous strides.
223    * Requires that `new_sizes.size() == this.dim()`.
224    *
225    * Callers must use torch::executor::resize_tensor() or
226    * torch::executor::resize_tensor_impl() instead, defined in TensorUtil.h.
227    *
228    * Same semantics as at::TensorImpl::set_sizes_contiguous(), but returns an
229    * error instead of panicking on failure. This is not part of the at::Tensor
230    * API, and can only be used in lean mode.
231    */
232   ET_NODISCARD Error internal_resize_contiguous(ArrayRef<SizesType> new_sizes);
233 
234  private:
235   // Keep fields arranged to avoid unnecessary alignment holes.
236 
237   /// List of sizes of each dimension in the tensor.
238   SizesType* sizes_;
239 
240   /// List of the order that dimensions are laid out in memory.
241   DimOrderType* dim_order_;
242 
243   // TODO(T148356881): Get rid of strides from ETensor
244   StridesType* strides_;
245 
246   /// Pointer to underlying data blob. NOTE: Can be null.
247   void* data_;
248 
249   /// Tensor's number of dimensions.
250   const ssize_t dim_;
251 
252   /// Number of elements in the tensor.
253   ssize_t numel_;
254 
255   /// Maximum number of elements in the bounded tensor. Used when resizing up
256   /// and down.
257   size_t numel_bound_;
258 
259   /// Scalar type (int, float, bool, etc) of the tensor data.
260   const ScalarType type_;
261 
262   /// Specifies the mutability of the shape of the tensor.
263   const TensorShapeDynamism shape_dynamism_;
264 };
265 
266 /**
267  * Compute the number of elements based on the sizes of a tensor.
268  */
269 ssize_t compute_numel(
270     const ::executorch::runtime::etensor::TensorImpl::SizesType* sizes,
271     ssize_t dim);
272 
273 } // namespace etensor
274 } // namespace runtime
275 } // namespace executorch
276 
277 namespace torch {
278 namespace executor {
279 // TODO(T197294990): Remove these deprecated aliases once all users have moved
280 // to the new `::executorch` namespaces.
281 using ::executorch::runtime::etensor::compute_numel;
282 using ::executorch::runtime::etensor::TensorImpl;
283 } // namespace executor
284 } // namespace torch
285