1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
12 #include <executorch/runtime/platform/assert.h>
13 #include <cstring>
14
15 namespace torch {
16 namespace executor {
17
18 using Tensor = exec_aten::Tensor;
19
20 //
21 // Helper Functions
22 //
23
24 // Normalize the dimension by adding in_dim if d < 0; for 0-D, clamp to 0
_normalize_non_neg_d(ssize_t d,ssize_t in_dim)25 inline size_t _normalize_non_neg_d(ssize_t d, ssize_t in_dim) {
26 if (in_dim == 0 && (d == 0 || d == -1)) {
27 return 0;
28 }
29 if (d < 0) {
30 return d + in_dim;
31 }
32 return d;
33 }
34
check_dim_list_is_valid(const exec_aten::Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list)35 ET_NODISCARD bool check_dim_list_is_valid(
36 const exec_aten::Tensor& in,
37 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list) {
38 if (dim_list.has_value() && dim_list.value().size() != 0) {
39 const auto& reduce_dims = dim_list.value();
40 bool dim_exist[kTensorDimensionLimit];
41 memset(dim_exist, false, sizeof(dim_exist));
42 for (const auto& d : reduce_dims) {
43 if (in.dim() == 0) {
44 ET_LOG_AND_RETURN_IF_FALSE(d == 0 || d == -1);
45 } else {
46 ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(d, in.dim()));
47 }
48
49 const size_t non_neg_d = _normalize_non_neg_d(d, in.dim());
50 ET_LOG_AND_RETURN_IF_FALSE(
51 non_neg_d < kTensorDimensionLimit && non_neg_d >= 0);
52
53 ET_LOG_MSG_AND_RETURN_IF_FALSE(
54 dim_exist[non_neg_d] == false,
55 "dim %zd appears multiple times in the list of dims",
56 non_neg_d);
57 dim_exist[non_neg_d] = true;
58 }
59 }
60
61 return true;
62 }
63
check_dim_in_dim_list(const size_t dim,const size_t max_dim,const exec_aten::ArrayRef<int64_t> & dim_list)64 bool check_dim_in_dim_list(
65 const size_t dim,
66 const size_t max_dim,
67 const exec_aten::ArrayRef<int64_t>& dim_list) {
68 for (const auto& d : dim_list) {
69 const size_t non_neg_dim = _normalize_non_neg_d(d, max_dim);
70 if (dim == non_neg_dim) {
71 return true;
72 }
73 }
74 return false;
75 }
76
77 /**
78 * Returns the product of the sizes of all reduction dims.
79 */
get_reduced_dim_product(const Tensor & in,const exec_aten::optional<int64_t> & dim)80 size_t get_reduced_dim_product(
81 const Tensor& in,
82 const exec_aten::optional<int64_t>& dim) {
83 if (in.dim() == 0) {
84 return 1;
85 }
86 size_t dim_product = 1;
87 if (!dim.has_value()) {
88 for (size_t i = 0; i < in.dim(); ++i) {
89 dim_product *= in.size(i);
90 }
91 return dim_product;
92 }
93 const size_t d = _normalize_non_neg_d(dim.value(), in.dim());
94 return in.size(d);
95 }
96
97 /**
98 * Returns the product of the sizes of all reduction dims.
99 */
get_reduced_dim_product(const Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list)100 size_t get_reduced_dim_product(
101 const Tensor& in,
102 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list) {
103 if (in.dim() == 0) {
104 return 1;
105 }
106 size_t dim_product = 1;
107 const size_t in_dim = in.dim();
108 if (!dim_list.has_value() || dim_list.value().size() == 0) {
109 for (size_t i = 0; i < in.dim(); ++i) {
110 dim_product *= in.size(i);
111 }
112 return dim_product;
113 }
114 for (const auto& d : dim_list.value()) {
115 const size_t non_neg_d = _normalize_non_neg_d(d, in_dim);
116 dim_product *= in.size(non_neg_d);
117 }
118 return dim_product;
119 }
120
121 /**
122 * Returns the number of elements of the output of reducing `in`
123 * over `dim`.
124 */
get_out_numel(const Tensor & in,const exec_aten::optional<int64_t> & dim)125 size_t get_out_numel(
126 const Tensor& in,
127 const exec_aten::optional<int64_t>& dim) {
128 size_t out_numel = 1;
129 if (dim.has_value()) {
130 const auto dim_val = dim.value();
131 if (in.dim() == 0) {
132 ET_CHECK(dim_val == 0 || dim_val == -1);
133 } else {
134 ET_CHECK_VALID_DIM(dim_val, in.dim());
135 }
136 const size_t non_neg_dim = _normalize_non_neg_d(dim_val, in.dim());
137 for (size_t d = 0; d < in.dim(); ++d) {
138 if (d != non_neg_dim) {
139 out_numel *= in.size(d);
140 }
141 }
142 }
143 return out_numel;
144 }
145
146 /**
147 * Returns the number of elements of the output of reducing `in`
148 * over `dim_list`.
149 */
get_out_numel(const Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list)150 size_t get_out_numel(
151 const Tensor& in,
152 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list) {
153 size_t out_numel = 1;
154 if (dim_list.has_value() && dim_list.value().size() != 0) {
155 for (size_t d = 0; d < in.dim(); ++d) {
156 if (!check_dim_in_dim_list(d, in.dim(), dim_list.value())) {
157 out_numel *= in.size(d);
158 }
159 }
160 }
161 return out_numel;
162 }
163
164 /**
165 * Returns the index of the first element in `in` that maps to `out_ix` when
166 * reducing over `dim`. If `dim` is empty, returns `0`.
167 */
get_init_index(const Tensor & in,const exec_aten::optional<int64_t> & dim,const size_t out_ix)168 size_t get_init_index(
169 const Tensor& in,
170 const exec_aten::optional<int64_t>& dim,
171 const size_t out_ix) {
172 if (!dim.has_value()) {
173 return 0;
174 }
175 const auto dim_val = dim.value();
176 if (in.dim() == 0) {
177 ET_CHECK(dim_val == 0 || dim_val == -1);
178 } else {
179 ET_CHECK_VALID_DIM(dim_val, in.dim());
180 }
181 const size_t non_neg_dim = _normalize_non_neg_d(dim_val, in.dim());
182 size_t init_ix = 0;
183 size_t mutable_out_ix = out_ix;
184 auto strides = in.strides();
185 for (int64_t d = in.dim() - 1; d >= 0; d--) {
186 if (d != non_neg_dim) {
187 init_ix += (mutable_out_ix % in.size(d)) * strides[d];
188 mutable_out_ix /= in.size(d);
189 }
190 }
191 return init_ix;
192 }
193
194 /**
195 * Returns the index of the first element in `in` that maps to `out_ix` when
196 * reducing over the list of dimensions in `dim_list`. If `dim_list` is null
197 * or empty, returns `0`
198 */
get_init_index(const Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,const size_t out_ix)199 size_t get_init_index(
200 const Tensor& in,
201 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
202 const size_t out_ix) {
203 if (!dim_list.has_value() || dim_list.value().size() == 0) {
204 return 0;
205 }
206 size_t init_ix = 0;
207 size_t mutable_out_ix = out_ix;
208 auto strides = in.strides();
209 for (int64_t d = in.dim() - 1; d >= 0; d--) {
210 if (!check_dim_in_dim_list(d, in.dim(), dim_list.value())) {
211 init_ix += (mutable_out_ix % in.size(d)) * strides[d];
212 mutable_out_ix /= in.size(d);
213 }
214 }
215 return init_ix;
216 }
217
218 //
219 // Resize out tensor of reduction op
220 //
221
compute_reduced_out_size(const Tensor & in,const exec_aten::optional<int64_t> & dim,bool keepdim,exec_aten::SizesType * sizes_arr)222 size_t compute_reduced_out_size(
223 const Tensor& in,
224 const exec_aten::optional<int64_t>& dim,
225 bool keepdim,
226 exec_aten::SizesType* sizes_arr) {
227 const auto in_dim = in.dim();
228 size_t out_dim = in_dim;
229
230 if (dim.has_value()) {
231 const auto dim_val = dim.value();
232 const size_t non_neg_dim = _normalize_non_neg_d(dim_val, in_dim);
233 for (ssize_t i = 0; i < non_neg_dim; ++i) {
234 sizes_arr[i] = in.size(i);
235 }
236 if (keepdim) {
237 sizes_arr[non_neg_dim] = 1;
238 for (ssize_t i = non_neg_dim + 1; i < in_dim; ++i) {
239 sizes_arr[i] = in.size(i);
240 }
241 } else {
242 for (ssize_t i = non_neg_dim; i < in_dim - 1; ++i) {
243 sizes_arr[i] = in.size(i + 1);
244 }
245 out_dim = in_dim == 0 ? 0 : in_dim - 1;
246 }
247 } else {
248 if (keepdim) {
249 for (size_t i = 0; i < in_dim; ++i) {
250 sizes_arr[i] = 1;
251 }
252 } else {
253 out_dim = 0;
254 }
255 }
256 return out_dim;
257 }
258
compute_reduced_out_size(const Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,bool keepdim,exec_aten::SizesType * sizes_arr)259 size_t compute_reduced_out_size(
260 const Tensor& in,
261 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
262 bool keepdim,
263 exec_aten::SizesType* sizes_arr) {
264 const auto in_dim = in.dim();
265 size_t out_dim = in_dim;
266
267 if (dim_list.has_value() && dim_list.value().size() != 0) {
268 const auto& reduce_dims = dim_list.value();
269 if (keepdim) {
270 for (size_t i = 0; i < in_dim; ++i) {
271 if (check_dim_in_dim_list(i, in_dim, reduce_dims)) {
272 sizes_arr[i] = 1;
273 } else {
274 sizes_arr[i] = in.size(i);
275 }
276 }
277 } else {
278 size_t out_i = 0;
279 for (size_t in_i = 0; in_i < in_dim; ++in_i) {
280 if (!check_dim_in_dim_list(in_i, in_dim, reduce_dims)) {
281 sizes_arr[out_i] = in.size(in_i);
282 out_i++;
283 }
284 }
285 out_dim = out_i;
286 }
287 } else {
288 if (keepdim) {
289 for (size_t i = 0; i < in_dim; ++i) {
290 sizes_arr[i] = 1;
291 }
292 } else {
293 out_dim = 0;
294 }
295 }
296 return out_dim;
297 }
298
resize_reduction_out(const Tensor & in,const exec_aten::optional<int64_t> & dim,bool keepdim,Tensor & out)299 Error resize_reduction_out(
300 const Tensor& in,
301 const exec_aten::optional<int64_t>& dim,
302 bool keepdim,
303 Tensor& out) {
304 exec_aten::SizesType sizes_arr[kTensorDimensionLimit];
305 const auto out_dim = compute_reduced_out_size(in, dim, keepdim, sizes_arr);
306 exec_aten::ArrayRef<exec_aten::SizesType> out_size{
307 sizes_arr, static_cast<size_t>(out_dim)};
308 return resize_tensor(out, out_size);
309 }
310
resize_reduction_out(const Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,bool keepdim,Tensor & out)311 Error resize_reduction_out(
312 const Tensor& in,
313 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
314 bool keepdim,
315 Tensor& out) {
316 exec_aten::SizesType sizes_arr[kTensorDimensionLimit];
317 const auto out_dim =
318 compute_reduced_out_size(in, dim_list, keepdim, sizes_arr);
319 exec_aten::ArrayRef<exec_aten::SizesType> out_size{
320 sizes_arr, static_cast<size_t>(out_dim)};
321 return resize_tensor(out, out_size);
322 }
323
324 #ifndef USE_ATEN_LIB
325
326 /**
327 * Check the validity of arguments for reduction operators.
328 */
check_reduction_args(const Tensor & in,const optional<ArrayRef<int64_t>> & dim_list,bool keepdim,optional<ScalarType> dtype,Tensor & out)329 bool check_reduction_args(
330 const Tensor& in,
331 const optional<ArrayRef<int64_t>>& dim_list,
332 bool keepdim,
333 optional<ScalarType> dtype,
334 Tensor& out) {
335 if (dtype.has_value()) {
336 ET_LOG_AND_RETURN_IF_FALSE(dtype.value() == out.scalar_type());
337 }
338 ET_LOG_AND_RETURN_IF_FALSE(check_dim_list_is_valid(in, dim_list));
339 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
340 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
341
342 return true;
343 }
344
345 /**
346 * Check the validity of arguments for reduction operators that take
347 * a single dimension argument.
348 */
check_reduction_args_single_dim(const Tensor & in,optional<int64_t> dim,bool keepdim,optional<ScalarType> dtype,Tensor & out,bool allow_empty_dim)349 bool check_reduction_args_single_dim(
350 const Tensor& in,
351 optional<int64_t> dim,
352 bool keepdim,
353 optional<ScalarType> dtype,
354 Tensor& out,
355 bool allow_empty_dim) {
356 if (dtype.has_value()) {
357 ET_LOG_AND_RETURN_IF_FALSE(dtype.value() == out.scalar_type());
358 }
359 if (in.dim() == 0) {
360 if (dim.has_value()) {
361 ET_LOG_AND_RETURN_IF_FALSE(dim.value() == 0 || dim.value() == -1);
362 }
363 return true;
364 }
365
366 if (dim.has_value()) {
367 ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim.value(), in.dim()));
368 if (!allow_empty_dim) {
369 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_non_empty_dim(in, dim.value()));
370 }
371 }
372
373 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
374 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
375
376 return true;
377 }
378
check_mean_dim_args(const Tensor & in,optional<ArrayRef<int64_t>> dim_list,bool keepdim,optional<ScalarType> dtype,Tensor & out)379 bool check_mean_dim_args(
380 const Tensor& in,
381 optional<ArrayRef<int64_t>> dim_list,
382 bool keepdim,
383 optional<ScalarType> dtype,
384 Tensor& out) {
385 ET_LOG_AND_RETURN_IF_FALSE(
386 check_reduction_args(in, dim_list, keepdim, dtype, out));
387
388 if (dtype) {
389 ET_LOG_AND_RETURN_IF_FALSE(torch::executor::isFloatingType(dtype.value()));
390 ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == dtype.value());
391 } else {
392 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_floating_type(in));
393 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_floating_type(out));
394 }
395
396 return true;
397 }
398
check_amin_amax_args(const Tensor & in,ArrayRef<int64_t> dim_list,bool keepdim,Tensor & out)399 bool check_amin_amax_args(
400 const Tensor& in,
401 ArrayRef<int64_t> dim_list,
402 bool keepdim,
403 Tensor& out) {
404 ET_LOG_AND_RETURN_IF_FALSE(
405 check_reduction_args(in, dim_list, keepdim, {}, out));
406 ET_LOG_AND_RETURN_IF_FALSE(in.scalar_type() == out.scalar_type());
407
408 return true;
409 }
410
check_argmin_argmax_args(const Tensor & in,optional<int64_t> dim,bool keepdim,Tensor & out)411 bool check_argmin_argmax_args(
412 const Tensor& in,
413 optional<int64_t> dim,
414 bool keepdim,
415 Tensor& out) {
416 ET_LOG_AND_RETURN_IF_FALSE(
417 check_reduction_args_single_dim(in, dim, keepdim, {}, out));
418
419 ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == ScalarType::Long);
420
421 return true;
422 }
423
check_min_max_args(const Tensor & in,int64_t dim,bool keepdim,Tensor & max,Tensor & max_indices)424 bool check_min_max_args(
425 const Tensor& in,
426 int64_t dim,
427 bool keepdim,
428 Tensor& max,
429 Tensor& max_indices) {
430 ET_LOG_AND_RETURN_IF_FALSE(
431 check_reduction_args_single_dim(in, dim, keepdim, {}, max));
432 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, max));
433 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_shape(max, max_indices));
434 ET_LOG_AND_RETURN_IF_FALSE(
435 tensor_is_default_or_channels_last_dim_order(max_indices));
436 ET_LOG_AND_RETURN_IF_FALSE(max_indices.scalar_type() == ScalarType::Long);
437
438 return true;
439 }
440
check_prod_out_args(const Tensor & in,optional<ScalarType> dtype,Tensor & out)441 bool check_prod_out_args(
442 const Tensor& in,
443 optional<ScalarType> dtype,
444 Tensor& out) {
445 if (dtype.has_value()) {
446 ET_LOG_AND_RETURN_IF_FALSE(dtype.value() == out.scalar_type());
447 } else if (isIntegralType(in.scalar_type(), /*includeBool*/ true)) {
448 ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == ScalarType::Long);
449 } else {
450 ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == in.scalar_type());
451 }
452
453 return true;
454 }
455
456 #endif
457
458 } // namespace executor
459 } // namespace torch
460