1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/tensor/tensor_ptr.h>
10
11 #include <numeric>
12
13 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
14
15 namespace executorch {
16 namespace extension {
17 namespace {
18 #ifndef USE_ATEN_LIB
19 /**
20 * A structure that consolidates the metadata (sizes, dim_order, strides) and
21 * the data buffer associated with a Tensor. Since Tensor does not own
22 * the memory for these metadata arrays or the data itself, this structure
23 * ensures that they are managed together and have the same lifetime as the
24 * Tensor. When the Tensor is destroyed, the Storage structure ensures
25 * proper cleanup of the associated metadata and data if needed.
26 */
27 struct Storage final {
28 exec_aten::TensorImpl tensor_impl;
29 exec_aten::Tensor tensor;
30 std::vector<exec_aten::SizesType> sizes;
31 std::vector<exec_aten::DimOrderType> dim_order;
32 std::vector<exec_aten::StridesType> strides;
33 std::function<void(void*)> deleter;
34
Storageexecutorch::extension::__anonea639a470111::Storage35 Storage(
36 exec_aten::TensorImpl&& tensor_impl,
37 std::vector<exec_aten::SizesType>&& sizes,
38 std::vector<exec_aten::DimOrderType>&& dim_order,
39 std::vector<exec_aten::StridesType>&& strides,
40 std::function<void(void*)>&& deleter)
41 : tensor_impl(std::move(tensor_impl)),
42 tensor(&this->tensor_impl),
43 sizes(std::move(sizes)),
44 dim_order(std::move(dim_order)),
45 strides(std::move(strides)),
46 deleter(std::move(deleter)) {}
47
~Storageexecutorch::extension::__anonea639a470111::Storage48 ~Storage() {
49 if (deleter) {
50 deleter(tensor_impl.mutable_data());
51 }
52 }
53 };
54 #endif // USE_ATEN_LIB
55 } // namespace
56
make_tensor_ptr(std::vector<exec_aten::SizesType> sizes,void * data,std::vector<exec_aten::DimOrderType> dim_order,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism,std::function<void (void *)> deleter)57 TensorPtr make_tensor_ptr(
58 std::vector<exec_aten::SizesType> sizes,
59 void* data,
60 std::vector<exec_aten::DimOrderType> dim_order,
61 std::vector<exec_aten::StridesType> strides,
62 exec_aten::ScalarType type,
63 exec_aten::TensorShapeDynamism dynamism,
64 std::function<void(void*)> deleter) {
65 const auto dim = sizes.size();
66 ET_CHECK_MSG(
67 dim_order.empty() || dim_order.size() == dim,
68 "dim_order size must match sizes or be empty.");
69 ET_CHECK_MSG(
70 strides.empty() || strides.size() == dim,
71 "strides size must match sizes or be empty.");
72
73 if (dim_order.empty()) {
74 dim_order.resize(dim);
75 std::iota(dim_order.begin(), dim_order.end(), 0);
76 if (!strides.empty()) {
77 std::sort(dim_order.begin(), dim_order.end(), [&](size_t a, size_t b) {
78 return strides[a] > strides[b];
79 });
80 }
81 }
82 std::vector<exec_aten::StridesType> computed_strides(dim);
83 auto error = runtime::dim_order_to_stride(
84 sizes.data(), dim_order.data(), dim, computed_strides.data());
85 ET_CHECK_MSG(error == runtime::Error::Ok, "Failed to compute strides.");
86
87 if (!strides.empty()) {
88 ET_CHECK_MSG(computed_strides == strides, "Invalid strides provided.");
89 } else {
90 strides = std::move(computed_strides);
91 }
92 #ifndef USE_ATEN_LIB
93 exec_aten::TensorImpl tensor_impl(
94 type,
95 dim,
96 sizes.data(),
97 data,
98 dim_order.data(),
99 strides.data(),
100 dim > 0 ? dynamism : exec_aten::TensorShapeDynamism::STATIC);
101 auto storage = std::make_shared<Storage>(
102 std::move(tensor_impl),
103 std::move(sizes),
104 std::move(dim_order),
105 std::move(strides),
106 std::move(deleter));
107 const auto tensor_ptr = &storage->tensor;
108 return std::shared_ptr<exec_aten::Tensor>(std::move(storage), tensor_ptr);
109 #else
110 auto options = c10::TensorOptions()
111 .dtype(c10::scalarTypeToTypeMeta(type))
112 .device(c10::kCPU);
113 auto storage = c10::Storage(
114 c10::Storage::use_byte_size_t(),
115 at::detail::computeStorageNbytes(
116 sizes, strides, options.dtype().itemsize()),
117 c10::InefficientStdFunctionContext::makeDataPtr(
118 data, std::move(deleter), options.device()),
119 nullptr,
120 false);
121 auto tensor_impl = c10::make_intrusive<exec_aten::TensorImpl>(
122 std::move(storage),
123 c10::DispatchKeySet(c10::DispatchKey::CPU),
124 options.dtype());
125 tensor_impl->set_sizes_and_strides(sizes, strides);
126 return std::make_shared<exec_aten::Tensor>(std::move(tensor_impl));
127 #endif // USE_ATEN_LIB
128 }
129
make_tensor_ptr(std::vector<exec_aten::SizesType> sizes,std::vector<uint8_t> data,std::vector<exec_aten::DimOrderType> dim_order,std::vector<exec_aten::StridesType> strides,exec_aten::ScalarType type,exec_aten::TensorShapeDynamism dynamism)130 TensorPtr make_tensor_ptr(
131 std::vector<exec_aten::SizesType> sizes,
132 std::vector<uint8_t> data,
133 std::vector<exec_aten::DimOrderType> dim_order,
134 std::vector<exec_aten::StridesType> strides,
135 exec_aten::ScalarType type,
136 exec_aten::TensorShapeDynamism dynamism) {
137 ET_CHECK_MSG(
138 data.size() >= exec_aten::compute_numel(sizes.data(), sizes.size()) *
139 exec_aten::elementSize(type),
140 "Data size is smaller than required by sizes and scalar type.");
141 auto data_ptr = data.data();
142 return make_tensor_ptr(
143 std::move(sizes),
144 data_ptr,
145 std::move(dim_order),
146 std::move(strides),
147 type,
148 dynamism,
149 // Data is moved into the deleter and is destroyed together with Storage.
150 [data = std::move(data)](void*) {});
151 }
152
clone_tensor_ptr(const exec_aten::Tensor & tensor)153 TensorPtr clone_tensor_ptr(const exec_aten::Tensor& tensor) {
154 std::vector<exec_aten::SizesType> sizes(
155 tensor.sizes().begin(), tensor.sizes().end());
156 std::vector<exec_aten::DimOrderType> dim_order{
157 #ifndef USE_ATEN_LIB
158 tensor.dim_order().begin(), tensor.dim_order().end()
159 #endif // USE_ATEN_LIB
160 };
161 std::vector<exec_aten::StridesType> strides(
162 tensor.strides().begin(), tensor.strides().end());
163 auto dynamism = exec_aten::TensorShapeDynamism::DYNAMIC_BOUND;
164 #ifndef USE_ATEN_LIB
165 dynamism = tensor.shape_dynamism();
166 #endif // USE_ATEN_LIB
167 return tensor.const_data_ptr()
168 ? make_tensor_ptr(
169 std::move(sizes),
170 std::vector<uint8_t>(
171 (uint8_t*)tensor.const_data_ptr(),
172 (uint8_t*)tensor.const_data_ptr() + tensor.nbytes()),
173 std::move(dim_order),
174 std::move(strides),
175 tensor.scalar_type(),
176 dynamism)
177 : make_tensor_ptr(
178 std::move(sizes),
179 nullptr,
180 std::move(dim_order),
181 std::move(strides),
182 tensor.scalar_type(),
183 dynamism);
184 }
185
clone_tensor_ptr(const exec_aten::Tensor & tensor)186 TensorPtr clone_tensor_ptr(const exec_aten::Tensor& tensor) {
187 std::vector<exec_aten::SizesType> sizes(
188 tensor.sizes().begin(), tensor.sizes().end());
189 std::vector<exec_aten::DimOrderType> dim_order{
190 #ifndef USE_ATEN_LIB
191 tensor.dim_order().begin(), tensor.dim_order().end()
192 #endif // USE_ATEN_LIB
193 };
194 std::vector<exec_aten::StridesType> strides(
195 tensor.strides().begin(), tensor.strides().end());
196 auto dynamism = exec_aten::TensorShapeDynamism::DYNAMIC_BOUND;
197 #ifndef USE_ATEN_LIB
198 dynamism = tensor.shape_dynamism();
199 #endif // USE_ATEN_LIB
200 return tensor.const_data_ptr()
201 ? make_tensor_ptr(
202 std::move(sizes),
203 std::vector<uint8_t>(
204 (uint8_t*)tensor.const_data_ptr(),
205 (uint8_t*)tensor.const_data_ptr() + tensor.nbytes()),
206 std::move(dim_order),
207 std::move(strides),
208 tensor.scalar_type(),
209 dynamism)
210 : make_tensor_ptr(
211 std::move(sizes),
212 nullptr,
213 std::move(dim_order),
214 std::move(strides),
215 tensor.scalar_type(),
216 dynamism);
217 }
218
resize_tensor_ptr(TensorPtr & tensor,const std::vector<exec_aten::SizesType> & sizes)219 runtime::Error resize_tensor_ptr(
220 TensorPtr& tensor,
221 const std::vector<exec_aten::SizesType>& sizes) {
222 return runtime::resize_tensor(
223 *tensor,
224 exec_aten::ArrayRef<exec_aten::SizesType>(sizes.data(), sizes.size()));
225 }
226
227 } // namespace extension
228 } // namespace executorch
229