1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/repeat_util.h>
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13 #include <string.h>
14
15 namespace torch {
16 namespace executor {
17
18 using Tensor = exec_aten::Tensor;
19 using ScalarType = exec_aten::ScalarType;
20
free_broadcast_tensor(const Tensor & broadcast_tensor)21 void free_broadcast_tensor(const Tensor& broadcast_tensor) {
22 free((void*)broadcast_tensor.const_data_ptr());
23 free((void*)broadcast_tensor.sizes().data());
24 free((void*)broadcast_tensor.dim_order().data());
25 free((void*)broadcast_tensor.strides().data());
26 free(broadcast_tensor.unsafeGetTensorImpl());
27 }
28
29 namespace {
30
make_tensor(const ArrayRef<Tensor::SizesType> & sizes,const ArrayRef<Tensor::DimOrderType> & dim_order,const ArrayRef<Tensor::StridesType> & strides,const ScalarType & dtype)31 Tensor make_tensor(
32 const ArrayRef<Tensor::SizesType>& sizes,
33 const ArrayRef<Tensor::DimOrderType>& dim_order,
34 const ArrayRef<Tensor::StridesType>& strides,
35 const ScalarType& dtype) {
36 int dim = sizes.size();
37 int size_nbytes = dim * sizeof(Tensor::SizesType);
38 void* size_data_ptr = malloc(size_nbytes);
39 ET_CHECK_MSG(size_data_ptr != nullptr, "Failed to malloc for size bytes");
40 memcpy(size_data_ptr, sizes.data(), size_nbytes);
41
42 // TODO(T145322324): can we remove the static cast once size is unsigned?
43 size_t dim_order_nbytes =
44 static_cast<size_t>(dim) * sizeof(Tensor::DimOrderType);
45 // This is leaking memory?
46 // TODO(T147221312)
47 void* dim_order_data_ptr = malloc(dim_order_nbytes);
48 ET_CHECK_MSG(
49 dim_order_data_ptr != nullptr, "Failed to malloc for dim order bytes");
50 memcpy(dim_order_data_ptr, dim_order.data(), dim_order_nbytes);
51
52 int strides_nbytes = dim * sizeof(Tensor::StridesType);
53 void* strides_data_ptr = malloc(strides_nbytes);
54 ET_CHECK_MSG(
55 strides_data_ptr != nullptr, "Failed to malloc for strides bytes");
56 memcpy(strides_data_ptr, strides.data(), strides_nbytes);
57
58 auto tensor_impl = static_cast<TensorImpl*>(malloc(sizeof(TensorImpl)));
59 ET_CHECK_MSG(tensor_impl != nullptr, "Failed to malloc for data TensorImpl");
60
61 new (tensor_impl) TensorImpl(
62 dtype,
63 dim,
64 reinterpret_cast<Tensor::SizesType*>(size_data_ptr),
65 nullptr,
66 reinterpret_cast<Tensor::DimOrderType*>(dim_order_data_ptr),
67 reinterpret_cast<Tensor::StridesType*>(strides_data_ptr));
68
69 void* data_ptr = malloc(tensor_impl->nbytes());
70 ET_CHECK_MSG(data_ptr != nullptr, "Failed to malloc for data buffer");
71 tensor_impl->set_data(data_ptr);
72
73 return Tensor{tensor_impl};
74 }
75
76 } // namespace
77
tensor_is_broadcastable_to(const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape)78 bool tensor_is_broadcastable_to(
79 const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
80 const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape) {
81 bool feasible_bcast = true;
82
83 if (broadcast_to_shape.size() < broadcast_from_shape.size()) {
84 return false;
85 }
86
87 for (int i = broadcast_to_shape.size() - 1,
88 j = broadcast_from_shape.size() - 1;
89 j >= 0;
90 --i, --j) {
91 auto broadcast_to_s = broadcast_to_shape[i],
92 broadcast_from_s = broadcast_from_shape[j];
93 feasible_bcast &=
94 broadcast_to_s == broadcast_from_s || broadcast_from_s == 1;
95 if (!feasible_bcast) {
96 return false;
97 }
98 }
99
100 return feasible_bcast;
101 }
102
tensor_is_broadcastable_to(const Tensor & broadcast_from,const Tensor & broadcast_to)103 bool tensor_is_broadcastable_to(
104 const Tensor& broadcast_from,
105 const Tensor& broadcast_to) {
106 return tensor_is_broadcastable_to(
107 broadcast_from.sizes(), broadcast_to.sizes());
108 }
109
tensors_are_broadcastable_between(const exec_aten::ArrayRef<Tensor::SizesType> a_shape,const exec_aten::ArrayRef<Tensor::SizesType> b_shape)110 bool tensors_are_broadcastable_between(
111 const exec_aten::ArrayRef<Tensor::SizesType> a_shape,
112 const exec_aten::ArrayRef<Tensor::SizesType> b_shape) {
113 auto a_dim = a_shape.size();
114 auto b_dim = b_shape.size();
115
116 // Although the documentation (https://fburl.com/n9wl4d0o) says that tensor
117 // with 0-dim can not be broadcasted, experiment shows that actually it can
118 // (https://www.internalfb.com/intern/px/p/2pMT0). So here we do not test the
119 // dimension.
120
121 for (int a_index = a_dim - 1, b_index = b_dim - 1;
122 a_index >= 0 && b_index >= 0;
123 a_index--, b_index--) {
124 if (a_shape[a_index] == b_shape[b_index] || a_shape[a_index] == 1 ||
125 b_shape[b_index] == 1) {
126 continue;
127 }
128 return false;
129 }
130
131 return true;
132 }
133
tensors_are_broadcastable_between(const Tensor & a,const Tensor & b)134 bool tensors_are_broadcastable_between(const Tensor& a, const Tensor& b) {
135 return tensors_are_broadcastable_between(a.sizes(), b.sizes());
136 }
137
138 // Broadcast tensor broadcast_from to match broadcast_to's shape, and return the
139 // broadcasted tensor.
broadcast_tensor(const Tensor & broadcast_from,const Tensor & broadcast_to)140 Tensor broadcast_tensor(
141 const Tensor& broadcast_from,
142 const Tensor& broadcast_to) {
143 auto broadcast_to_shape = broadcast_to.sizes();
144 auto broadcast_from_shape = broadcast_from.sizes();
145 auto broadcast_to_dim_order = broadcast_to.dim_order();
146 auto broadcast_to_strides = broadcast_to.strides();
147
148 // First check if broadcast_from is broadcastable to broadcast_to.
149 // Essentially, we can broadcast broadcast_from if it meets three conditions
150 // along any dimension i: (1) broadcast_to[i] = broadcast_from[i]; (2)
151 // broadcast_from[i] = 1; or (3) broadcast_from[i] does not exist.
152 // for torch.tensor(11), the dim is 0 so we can't use *.sizes().empty() to
153 // check.
154 ET_CHECK_MSG(
155 broadcast_from.numel() != 0 || !(broadcast_from).sizes().empty(),
156 "Input tensor must be non-empty");
157 // there would never be a broadcast_to with only 1 element, so we are checking
158 // dim here.
159 ET_CHECK_MSG(
160 !(broadcast_to).sizes().empty(), "Input tensor must be non-empty");
161 ET_CHECK_MSG(
162 broadcast_to_shape.size() >= broadcast_from_shape.size(),
163 "For broadcast, tensor broadcast_to must be higher dimensional than tensor broadcast_from");
164
165 bool feasible_bcast =
166 tensor_is_broadcastable_to(broadcast_from, broadcast_to);
167
168 ET_CHECK_MSG(
169 feasible_bcast,
170 "Cannot broadcast tensor broadcast_from into tensor broadcast_to along some dimensions");
171
172 // Once we have discovered that broadcast_from can be broadcasted into
173 // broadcast_to, use repeat() to do the broadcast.
174 Tensor out = make_tensor(
175 broadcast_to_shape,
176 broadcast_to_dim_order,
177 broadcast_to_strides,
178 broadcast_from.scalar_type());
179
180 // We need to pass IntArrayRef (i.e. ArrayRef<int64_t>) to cpu::repeat() but
181 // .sizes() is ArrayRef<int32_t>
182 using T = IntArrayRef::value_type;
183 auto ndim = broadcast_to.dim();
184
185 // repeat is int64_t* but broadcast_to_shape is at::ArrayRef<int32_t>
186 T* repeats = static_cast<T*>(malloc((ndim) * sizeof(T)));
187 for (int i = 0; i < ndim; ++i) {
188 repeats[i] = broadcast_to_shape[i];
189 }
190
191 // Compute the repeat factor along each dimension
192 for (int i = broadcast_to_shape.size() - 1,
193 j = broadcast_from_shape.size() - 1;
194 j >= 0;
195 --i, --j) {
196 if (broadcast_to_shape[i] == broadcast_from_shape[j]) {
197 repeats[i] = 1;
198 }
199 }
200
201 ET_CHECK(
202 repeat_tensor(broadcast_from, makeArrayRef(repeats, ndim), out) ==
203 Error::Ok);
204
205 free(repeats);
206
207 return out;
208 }
209
get_broadcast_target_size(const exec_aten::ArrayRef<Tensor::SizesType> a_size,const exec_aten::ArrayRef<Tensor::SizesType> b_size,Tensor::SizesType * out_sizes,const size_t out_sizes_len,size_t * out_dim)210 ET_NODISCARD Error get_broadcast_target_size(
211 const exec_aten::ArrayRef<Tensor::SizesType> a_size,
212 const exec_aten::ArrayRef<Tensor::SizesType> b_size,
213 Tensor::SizesType* out_sizes,
214 const size_t out_sizes_len,
215 size_t* out_dim) {
216 ET_CHECK_OR_RETURN_ERROR(
217 tensors_are_broadcastable_between(a_size, b_size),
218 InvalidArgument,
219 "Two input tensors should be broadcastable.\n");
220
221 auto a_dim = a_size.size();
222 auto b_dim = b_size.size();
223
224 ET_CHECK_OR_RETURN_ERROR(
225 a_dim <= out_sizes_len && b_dim <= out_sizes_len,
226 InvalidArgument,
227 "Dim of input tensors should be smaller than the limitation, but find %zu, %zu and %zu.",
228 a_dim,
229 b_dim,
230 out_sizes_len);
231
232 *out_dim = a_dim > b_dim ? a_dim : b_dim;
233
234 for (int a_idx = a_dim - 1,
235 b_idx = b_dim - 1,
236 expected_target_idx = *out_dim - 1;
237 expected_target_idx >= 0;
238 a_idx--, b_idx--, expected_target_idx--) {
239 if (a_idx >= 0 && b_idx >= 0) {
240 out_sizes[expected_target_idx] =
241 b_size[b_idx] == 1 ? a_size[a_idx] : b_size[b_idx];
242 } else {
243 out_sizes[expected_target_idx] =
244 a_idx >= 0 ? a_size[a_idx] : b_size[b_idx];
245 }
246 }
247
248 return Error::Ok;
249 }
250
get_broadcast_target_size(const Tensor & a,const Tensor & b,Tensor::SizesType * out_sizes,const size_t out_sizes_len,size_t * out_dim)251 ET_NODISCARD Error get_broadcast_target_size(
252 const Tensor& a,
253 const Tensor& b,
254 Tensor::SizesType* out_sizes,
255 const size_t out_sizes_len,
256 size_t* out_dim) {
257 return get_broadcast_target_size(
258 a.sizes(), b.sizes(), out_sizes, out_sizes_len, out_dim);
259 }
260
delinearize_index(size_t linear_index,exec_aten::ArrayRef<Tensor::SizesType> shape,size_t * out_indexes,const size_t out_indexes_len)261 void delinearize_index(
262 size_t linear_index,
263 exec_aten::ArrayRef<Tensor::SizesType> shape,
264 size_t* out_indexes,
265 const size_t out_indexes_len) {
266 ET_CHECK(shape.size() <= out_indexes_len);
267 for (auto i = 0; i < shape.size(); ++i) {
268 auto dim = shape.size() - 1 - i;
269 auto dim_size = shape[dim];
270 out_indexes[dim] = linear_index % dim_size;
271 linear_index /= dim_size;
272 }
273 }
274
delinearize_index(size_t linear_index,const Tensor & t,size_t * out_indexes,const size_t out_indexes_len)275 void delinearize_index(
276 size_t linear_index,
277 const Tensor& t,
278 size_t* out_indexes,
279 const size_t out_indexes_len) {
280 delinearize_index(linear_index, t.sizes(), out_indexes, out_indexes_len);
281 }
282
linearize_access_indexes(ArrayRef<size_t> indexes_broadcast_to,ssize_t broadcast_to_ndim,exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,exec_aten::ArrayRef<Tensor::StridesType> broadcast_from_strides)283 size_t linearize_access_indexes(
284 ArrayRef<size_t> indexes_broadcast_to,
285 ssize_t broadcast_to_ndim,
286 exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
287 exec_aten::ArrayRef<Tensor::StridesType> broadcast_from_strides) {
288 size_t num_skip_dims = broadcast_to_ndim - broadcast_from_shape.size();
289 ArrayRef<size_t> indexes_broadcast_from = indexes_broadcast_to.slice(
290 num_skip_dims, broadcast_to_ndim - num_skip_dims);
291
292 ET_CHECK(indexes_broadcast_from.size() == broadcast_from_shape.size());
293
294 size_t linear_index = 0;
295 for (size_t i = 0; i < indexes_broadcast_from.size(); ++i) {
296 // If this dimension is broadcasted, add zero to the linear address.
297 if (indexes_broadcast_from[i] >= broadcast_from_shape[i]) {
298 ET_CHECK_MSG(
299 broadcast_from_shape[i] == 1,
300 "Expected dim size == 1 if broadcasted, but actual dim size is %zu",
301 static_cast<size_t>(broadcast_from_shape[i]));
302 continue;
303 }
304 linear_index += indexes_broadcast_from[i] * broadcast_from_strides[i];
305 }
306 return linear_index;
307 }
308
linearize_access_indexes(ArrayRef<size_t> indexes_broadcast_to,ssize_t broadcast_to_ndim,const Tensor & broadcast_from)309 size_t linearize_access_indexes(
310 ArrayRef<size_t> indexes_broadcast_to,
311 ssize_t broadcast_to_ndim,
312 const Tensor& broadcast_from) {
313 return linearize_access_indexes(
314 indexes_broadcast_to,
315 broadcast_to_ndim,
316 broadcast_from.sizes(),
317 broadcast_from.strides());
318 }
319
320 } // namespace executor
321 } // namespace torch
322