xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/broadcast_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/repeat_util.h>
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13 #include <string.h>
14 
15 namespace torch {
16 namespace executor {
17 
18 using Tensor = exec_aten::Tensor;
19 using ScalarType = exec_aten::ScalarType;
20 
free_broadcast_tensor(const Tensor & broadcast_tensor)21 void free_broadcast_tensor(const Tensor& broadcast_tensor) {
22   free((void*)broadcast_tensor.const_data_ptr());
23   free((void*)broadcast_tensor.sizes().data());
24   free((void*)broadcast_tensor.dim_order().data());
25   free((void*)broadcast_tensor.strides().data());
26   free(broadcast_tensor.unsafeGetTensorImpl());
27 }
28 
29 namespace {
30 
make_tensor(const ArrayRef<Tensor::SizesType> & sizes,const ArrayRef<Tensor::DimOrderType> & dim_order,const ArrayRef<Tensor::StridesType> & strides,const ScalarType & dtype)31 Tensor make_tensor(
32     const ArrayRef<Tensor::SizesType>& sizes,
33     const ArrayRef<Tensor::DimOrderType>& dim_order,
34     const ArrayRef<Tensor::StridesType>& strides,
35     const ScalarType& dtype) {
36   int dim = sizes.size();
37   int size_nbytes = dim * sizeof(Tensor::SizesType);
38   void* size_data_ptr = malloc(size_nbytes);
39   ET_CHECK_MSG(size_data_ptr != nullptr, "Failed to malloc for size bytes");
40   memcpy(size_data_ptr, sizes.data(), size_nbytes);
41 
42   // TODO(T145322324): can we remove the static cast once size is unsigned?
43   size_t dim_order_nbytes =
44       static_cast<size_t>(dim) * sizeof(Tensor::DimOrderType);
45   // This is leaking memory?
46   // TODO(T147221312)
47   void* dim_order_data_ptr = malloc(dim_order_nbytes);
48   ET_CHECK_MSG(
49       dim_order_data_ptr != nullptr, "Failed to malloc for dim order bytes");
50   memcpy(dim_order_data_ptr, dim_order.data(), dim_order_nbytes);
51 
52   int strides_nbytes = dim * sizeof(Tensor::StridesType);
53   void* strides_data_ptr = malloc(strides_nbytes);
54   ET_CHECK_MSG(
55       strides_data_ptr != nullptr, "Failed to malloc for strides bytes");
56   memcpy(strides_data_ptr, strides.data(), strides_nbytes);
57 
58   auto tensor_impl = static_cast<TensorImpl*>(malloc(sizeof(TensorImpl)));
59   ET_CHECK_MSG(tensor_impl != nullptr, "Failed to malloc for data TensorImpl");
60 
61   new (tensor_impl) TensorImpl(
62       dtype,
63       dim,
64       reinterpret_cast<Tensor::SizesType*>(size_data_ptr),
65       nullptr,
66       reinterpret_cast<Tensor::DimOrderType*>(dim_order_data_ptr),
67       reinterpret_cast<Tensor::StridesType*>(strides_data_ptr));
68 
69   void* data_ptr = malloc(tensor_impl->nbytes());
70   ET_CHECK_MSG(data_ptr != nullptr, "Failed to malloc for data buffer");
71   tensor_impl->set_data(data_ptr);
72 
73   return Tensor{tensor_impl};
74 }
75 
76 } // namespace
77 
tensor_is_broadcastable_to(const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape)78 bool tensor_is_broadcastable_to(
79     const exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
80     const exec_aten::ArrayRef<Tensor::SizesType> broadcast_to_shape) {
81   bool feasible_bcast = true;
82 
83   if (broadcast_to_shape.size() < broadcast_from_shape.size()) {
84     return false;
85   }
86 
87   for (int i = broadcast_to_shape.size() - 1,
88            j = broadcast_from_shape.size() - 1;
89        j >= 0;
90        --i, --j) {
91     auto broadcast_to_s = broadcast_to_shape[i],
92          broadcast_from_s = broadcast_from_shape[j];
93     feasible_bcast &=
94         broadcast_to_s == broadcast_from_s || broadcast_from_s == 1;
95     if (!feasible_bcast) {
96       return false;
97     }
98   }
99 
100   return feasible_bcast;
101 }
102 
tensor_is_broadcastable_to(const Tensor & broadcast_from,const Tensor & broadcast_to)103 bool tensor_is_broadcastable_to(
104     const Tensor& broadcast_from,
105     const Tensor& broadcast_to) {
106   return tensor_is_broadcastable_to(
107       broadcast_from.sizes(), broadcast_to.sizes());
108 }
109 
tensors_are_broadcastable_between(const exec_aten::ArrayRef<Tensor::SizesType> a_shape,const exec_aten::ArrayRef<Tensor::SizesType> b_shape)110 bool tensors_are_broadcastable_between(
111     const exec_aten::ArrayRef<Tensor::SizesType> a_shape,
112     const exec_aten::ArrayRef<Tensor::SizesType> b_shape) {
113   auto a_dim = a_shape.size();
114   auto b_dim = b_shape.size();
115 
116   // Although the documentation (https://fburl.com/n9wl4d0o) says that tensor
117   // with 0-dim can not be broadcasted, experiment shows that actually it can
118   // (https://www.internalfb.com/intern/px/p/2pMT0). So here we do not test the
119   // dimension.
120 
121   for (int a_index = a_dim - 1, b_index = b_dim - 1;
122        a_index >= 0 && b_index >= 0;
123        a_index--, b_index--) {
124     if (a_shape[a_index] == b_shape[b_index] || a_shape[a_index] == 1 ||
125         b_shape[b_index] == 1) {
126       continue;
127     }
128     return false;
129   }
130 
131   return true;
132 }
133 
tensors_are_broadcastable_between(const Tensor & a,const Tensor & b)134 bool tensors_are_broadcastable_between(const Tensor& a, const Tensor& b) {
135   return tensors_are_broadcastable_between(a.sizes(), b.sizes());
136 }
137 
138 // Broadcast tensor broadcast_from to match broadcast_to's shape, and return the
139 // broadcasted tensor.
broadcast_tensor(const Tensor & broadcast_from,const Tensor & broadcast_to)140 Tensor broadcast_tensor(
141     const Tensor& broadcast_from,
142     const Tensor& broadcast_to) {
143   auto broadcast_to_shape = broadcast_to.sizes();
144   auto broadcast_from_shape = broadcast_from.sizes();
145   auto broadcast_to_dim_order = broadcast_to.dim_order();
146   auto broadcast_to_strides = broadcast_to.strides();
147 
148   // First check if broadcast_from is broadcastable to broadcast_to.
149   // Essentially, we can broadcast broadcast_from if it meets three conditions
150   // along any dimension i: (1) broadcast_to[i] = broadcast_from[i]; (2)
151   // broadcast_from[i] = 1; or (3) broadcast_from[i] does not exist.
152   // for torch.tensor(11), the dim is 0 so we can't use *.sizes().empty() to
153   // check.
154   ET_CHECK_MSG(
155       broadcast_from.numel() != 0 || !(broadcast_from).sizes().empty(),
156       "Input tensor must be non-empty");
157   // there would never be a broadcast_to with only 1 element, so we are checking
158   // dim here.
159   ET_CHECK_MSG(
160       !(broadcast_to).sizes().empty(), "Input tensor must be non-empty");
161   ET_CHECK_MSG(
162       broadcast_to_shape.size() >= broadcast_from_shape.size(),
163       "For broadcast, tensor broadcast_to must be higher dimensional than tensor broadcast_from");
164 
165   bool feasible_bcast =
166       tensor_is_broadcastable_to(broadcast_from, broadcast_to);
167 
168   ET_CHECK_MSG(
169       feasible_bcast,
170       "Cannot broadcast tensor broadcast_from into tensor broadcast_to along some dimensions");
171 
172   // Once we have discovered that broadcast_from can be broadcasted into
173   // broadcast_to, use repeat() to do the broadcast.
174   Tensor out = make_tensor(
175       broadcast_to_shape,
176       broadcast_to_dim_order,
177       broadcast_to_strides,
178       broadcast_from.scalar_type());
179 
180   // We need to pass IntArrayRef (i.e. ArrayRef<int64_t>) to cpu::repeat() but
181   // .sizes() is ArrayRef<int32_t>
182   using T = IntArrayRef::value_type;
183   auto ndim = broadcast_to.dim();
184 
185   // repeat is int64_t* but broadcast_to_shape is at::ArrayRef<int32_t>
186   T* repeats = static_cast<T*>(malloc((ndim) * sizeof(T)));
187   for (int i = 0; i < ndim; ++i) {
188     repeats[i] = broadcast_to_shape[i];
189   }
190 
191   // Compute the repeat factor along each dimension
192   for (int i = broadcast_to_shape.size() - 1,
193            j = broadcast_from_shape.size() - 1;
194        j >= 0;
195        --i, --j) {
196     if (broadcast_to_shape[i] == broadcast_from_shape[j]) {
197       repeats[i] = 1;
198     }
199   }
200 
201   ET_CHECK(
202       repeat_tensor(broadcast_from, makeArrayRef(repeats, ndim), out) ==
203       Error::Ok);
204 
205   free(repeats);
206 
207   return out;
208 }
209 
get_broadcast_target_size(const exec_aten::ArrayRef<Tensor::SizesType> a_size,const exec_aten::ArrayRef<Tensor::SizesType> b_size,Tensor::SizesType * out_sizes,const size_t out_sizes_len,size_t * out_dim)210 ET_NODISCARD Error get_broadcast_target_size(
211     const exec_aten::ArrayRef<Tensor::SizesType> a_size,
212     const exec_aten::ArrayRef<Tensor::SizesType> b_size,
213     Tensor::SizesType* out_sizes,
214     const size_t out_sizes_len,
215     size_t* out_dim) {
216   ET_CHECK_OR_RETURN_ERROR(
217       tensors_are_broadcastable_between(a_size, b_size),
218       InvalidArgument,
219       "Two input tensors should be broadcastable.\n");
220 
221   auto a_dim = a_size.size();
222   auto b_dim = b_size.size();
223 
224   ET_CHECK_OR_RETURN_ERROR(
225       a_dim <= out_sizes_len && b_dim <= out_sizes_len,
226       InvalidArgument,
227       "Dim of input tensors should be smaller than the limitation, but find %zu, %zu and %zu.",
228       a_dim,
229       b_dim,
230       out_sizes_len);
231 
232   *out_dim = a_dim > b_dim ? a_dim : b_dim;
233 
234   for (int a_idx = a_dim - 1,
235            b_idx = b_dim - 1,
236            expected_target_idx = *out_dim - 1;
237        expected_target_idx >= 0;
238        a_idx--, b_idx--, expected_target_idx--) {
239     if (a_idx >= 0 && b_idx >= 0) {
240       out_sizes[expected_target_idx] =
241           b_size[b_idx] == 1 ? a_size[a_idx] : b_size[b_idx];
242     } else {
243       out_sizes[expected_target_idx] =
244           a_idx >= 0 ? a_size[a_idx] : b_size[b_idx];
245     }
246   }
247 
248   return Error::Ok;
249 }
250 
get_broadcast_target_size(const Tensor & a,const Tensor & b,Tensor::SizesType * out_sizes,const size_t out_sizes_len,size_t * out_dim)251 ET_NODISCARD Error get_broadcast_target_size(
252     const Tensor& a,
253     const Tensor& b,
254     Tensor::SizesType* out_sizes,
255     const size_t out_sizes_len,
256     size_t* out_dim) {
257   return get_broadcast_target_size(
258       a.sizes(), b.sizes(), out_sizes, out_sizes_len, out_dim);
259 }
260 
delinearize_index(size_t linear_index,exec_aten::ArrayRef<Tensor::SizesType> shape,size_t * out_indexes,const size_t out_indexes_len)261 void delinearize_index(
262     size_t linear_index,
263     exec_aten::ArrayRef<Tensor::SizesType> shape,
264     size_t* out_indexes,
265     const size_t out_indexes_len) {
266   ET_CHECK(shape.size() <= out_indexes_len);
267   for (auto i = 0; i < shape.size(); ++i) {
268     auto dim = shape.size() - 1 - i;
269     auto dim_size = shape[dim];
270     out_indexes[dim] = linear_index % dim_size;
271     linear_index /= dim_size;
272   }
273 }
274 
delinearize_index(size_t linear_index,const Tensor & t,size_t * out_indexes,const size_t out_indexes_len)275 void delinearize_index(
276     size_t linear_index,
277     const Tensor& t,
278     size_t* out_indexes,
279     const size_t out_indexes_len) {
280   delinearize_index(linear_index, t.sizes(), out_indexes, out_indexes_len);
281 }
282 
linearize_access_indexes(ArrayRef<size_t> indexes_broadcast_to,ssize_t broadcast_to_ndim,exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,exec_aten::ArrayRef<Tensor::StridesType> broadcast_from_strides)283 size_t linearize_access_indexes(
284     ArrayRef<size_t> indexes_broadcast_to,
285     ssize_t broadcast_to_ndim,
286     exec_aten::ArrayRef<Tensor::SizesType> broadcast_from_shape,
287     exec_aten::ArrayRef<Tensor::StridesType> broadcast_from_strides) {
288   size_t num_skip_dims = broadcast_to_ndim - broadcast_from_shape.size();
289   ArrayRef<size_t> indexes_broadcast_from = indexes_broadcast_to.slice(
290       num_skip_dims, broadcast_to_ndim - num_skip_dims);
291 
292   ET_CHECK(indexes_broadcast_from.size() == broadcast_from_shape.size());
293 
294   size_t linear_index = 0;
295   for (size_t i = 0; i < indexes_broadcast_from.size(); ++i) {
296     // If this dimension is broadcasted, add zero to the linear address.
297     if (indexes_broadcast_from[i] >= broadcast_from_shape[i]) {
298       ET_CHECK_MSG(
299           broadcast_from_shape[i] == 1,
300           "Expected dim size == 1 if broadcasted, but actual dim size is %zu",
301           static_cast<size_t>(broadcast_from_shape[i]));
302       continue;
303     }
304     linear_index += indexes_broadcast_from[i] * broadcast_from_strides[i];
305   }
306   return linear_index;
307 }
308 
linearize_access_indexes(ArrayRef<size_t> indexes_broadcast_to,ssize_t broadcast_to_ndim,const Tensor & broadcast_from)309 size_t linearize_access_indexes(
310     ArrayRef<size_t> indexes_broadcast_to,
311     ssize_t broadcast_to_ndim,
312     const Tensor& broadcast_from) {
313   return linearize_access_indexes(
314       indexes_broadcast_to,
315       broadcast_to_ndim,
316       broadcast_from.sizes(),
317       broadcast_from.strides());
318 }
319 
320 } // namespace executor
321 } // namespace torch
322