xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/reduce_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
12 #include <executorch/runtime/platform/assert.h>
13 #include <cstring>
14 
15 namespace torch {
16 namespace executor {
17 
18 using Tensor = exec_aten::Tensor;
19 
20 //
21 // Helper Functions
22 //
23 
24 // Normalize the dimension by adding in_dim if d < 0; for 0-D, clamp to 0
_normalize_non_neg_d(ssize_t d,ssize_t in_dim)25 inline size_t _normalize_non_neg_d(ssize_t d, ssize_t in_dim) {
26   if (in_dim == 0 && (d == 0 || d == -1)) {
27     return 0;
28   }
29   if (d < 0) {
30     return d + in_dim;
31   }
32   return d;
33 }
34 
check_dim_list_is_valid(const exec_aten::Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list)35 ET_NODISCARD bool check_dim_list_is_valid(
36     const exec_aten::Tensor& in,
37     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list) {
38   if (dim_list.has_value() && dim_list.value().size() != 0) {
39     const auto& reduce_dims = dim_list.value();
40     bool dim_exist[kTensorDimensionLimit];
41     memset(dim_exist, false, sizeof(dim_exist));
42     for (const auto& d : reduce_dims) {
43       if (in.dim() == 0) {
44         ET_LOG_AND_RETURN_IF_FALSE(d == 0 || d == -1);
45       } else {
46         ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(d, in.dim()));
47       }
48 
49       const size_t non_neg_d = _normalize_non_neg_d(d, in.dim());
50       ET_LOG_AND_RETURN_IF_FALSE(
51           non_neg_d < kTensorDimensionLimit && non_neg_d >= 0);
52 
53       ET_LOG_MSG_AND_RETURN_IF_FALSE(
54           dim_exist[non_neg_d] == false,
55           "dim %zd appears multiple times in the list of dims",
56           non_neg_d);
57       dim_exist[non_neg_d] = true;
58     }
59   }
60 
61   return true;
62 }
63 
check_dim_in_dim_list(const size_t dim,const size_t max_dim,const exec_aten::ArrayRef<int64_t> & dim_list)64 bool check_dim_in_dim_list(
65     const size_t dim,
66     const size_t max_dim,
67     const exec_aten::ArrayRef<int64_t>& dim_list) {
68   for (const auto& d : dim_list) {
69     const size_t non_neg_dim = _normalize_non_neg_d(d, max_dim);
70     if (dim == non_neg_dim) {
71       return true;
72     }
73   }
74   return false;
75 }
76 
77 /**
78  * Returns the product of the sizes of all reduction dims.
79  */
get_reduced_dim_product(const Tensor & in,const exec_aten::optional<int64_t> & dim)80 size_t get_reduced_dim_product(
81     const Tensor& in,
82     const exec_aten::optional<int64_t>& dim) {
83   if (in.dim() == 0) {
84     return 1;
85   }
86   size_t dim_product = 1;
87   if (!dim.has_value()) {
88     for (size_t i = 0; i < in.dim(); ++i) {
89       dim_product *= in.size(i);
90     }
91     return dim_product;
92   }
93   const size_t d = _normalize_non_neg_d(dim.value(), in.dim());
94   return in.size(d);
95 }
96 
97 /**
98  * Returns the product of the sizes of all reduction dims.
99  */
get_reduced_dim_product(const Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list)100 size_t get_reduced_dim_product(
101     const Tensor& in,
102     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list) {
103   if (in.dim() == 0) {
104     return 1;
105   }
106   size_t dim_product = 1;
107   const size_t in_dim = in.dim();
108   if (!dim_list.has_value() || dim_list.value().size() == 0) {
109     for (size_t i = 0; i < in.dim(); ++i) {
110       dim_product *= in.size(i);
111     }
112     return dim_product;
113   }
114   for (const auto& d : dim_list.value()) {
115     const size_t non_neg_d = _normalize_non_neg_d(d, in_dim);
116     dim_product *= in.size(non_neg_d);
117   }
118   return dim_product;
119 }
120 
121 /**
122  * Returns the number of elements of the output of reducing `in`
123  * over `dim`.
124  */
get_out_numel(const Tensor & in,const exec_aten::optional<int64_t> & dim)125 size_t get_out_numel(
126     const Tensor& in,
127     const exec_aten::optional<int64_t>& dim) {
128   size_t out_numel = 1;
129   if (dim.has_value()) {
130     const auto dim_val = dim.value();
131     if (in.dim() == 0) {
132       ET_CHECK(dim_val == 0 || dim_val == -1);
133     } else {
134       ET_CHECK_VALID_DIM(dim_val, in.dim());
135     }
136     const size_t non_neg_dim = _normalize_non_neg_d(dim_val, in.dim());
137     for (size_t d = 0; d < in.dim(); ++d) {
138       if (d != non_neg_dim) {
139         out_numel *= in.size(d);
140       }
141     }
142   }
143   return out_numel;
144 }
145 
146 /**
147  * Returns the number of elements of the output of reducing `in`
148  * over `dim_list`.
149  */
get_out_numel(const Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list)150 size_t get_out_numel(
151     const Tensor& in,
152     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list) {
153   size_t out_numel = 1;
154   if (dim_list.has_value() && dim_list.value().size() != 0) {
155     for (size_t d = 0; d < in.dim(); ++d) {
156       if (!check_dim_in_dim_list(d, in.dim(), dim_list.value())) {
157         out_numel *= in.size(d);
158       }
159     }
160   }
161   return out_numel;
162 }
163 
164 /**
165  * Returns the index of the first element in `in` that maps to `out_ix` when
166  * reducing over `dim`. If `dim` is empty, returns `0`.
167  */
get_init_index(const Tensor & in,const exec_aten::optional<int64_t> & dim,const size_t out_ix)168 size_t get_init_index(
169     const Tensor& in,
170     const exec_aten::optional<int64_t>& dim,
171     const size_t out_ix) {
172   if (!dim.has_value()) {
173     return 0;
174   }
175   const auto dim_val = dim.value();
176   if (in.dim() == 0) {
177     ET_CHECK(dim_val == 0 || dim_val == -1);
178   } else {
179     ET_CHECK_VALID_DIM(dim_val, in.dim());
180   }
181   const size_t non_neg_dim = _normalize_non_neg_d(dim_val, in.dim());
182   size_t init_ix = 0;
183   size_t mutable_out_ix = out_ix;
184   auto strides = in.strides();
185   for (int64_t d = in.dim() - 1; d >= 0; d--) {
186     if (d != non_neg_dim) {
187       init_ix += (mutable_out_ix % in.size(d)) * strides[d];
188       mutable_out_ix /= in.size(d);
189     }
190   }
191   return init_ix;
192 }
193 
194 /**
195  * Returns the index of the first element in `in` that maps to `out_ix` when
196  * reducing over the list of dimensions in `dim_list`. If `dim_list` is null
197  * or empty, returns `0`
198  */
get_init_index(const Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,const size_t out_ix)199 size_t get_init_index(
200     const Tensor& in,
201     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
202     const size_t out_ix) {
203   if (!dim_list.has_value() || dim_list.value().size() == 0) {
204     return 0;
205   }
206   size_t init_ix = 0;
207   size_t mutable_out_ix = out_ix;
208   auto strides = in.strides();
209   for (int64_t d = in.dim() - 1; d >= 0; d--) {
210     if (!check_dim_in_dim_list(d, in.dim(), dim_list.value())) {
211       init_ix += (mutable_out_ix % in.size(d)) * strides[d];
212       mutable_out_ix /= in.size(d);
213     }
214   }
215   return init_ix;
216 }
217 
218 //
219 // Resize out tensor of reduction op
220 //
221 
compute_reduced_out_size(const Tensor & in,const exec_aten::optional<int64_t> & dim,bool keepdim,exec_aten::SizesType * sizes_arr)222 size_t compute_reduced_out_size(
223     const Tensor& in,
224     const exec_aten::optional<int64_t>& dim,
225     bool keepdim,
226     exec_aten::SizesType* sizes_arr) {
227   const auto in_dim = in.dim();
228   size_t out_dim = in_dim;
229 
230   if (dim.has_value()) {
231     const auto dim_val = dim.value();
232     const size_t non_neg_dim = _normalize_non_neg_d(dim_val, in_dim);
233     for (ssize_t i = 0; i < non_neg_dim; ++i) {
234       sizes_arr[i] = in.size(i);
235     }
236     if (keepdim) {
237       sizes_arr[non_neg_dim] = 1;
238       for (ssize_t i = non_neg_dim + 1; i < in_dim; ++i) {
239         sizes_arr[i] = in.size(i);
240       }
241     } else {
242       for (ssize_t i = non_neg_dim; i < in_dim - 1; ++i) {
243         sizes_arr[i] = in.size(i + 1);
244       }
245       out_dim = in_dim == 0 ? 0 : in_dim - 1;
246     }
247   } else {
248     if (keepdim) {
249       for (size_t i = 0; i < in_dim; ++i) {
250         sizes_arr[i] = 1;
251       }
252     } else {
253       out_dim = 0;
254     }
255   }
256   return out_dim;
257 }
258 
compute_reduced_out_size(const Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,bool keepdim,exec_aten::SizesType * sizes_arr)259 size_t compute_reduced_out_size(
260     const Tensor& in,
261     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
262     bool keepdim,
263     exec_aten::SizesType* sizes_arr) {
264   const auto in_dim = in.dim();
265   size_t out_dim = in_dim;
266 
267   if (dim_list.has_value() && dim_list.value().size() != 0) {
268     const auto& reduce_dims = dim_list.value();
269     if (keepdim) {
270       for (size_t i = 0; i < in_dim; ++i) {
271         if (check_dim_in_dim_list(i, in_dim, reduce_dims)) {
272           sizes_arr[i] = 1;
273         } else {
274           sizes_arr[i] = in.size(i);
275         }
276       }
277     } else {
278       size_t out_i = 0;
279       for (size_t in_i = 0; in_i < in_dim; ++in_i) {
280         if (!check_dim_in_dim_list(in_i, in_dim, reduce_dims)) {
281           sizes_arr[out_i] = in.size(in_i);
282           out_i++;
283         }
284       }
285       out_dim = out_i;
286     }
287   } else {
288     if (keepdim) {
289       for (size_t i = 0; i < in_dim; ++i) {
290         sizes_arr[i] = 1;
291       }
292     } else {
293       out_dim = 0;
294     }
295   }
296   return out_dim;
297 }
298 
resize_reduction_out(const Tensor & in,const exec_aten::optional<int64_t> & dim,bool keepdim,Tensor & out)299 Error resize_reduction_out(
300     const Tensor& in,
301     const exec_aten::optional<int64_t>& dim,
302     bool keepdim,
303     Tensor& out) {
304   exec_aten::SizesType sizes_arr[kTensorDimensionLimit];
305   const auto out_dim = compute_reduced_out_size(in, dim, keepdim, sizes_arr);
306   exec_aten::ArrayRef<exec_aten::SizesType> out_size{
307       sizes_arr, static_cast<size_t>(out_dim)};
308   return resize_tensor(out, out_size);
309 }
310 
resize_reduction_out(const Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,bool keepdim,Tensor & out)311 Error resize_reduction_out(
312     const Tensor& in,
313     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
314     bool keepdim,
315     Tensor& out) {
316   exec_aten::SizesType sizes_arr[kTensorDimensionLimit];
317   const auto out_dim =
318       compute_reduced_out_size(in, dim_list, keepdim, sizes_arr);
319   exec_aten::ArrayRef<exec_aten::SizesType> out_size{
320       sizes_arr, static_cast<size_t>(out_dim)};
321   return resize_tensor(out, out_size);
322 }
323 
324 #ifndef USE_ATEN_LIB
325 
326 /**
327  * Check the validity of arguments for reduction operators.
328  */
check_reduction_args(const Tensor & in,const optional<ArrayRef<int64_t>> & dim_list,bool keepdim,optional<ScalarType> dtype,Tensor & out)329 bool check_reduction_args(
330     const Tensor& in,
331     const optional<ArrayRef<int64_t>>& dim_list,
332     bool keepdim,
333     optional<ScalarType> dtype,
334     Tensor& out) {
335   if (dtype.has_value()) {
336     ET_LOG_AND_RETURN_IF_FALSE(dtype.value() == out.scalar_type());
337   }
338   ET_LOG_AND_RETURN_IF_FALSE(check_dim_list_is_valid(in, dim_list));
339   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
340   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
341 
342   return true;
343 }
344 
345 /**
346  * Check the validity of arguments for reduction operators that take
347  * a single dimension argument.
348  */
check_reduction_args_single_dim(const Tensor & in,optional<int64_t> dim,bool keepdim,optional<ScalarType> dtype,Tensor & out,bool allow_empty_dim)349 bool check_reduction_args_single_dim(
350     const Tensor& in,
351     optional<int64_t> dim,
352     bool keepdim,
353     optional<ScalarType> dtype,
354     Tensor& out,
355     bool allow_empty_dim) {
356   if (dtype.has_value()) {
357     ET_LOG_AND_RETURN_IF_FALSE(dtype.value() == out.scalar_type());
358   }
359   if (in.dim() == 0) {
360     if (dim.has_value()) {
361       ET_LOG_AND_RETURN_IF_FALSE(dim.value() == 0 || dim.value() == -1);
362     }
363     return true;
364   }
365 
366   if (dim.has_value()) {
367     ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim.value(), in.dim()));
368     if (!allow_empty_dim) {
369       ET_LOG_AND_RETURN_IF_FALSE(tensor_has_non_empty_dim(in, dim.value()));
370     }
371   }
372 
373   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
374   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
375 
376   return true;
377 }
378 
check_mean_dim_args(const Tensor & in,optional<ArrayRef<int64_t>> dim_list,bool keepdim,optional<ScalarType> dtype,Tensor & out)379 bool check_mean_dim_args(
380     const Tensor& in,
381     optional<ArrayRef<int64_t>> dim_list,
382     bool keepdim,
383     optional<ScalarType> dtype,
384     Tensor& out) {
385   ET_LOG_AND_RETURN_IF_FALSE(
386       check_reduction_args(in, dim_list, keepdim, dtype, out));
387 
388   if (dtype) {
389     ET_LOG_AND_RETURN_IF_FALSE(torch::executor::isFloatingType(dtype.value()));
390     ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == dtype.value());
391   } else {
392     ET_LOG_AND_RETURN_IF_FALSE(tensor_is_floating_type(in));
393     ET_LOG_AND_RETURN_IF_FALSE(tensor_is_floating_type(out));
394   }
395 
396   return true;
397 }
398 
check_amin_amax_args(const Tensor & in,ArrayRef<int64_t> dim_list,bool keepdim,Tensor & out)399 bool check_amin_amax_args(
400     const Tensor& in,
401     ArrayRef<int64_t> dim_list,
402     bool keepdim,
403     Tensor& out) {
404   ET_LOG_AND_RETURN_IF_FALSE(
405       check_reduction_args(in, dim_list, keepdim, {}, out));
406   ET_LOG_AND_RETURN_IF_FALSE(in.scalar_type() == out.scalar_type());
407 
408   return true;
409 }
410 
check_argmin_argmax_args(const Tensor & in,optional<int64_t> dim,bool keepdim,Tensor & out)411 bool check_argmin_argmax_args(
412     const Tensor& in,
413     optional<int64_t> dim,
414     bool keepdim,
415     Tensor& out) {
416   ET_LOG_AND_RETURN_IF_FALSE(
417       check_reduction_args_single_dim(in, dim, keepdim, {}, out));
418 
419   ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == ScalarType::Long);
420 
421   return true;
422 }
423 
check_min_max_args(const Tensor & in,int64_t dim,bool keepdim,Tensor & max,Tensor & max_indices)424 bool check_min_max_args(
425     const Tensor& in,
426     int64_t dim,
427     bool keepdim,
428     Tensor& max,
429     Tensor& max_indices) {
430   ET_LOG_AND_RETURN_IF_FALSE(
431       check_reduction_args_single_dim(in, dim, keepdim, {}, max));
432   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, max));
433   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_shape(max, max_indices));
434   ET_LOG_AND_RETURN_IF_FALSE(
435       tensor_is_default_or_channels_last_dim_order(max_indices));
436   ET_LOG_AND_RETURN_IF_FALSE(max_indices.scalar_type() == ScalarType::Long);
437 
438   return true;
439 }
440 
check_prod_out_args(const Tensor & in,optional<ScalarType> dtype,Tensor & out)441 bool check_prod_out_args(
442     const Tensor& in,
443     optional<ScalarType> dtype,
444     Tensor& out) {
445   if (dtype.has_value()) {
446     ET_LOG_AND_RETURN_IF_FALSE(dtype.value() == out.scalar_type());
447   } else if (isIntegralType(in.scalar_type(), /*includeBool*/ true)) {
448     ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == ScalarType::Long);
449   } else {
450     ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == in.scalar_type());
451   }
452 
453   return true;
454 }
455 
456 #endif
457 
458 } // namespace executor
459 } // namespace torch
460