1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/runtime/core/array_ref.h> 12 #include <executorch/runtime/core/error.h> 13 #include <executorch/runtime/core/portable_type/scalar_type.h> 14 #include <executorch/runtime/core/tensor_shape_dynamism.h> 15 16 // Forward declaration of a helper that provides access to internal resizing 17 // methods of TensorImpl. Real definition is in 18 // executorch/runtime/core/exec_aten/tensor_util.h. 19 namespace executorch { 20 namespace runtime { 21 namespace internal { 22 class TensorResizerFriend; 23 } // namespace internal 24 } // namespace runtime 25 } // namespace executorch 26 27 namespace executorch { 28 namespace runtime { 29 namespace etensor { 30 31 /** 32 * Manages the storage behind an ETensor (torch::executor::Tensor). 33 * 34 * Note that instances of this class do not own the arrays given to it 35 * (sizes/strides/data), which means that the caller must guarantee that they 36 * live longer than a given instance of this class. 37 * 38 * Note on types: 39 * 40 * Code that uses ETensor should also be able to build against at::Tensor. So, 41 * although the overlapping APIs don't need to be exactly the same, their types 42 * should be semantically similar. 43 * 44 * Many of the methods in at::Tensor use int64_t for parameter and return types. 45 * This can be a waste when building for 32-bit environments. So, TensorImpl and 46 * ETensor use ssize_t instead: like int64_t it is signed, but it will match the 47 * native word size of the target architecture. This will avoid unnecessarily 48 * expensive uses of 64-bit integers on 32-bit machines. 49 * 50 * But, since the types are not identical, code that uses ETensor needs to be 51 * generic about the local types it uses when working with these methods. In 52 * most cases, `auto` will do the trick. In the worst case, code can be guarded 53 * with `#ifdef USE_ATEN_LIB`. 54 */ 55 class TensorImpl { 56 public: 57 /** 58 * The type used for elements of `sizes()`. 59 * 60 * This must match the size/signedness of the type used for `Tensor.sizes` in 61 * //executorch/schema/program.fbs. 62 * 63 * Note that at::TensorImpl uses `int64_t` for this type. ExecuTorch uses 64 * `int32_t` to save memory, since no single size value will ever be larger 65 * than 2 billion. 66 */ 67 using SizesType = int32_t; 68 69 /** 70 * The type used for elements of `dim_order()`. 71 * 72 * This must match the size/signedness of the type used for `Tensor.dim_order` 73 * in //executorch/schema/program.fbs. 74 */ 75 using DimOrderType = uint8_t; 76 77 /** 78 * The type used for elements of `strides()`. 79 * 80 * This must match the size/signedness of the type used for `Tensor.strides` 81 * in //executorch/schema/program.fbs. 82 * 83 * Note that at::TensorImpl uses `int64_t` for this type. ExecuTorch uses 84 * `int32_t` to save memory, since no single stride value will ever be larger 85 * than 2 billion. 86 */ 87 using StridesType = int32_t; 88 89 TensorImpl() = delete; 90 91 /** 92 * @param type The type of the data (int, float, bool). 93 * @param dim Number of dimensions, and the length of the `sizes` array. 94 * @param sizes Sizes of the tensor at each dimension. Must contain `dim` 95 * entries. 96 * @param data Pointer to the data, whose size is determined by `type`, 97 * `dim`, and `sizes`. The tensor will not own this memory. 98 * @param dim_order Order in which dimensions are laid out in memory. 99 * @param strides Strides of the tensor at each dimension. Must contain `dim` 100 * entries. 101 * @param dynamism The mutability of the shape of the tensor. 102 */ 103 TensorImpl( 104 ScalarType type, 105 ssize_t dim, 106 SizesType* sizes, 107 void* data = nullptr, 108 DimOrderType* dim_order = nullptr, 109 StridesType* strides = nullptr, 110 TensorShapeDynamism dynamism = TensorShapeDynamism::STATIC); 111 112 /** 113 * Returns the size of the tensor in bytes. 114 * 115 * NOTE: This returns the size of the data used by the tensor's current shape, 116 * not the capacity of the underlying buffer. 117 */ 118 size_t nbytes() const; 119 120 /** 121 * Returns the size of the tensor at the given dimension. 122 * 123 * NOTE: size() intentionally does not return SizeType even though it 124 * returns an element of an array of SizeType. This is to help make calls of 125 * this method more compatible with at::Tensor, and more consistent with the 126 * rest of the methods on this class and in ETensor. 127 */ size(ssize_t dim)128 ssize_t size(ssize_t dim) const { 129 ET_CHECK_MSG( 130 dim < dim_ && dim >= 0, 131 "Dimension out of range (expected to be in range of [0, %zd], but got %zd", 132 dim_ - 1, 133 dim); 134 return sizes_[dim]; 135 } 136 137 /// Returns the tensor's number of dimensions. dim()138 ssize_t dim() const { 139 return dim_; 140 } 141 142 /// Returns the number of elements in the tensor. numel()143 ssize_t numel() const { 144 return numel_; 145 } 146 147 /// Returns the type of the elements in the tensor (int32, float, bool, etc). scalar_type()148 ScalarType scalar_type() const { 149 return type_; 150 } 151 dtype()152 inline ScalarType dtype() const { 153 return scalar_type(); 154 } 155 156 /// Returns the size in bytes of one element of the tensor. 157 ssize_t element_size() const; 158 159 /// Returns the sizes of the tensor at each dimension. sizes()160 const ArrayRef<SizesType> sizes() const { 161 return ArrayRef<SizesType>{sizes_, static_cast<size_t>(dim_)}; 162 } 163 164 /// Returns the order the dimensions are laid out in memory. dim_order()165 const ArrayRef<DimOrderType> dim_order() const { 166 return ArrayRef<DimOrderType>{dim_order_, static_cast<size_t>(dim_)}; 167 } 168 169 /// Returns the strides of the tensor at each dimension. strides()170 const ArrayRef<StridesType> strides() const { 171 return ArrayRef<StridesType>{strides_, static_cast<size_t>(dim_)}; 172 } 173 174 /// Returns the mutability of the shape of the tensor. shape_dynamism()175 TensorShapeDynamism shape_dynamism() const { 176 return shape_dynamism_; 177 } 178 179 /// Returns a pointer of type T to the constant underlying data blob. 180 template <typename T> data()181 inline const T* data() const { 182 return static_cast<const T*>(data()); 183 } 184 185 /// Returns a pointer to the constant underlying data blob. data()186 const void* data() const { 187 return data_; 188 } 189 190 /// Returns a pointer of type T to the mutable underlying data blob. 191 template <typename T> mutable_data()192 inline T* mutable_data() const { 193 return static_cast<T*>(mutable_data()); 194 } 195 196 /// Returns a pointer to the mutable underlying data blob. mutable_data()197 void* mutable_data() const { 198 return data_; 199 } 200 201 /// Sets the underlying data blob to the passed in pointer. set_data(void * ptr)202 void set_data(void* ptr) { 203 data_ = ptr; 204 } 205 206 /* 207 * DEPRECATED: Use torch::executor::resize_tensor() or 208 * torch::executor::resize_tensor_impl(). 209 */ 210 ET_DEPRECATED set_sizes_contiguous(ArrayRef<SizesType> new_sizes)211 void set_sizes_contiguous(ArrayRef<SizesType> new_sizes) { 212 Error err = internal_resize_contiguous(new_sizes); 213 ET_CHECK_MSG( 214 err == Error::Ok, "Could not resize Tensor; see logs for details"); 215 } 216 217 private: 218 // For access to internal_resize_contiguous(). 219 friend class ::executorch::runtime::internal::TensorResizerFriend; 220 221 /** 222 * Set the sizes and strides of a tensor assuming contiguous strides. 223 * Requires that `new_sizes.size() == this.dim()`. 224 * 225 * Callers must use torch::executor::resize_tensor() or 226 * torch::executor::resize_tensor_impl() instead, defined in TensorUtil.h. 227 * 228 * Same semantics as at::TensorImpl::set_sizes_contiguous(), but returns an 229 * error instead of panicking on failure. This is not part of the at::Tensor 230 * API, and can only be used in lean mode. 231 */ 232 ET_NODISCARD Error internal_resize_contiguous(ArrayRef<SizesType> new_sizes); 233 234 private: 235 // Keep fields arranged to avoid unnecessary alignment holes. 236 237 /// List of sizes of each dimension in the tensor. 238 SizesType* sizes_; 239 240 /// List of the order that dimensions are laid out in memory. 241 DimOrderType* dim_order_; 242 243 // TODO(T148356881): Get rid of strides from ETensor 244 StridesType* strides_; 245 246 /// Pointer to underlying data blob. NOTE: Can be null. 247 void* data_; 248 249 /// Tensor's number of dimensions. 250 const ssize_t dim_; 251 252 /// Number of elements in the tensor. 253 ssize_t numel_; 254 255 /// Maximum number of elements in the bounded tensor. Used when resizing up 256 /// and down. 257 size_t numel_bound_; 258 259 /// Scalar type (int, float, bool, etc) of the tensor data. 260 const ScalarType type_; 261 262 /// Specifies the mutability of the shape of the tensor. 263 const TensorShapeDynamism shape_dynamism_; 264 }; 265 266 /** 267 * Compute the number of elements based on the sizes of a tensor. 268 */ 269 ssize_t compute_numel( 270 const ::executorch::runtime::etensor::TensorImpl::SizesType* sizes, 271 ssize_t dim); 272 273 } // namespace etensor 274 } // namespace runtime 275 } // namespace executorch 276 277 namespace torch { 278 namespace executor { 279 // TODO(T197294990): Remove these deprecated aliases once all users have moved 280 // to the new `::executorch` namespaces. 281 using ::executorch::runtime::etensor::compute_numel; 282 using ::executorch::runtime::etensor::TensorImpl; 283 } // namespace executor 284 } // namespace torch 285