xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/reduce_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13 #include <cstring>
14 #include <tuple>
15 
16 namespace torch {
17 namespace executor {
18 namespace {
19 
20 template <typename Fn>
apply_on_flat_ix_with_stride_and_base(const Fn & fn,const size_t stride,const size_t base,const size_t start,const size_t end)21 void apply_on_flat_ix_with_stride_and_base(
22     const Fn& fn,
23     const size_t stride,
24     const size_t base,
25     const size_t start,
26     const size_t end) {
27   for (size_t i = start; i <= end; i++) {
28     fn(base + i * stride);
29   }
30 }
31 
32 template <typename Fn>
apply_on_flat_and_dim_ix_with_stride_and_base(const Fn & fn,const size_t stride,const size_t base,const size_t start,const size_t end)33 void apply_on_flat_and_dim_ix_with_stride_and_base(
34     const Fn& fn,
35     const size_t stride,
36     const size_t base,
37     const size_t start,
38     const size_t end) {
39   for (size_t i = start; i <= end; i++) {
40     fn(base + i * stride, i);
41   }
42 }
43 
44 template <typename Fn>
apply_on_flat_ix_with_dim_mask_and_base(const Fn & fn,const Tensor & in,bool * dim_mask,const size_t base,const size_t start,const size_t end)45 void apply_on_flat_ix_with_dim_mask_and_base(
46     const Fn& fn,
47     const Tensor& in,
48     bool* dim_mask,
49     const size_t base,
50     const size_t start,
51     const size_t end) {
52   // Compute innermost dim from dim list
53   size_t inner_dim = in.dim() - 1;
54   while (!dim_mask[inner_dim]) {
55     inner_dim--;
56   }
57 
58   // Initialize array of indices per dimension. This array is used to maintain
59   // the per-dimension index of the element in `in` that is being reduced over
60   // Only the dims that are in the dim list are relevant.
61   size_t dim_index[kTensorDimensionLimit];
62   for (int64_t d = 0; d < in.dim(); d++) {
63     dim_index[d] = 0;
64   }
65 
66   // Gather strides
67   const auto strides = in.strides();
68 
69   // curr_index will always be index of the element from `in` we are currently
70   // reducing. Initialized to the first index from `in` that maps to `out_ix`
71   size_t curr_index = base;
72 
73   size_t apply_fun_counter = 0;
74   while (true) {
75     // Apply reduction to current index
76     if (apply_fun_counter >= start && apply_fun_counter <= end) {
77       fn(curr_index);
78     }
79     apply_fun_counter += 1;
80     if (apply_fun_counter > end) {
81       return;
82     }
83 
84     // Next index to reduce. Increase dim_index[inner_dim] by 1, and curr_index
85     // by strides[inner_dim].
86     dim_index[inner_dim]++;
87     curr_index += strides[inner_dim];
88 
89     // Check if we have reached the end of the innermost dimension
90     if (dim_index[inner_dim] == in.size(inner_dim)) {
91       // If we reached the end, we need to update the indices in dim_index. We
92       // do this by resetting dim_index[inner_dim] to 0, and then incrementing
93       // the index of the next innermost dimension from the dim list by 1.
94       // If when we do this increment, we also reach the end of that dimension,
95       // we need to keep repeating that procedure.
96       // This is similar to doing the carry over when adding 1 to a number.
97 
98       // curr_dim will be the dim from the dim list we are currently updating
99       int64_t curr_dim = inner_dim;
100 
101       while (dim_index[curr_dim] == in.size(curr_dim)) {
102         if (curr_dim == 0) {
103           // Exit function if we've reached the end of the outermost dimension
104           return;
105         }
106         // Reset dim_index[curr_dim] to 0. We need to update curr_index
107         // accordingly. Reseting dim_index[curr_dim] from in.size(curr_dim)
108         // to 0 means we need to subtract in.size(curr_dim) * strides[curr_dim]
109         // from curr_index. However in.size(curr_dim) * strides[curr_dim] is
110         // equal to strides[curr_dim - 1]. Notice that curr_dim > 0 at this
111         // point in the execution
112         dim_index[curr_dim] = 0;
113         curr_index -= strides[curr_dim - 1];
114 
115         // Decrease current dim
116         curr_dim--;
117         while (curr_dim >= 0) {
118           // Stop if curr_dim is in the dim list
119           if (dim_mask[curr_dim]) {
120             break;
121           }
122           // Keep decreasing if curr_dim is not in the dim list
123           curr_dim--;
124         }
125         // Exit function if curr_dim was decreased to -1. This means we have
126         // reduced over all the elements we needed to.
127         if (curr_dim < 0) {
128           return;
129         }
130 
131         // At this point in the execution, curr_dim is the next innermost
132         // dimension. Increase dim_index[curr_dim] by 1 and update curr_index
133         // accordingly.
134         dim_index[curr_dim]++;
135         curr_index += strides[curr_dim];
136       }
137     }
138   }
139 }
140 
141 } // namespace
142 
143 //
144 // Helper Functions
145 //
146 
147 ET_NODISCARD bool check_dim_list_is_valid(
148     const exec_aten::Tensor& in,
149     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list);
150 
151 bool check_dim_in_dim_list(
152     const size_t dim,
153     const size_t max_dim,
154     const exec_aten::ArrayRef<int64_t>& dim_list);
155 
156 size_t get_reduced_dim_product(
157     const exec_aten::Tensor& in,
158     const exec_aten::optional<int64_t>& dim);
159 
160 size_t get_reduced_dim_product(
161     const exec_aten::Tensor& in,
162     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list);
163 
164 size_t get_out_numel(
165     const exec_aten::Tensor& in,
166     const exec_aten::optional<int64_t>& dim);
167 
168 size_t get_out_numel(
169     const exec_aten::Tensor& in,
170     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list);
171 
172 size_t get_init_index(
173     const exec_aten::Tensor& in,
174     const exec_aten::optional<int64_t>& dim,
175     const size_t out_ix);
176 
177 size_t get_init_index(
178     const exec_aten::Tensor& in,
179     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
180     const size_t out_ix);
181 
182 //
183 // Iteration Functions
184 //
185 
186 /**
187  * Useful to reduce a tensor `in` over a given dimension `dim` using the
188  * reduce function `fn`, which should have the following signature:
189  * void fn(const size_t size, const size_t stride, const size_t base_ix)
190  * where `size` and `stride` are the size and stride of the dimension being
191  * reduced and `base_ix` is the index of the first element of the reduction.
192  */
193 template <typename Fn>
apply_over_dim(const Fn & fn,const exec_aten::Tensor & in,const exec_aten::optional<int64_t> & dim)194 void apply_over_dim(
195     const Fn& fn,
196     const exec_aten::Tensor& in,
197     const exec_aten::optional<int64_t>& dim) {
198   // If dim is null, apply fn over the entire tensor
199   if (!dim.has_value()) {
200     fn(in.numel(), 1, 0);
201     return;
202   }
203 
204   if (in.dim() != 0) {
205     ET_CHECK_VALID_DIM(dim.value(), in.dim());
206   } else {
207     // Special handling for 0-D tensor; 0 or -1 is valid for PyTorch code
208     // `torch.mean(torch.tensor(2, dtype=float), dim=-1)`
209     ET_CHECK(dim.value() == 0 || dim.value() == -1);
210     fn(in.numel(), 1, 0);
211     return;
212   }
213 
214   if (in.numel() == 0) {
215     return;
216   }
217 
218   const size_t d = ET_NORMALIZE_IX(dim.value(), in.dim());
219 
220   const size_t size = in.size(d);
221   const size_t stride = in.strides()[d];
222   const size_t outer_size = getLeadingDims(in, d);
223   const size_t outer_stride = size * stride;
224   // Loop through all outer dimensions
225   for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
226     size_t outer = outer_idx * outer_stride;
227     // Loop through all inner dimensions
228     for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) {
229       size_t base = outer + inner_idx;
230       fn(size, stride, base);
231     }
232   }
233 }
234 
235 /**
236  * Useful to reduce a tensor `in` over a given dimension `dim` for the output
237  * element at index `out_ix` using the reduce function `fn`, which
238  * should have the following signature:
239  * `void fn(const size_t in_ix, const size_t dim_ix)`
240  * where `in_ix` is the flat index of each element from `in` that maps to
241  * `out_ix` and `dim_ix` is its index along `dim`.
242  */
243 template <typename Fn>
244 void apply_over_dim(
245     const Fn& fn,
246     const exec_aten::Tensor& in,
247     const exec_aten::optional<int64_t>& dim,
248     const size_t out_ix,
249     const int64_t start = 0,
250     const int64_t end = -1) {
251   if (dim.has_value()) {
252     if (in.dim() != 0) {
253       ET_CHECK_VALID_DIM(dim.value(), in.dim());
254     } else {
255       ET_CHECK(dim.value() == 0 || dim.value() == -1);
256     }
257   }
258   ET_CHECK_MSG(
259       out_ix < get_out_numel(in, dim),
260       "Out index %zd is out of bounds",
261       out_ix);
262 
263   if (in.numel() == 0) {
264     return;
265   }
266 
267   const size_t iter_length = get_reduced_dim_product(in, dim);
268   const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length);
269   const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length);
270   const size_t ustart = std::max(normalized_start, size_t(0));
271   const size_t uend = std::min(normalized_end, iter_length - 1);
272 
273   // If dim is null, iterate over the entire tensor
274   if (!dim.has_value()) {
275     apply_on_flat_and_dim_ix_with_stride_and_base(
276         fn, /*stride=*/1, /*base=*/0, ustart, uend);
277     return;
278   }
279 
280   // Compute the starting base index
281   const size_t base = get_init_index(in, dim, out_ix);
282 
283   // Compute non-negative dimension value from dim value
284   const size_t d = ET_NORMALIZE_IX(dim.value(), in.dim());
285 
286   if (in.dim() == 0) {
287     fn(base, ustart);
288   } else {
289     apply_on_flat_and_dim_ix_with_stride_and_base(
290         fn, in.strides()[d], base, ustart, uend);
291   }
292 }
293 
294 /**
295  * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
296  * for the output element at index `out_ix` using the reduce function
297  * `fn`, which should have the following signature:
298  * `void fn(const size_t in_ix)`
299  * where `in_ix` is the index of each element from `in` that maps to `out_ix`
300  */
301 template <typename Fn>
302 void apply_over_dim_list(
303     const Fn& fn,
304     const exec_aten::Tensor& in,
305     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
306     const size_t out_ix,
307     const int64_t start = 0,
308     const int64_t end = -1) {
309   ET_CHECK(check_dim_list_is_valid(in, dim_list));
310   ET_CHECK_MSG(
311       out_ix < get_out_numel(in, dim_list),
312       "Out index %zd is out of bounds",
313       out_ix);
314 
315   if (in.numel() == 0) {
316     return;
317   }
318 
319   const size_t iter_length = get_reduced_dim_product(in, dim_list);
320   const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length);
321   const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length);
322   const size_t ustart = std::max(normalized_start, size_t(0));
323   const size_t uend = std::min(normalized_end, iter_length - 1);
324 
325   // If dim_list is null or empty, or in is 0-D, iterate over the entire tensor
326   if (!dim_list.has_value() || dim_list.value().size() == 0 || in.dim() == 0) {
327     apply_on_flat_ix_with_stride_and_base(
328         fn, /*stride=*/1, /*base=*/0, ustart, uend);
329     return;
330   }
331 
332   // Create is_in_dims to check whether each dimension is in the dim list
333   bool is_in_dim_list[kTensorDimensionLimit];
334   memset(is_in_dim_list, false, sizeof(is_in_dim_list));
335   for (const auto& d : dim_list.value()) {
336     const size_t non_neg_d = d < 0 ? d + in.dim() : d;
337     is_in_dim_list[non_neg_d] = true;
338   }
339 
340   // Compute the starting base index
341   const size_t base = get_init_index(in, dim_list, out_ix);
342 
343   apply_on_flat_ix_with_dim_mask_and_base(
344       fn, in, is_in_dim_list, base, ustart, uend);
345 }
346 
347 //
348 // Reduce Functions
349 //
350 
351 /**
352  * Useful to reduce a tensor `in` over a dimension `dim` for the output element
353  * at index `out_ix`, first applying the map `map_fun` to each element of `in`,
354  * which should have the signature: CTYPE_OUT map_fun(CTYPE_IN v)
355  * and then reducing using `reduce_fun`, which should have the signature:
356  * `CTYPE_OUT reduce_fun(CTYPE_OUT val, long ix, CTYPE_OUT acc_val, long
357  * acc_ix)`
358  *
359  * Common usage:
360  *
361  * CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
362  * for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
363  *   out_data[out_ix] = map_reduce_over_dim<CTYPE_IN, CTYPE_OUT>(
364  *       [](CTYPE_IN v) {
365  *         // map operation on `v`, outputs `val`
366  *       },
367  *       [](CTYPE_OUT val, long ix, CTYPE_OUT acc_val, long acc_ix) {
368  *         // reduce operation on `acc_val` and `acc_ix` using `val` and `ix`,
369  *         // outputs {`acc_val`, `acc_ix`} pair
370  *       in,
371  *       dim_list,
372  *       out_ix);
373  * }
374  */
375 template <
376     typename CTYPE_IN,
377     typename CTYPE_OUT,
378     typename MapOp,
379     typename ReduceOp>
map_reduce_over_dim(const MapOp & map_fun,const ReduceOp & reduce_fun,const exec_aten::Tensor & in,const exec_aten::optional<int64_t> & dim,const size_t out_ix)380 std::tuple<CTYPE_OUT, long> map_reduce_over_dim(
381     const MapOp& map_fun,
382     const ReduceOp& reduce_fun,
383     const exec_aten::Tensor& in,
384     const exec_aten::optional<int64_t>& dim,
385     const size_t out_ix) {
386   if (dim.has_value()) {
387     if (in.dim() != 0) {
388       ET_CHECK_VALID_DIM(dim.value(), in.dim());
389     } else {
390       ET_CHECK(dim.value() == 0 || dim.value() == -1);
391     }
392   }
393 
394   ET_CHECK_MSG(
395       out_ix < get_out_numel(in, dim),
396       "Out index %zd is out of bounds",
397       out_ix);
398 
399   ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty");
400 
401   const size_t init_index = get_init_index(in, dim, out_ix);
402 
403   const CTYPE_IN* const in_data = in.const_data_ptr<CTYPE_IN>();
404   CTYPE_OUT acc_val = map_fun(in_data[init_index]);
405   long acc_ix = 0;
406 
407   if (in.numel() == 1) {
408     return std::tuple<CTYPE_OUT, long>{acc_val, acc_ix};
409   }
410 
411   apply_over_dim(
412       [&acc_val, &acc_ix, reduce_fun, map_fun, in_data](
413           const size_t in_ix, const size_t dim_ix) {
414         std::tuple<CTYPE_OUT, long> res =
415             reduce_fun(map_fun(in_data[in_ix]), dim_ix, acc_val, acc_ix);
416         acc_val = std::get<0>(res);
417         acc_ix = std::get<1>(res);
418       },
419       in,
420       dim,
421       out_ix,
422       1,
423       -1);
424 
425   return std::tuple<CTYPE_OUT, long>{acc_val, acc_ix};
426 }
427 
428 /**
429  * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
430  * for the output element at index `out_ix`, first applying the map `map_fun`
431  * to each element of `in`, which should have the signature:
432  * `CTYPE_OUT map_fun(CTYPE_IN v)`
433  * and then reducing using `reduce_fun`, which should have the signature:
434  * `CTYPE_OUT reduce_fun(CTYPE_OUT v, CTYPE_OUT acc)`
435  *
436  * Common usage:
437  *
438  * CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
439  * for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
440  *   out_data[out_ix] = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
441  *       [](CTYPE_IN v) {
442  *         // map operation on `v`, outputs `outv`
443  *       },
444  *       [](CTYPE_OUT outv, CTYPE_OUT acc) {
445  *         // reduce operation on `acc` using `v`, outputs `acc`
446  *       in,
447  *       dim_list,
448  *       out_ix);
449  * }
450  */
451 template <
452     typename CTYPE_IN,
453     typename CTYPE_OUT,
454     typename MapOp,
455     typename ReduceOp>
map_reduce_over_dim_list(const MapOp & map_fun,const ReduceOp & reduce_fun,const exec_aten::Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,const size_t out_ix)456 CTYPE_OUT map_reduce_over_dim_list(
457     const MapOp& map_fun,
458     const ReduceOp& reduce_fun,
459     const exec_aten::Tensor& in,
460     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
461     const size_t out_ix) {
462   ET_CHECK(check_dim_list_is_valid(in, dim_list));
463 
464   ET_CHECK_MSG(
465       out_ix < get_out_numel(in, dim_list),
466       "Out index %zd is out of bounds",
467       out_ix);
468 
469   ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty");
470 
471   const size_t init_index = get_init_index(in, dim_list, out_ix);
472 
473   const CTYPE_IN* const in_data = in.const_data_ptr<CTYPE_IN>();
474   CTYPE_OUT acc_val = map_fun(in_data[init_index]);
475 
476   if (in.numel() == 1) {
477     return acc_val;
478   }
479 
480   apply_over_dim_list(
481       [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
482         acc_val = reduce_fun(map_fun(in_data[in_ix]), acc_val);
483       },
484       in,
485       dim_list,
486       out_ix,
487       1,
488       -1);
489 
490   return acc_val;
491 }
492 
493 /**
494  * Useful to reduce a tensor `in` over a dimension `dim` for the output element
495  * at index `out_ix` using the reduce function `reduce_fun`, which should have
496  * the following signature:
497  * `CTYPE reduce_fun(CTYPE val, long ix, CTYPE acc_val, long acc_ix)`
498  *
499  * Common usage:
500  *
501  * CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
502  * for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
503  *   out_data[out_ix] = reduce_over_dim<CTYPE>(
504  *       [](CTYPE val, long ix, CTYPE acc_val, long acc_ix) {
505  *         // reduce operation on `acc_val` and `acc_ix` using `val` and `ix`,
506  *         // outputs {`acc_val`, `acc_ix`} pair
507  *       },
508  *       in,
509  *       dim_list,
510  *       out_ix);
511  * }
512  */
513 template <typename CTYPE, typename ReduceOp>
reduce_over_dim(const ReduceOp & reduce_fun,const exec_aten::Tensor & in,const exec_aten::optional<int64_t> & dim,const size_t out_ix)514 std::tuple<CTYPE, long> reduce_over_dim(
515     const ReduceOp& reduce_fun,
516     const exec_aten::Tensor& in,
517     const exec_aten::optional<int64_t>& dim,
518     const size_t out_ix) {
519   return map_reduce_over_dim<CTYPE, CTYPE>(
520       [](CTYPE v) { return v; }, reduce_fun, in, dim, out_ix);
521 }
522 
523 /**
524  * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
525  * for the output element at index `out_ix` using the reduce function
526  * `reduce_fun`, which should have the following signature:
527  * `CTYPE reduce_fun(CTYPE v, CTYPE acc)`
528  *
529  * Common usage:
530  *
531  * CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
532  * for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
533  *   out_data[out_ix] = reduce_over_dim_list<CTYPE>(
534  *       [](CTYPE v, CTYPE acc) {
535  *         // reduce operation on `acc` using `v`, outputs `acc`
536  *       },
537  *       in,
538  *       dim_list,
539  *       out_ix);
540  * }
541  */
542 template <typename CTYPE, typename ReduceOp>
reduce_over_dim_list(const ReduceOp & reduce_fun,const exec_aten::Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,const size_t out_ix)543 CTYPE reduce_over_dim_list(
544     const ReduceOp& reduce_fun,
545     const exec_aten::Tensor& in,
546     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
547     const size_t out_ix) {
548   return map_reduce_over_dim_list<CTYPE, CTYPE>(
549       [](CTYPE v) { return v; }, reduce_fun, in, dim_list, out_ix);
550 }
551 
552 //
553 // Compute reduced out tensor size and dim
554 //
555 
556 size_t compute_reduced_out_size(
557     const exec_aten::Tensor& in,
558     const exec_aten::optional<int64_t>& dim,
559     bool keepdim,
560     exec_aten::SizesType* sizes_arr);
561 
562 size_t compute_reduced_out_size(
563     const exec_aten::Tensor& in,
564     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
565     bool keepdim,
566     exec_aten::SizesType* sizes_arr);
567 
compute_reduced_out_dim(const exec_aten::Tensor & in,const exec_aten::optional<int64_t> & dim,bool keepdim)568 inline ssize_t compute_reduced_out_dim(
569     const exec_aten::Tensor& in,
570     const exec_aten::optional<int64_t>& dim,
571     bool keepdim) {
572   return (
573       keepdim                                ? in.dim()
574           : dim.has_value() && in.dim() != 0 ? in.dim() - 1
575                                              : 0);
576 }
577 
compute_reduced_out_dim(const exec_aten::Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,bool keepdim)578 inline ssize_t compute_reduced_out_dim(
579     const exec_aten::Tensor& in,
580     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
581     bool keepdim) {
582   return (
583       keepdim ? in.dim()
584           : dim_list.has_value() && dim_list.value().size() != 0 &&
585               in.dim() != 0
586 
587           ? in.dim() - dim_list.value().size()
588           : 0);
589 }
590 
591 //
592 // Resize out tensor of reduction op
593 //
594 
595 Error resize_reduction_out(
596     const exec_aten::Tensor& in,
597     const exec_aten::optional<int64_t>& dim,
598     bool keepdim,
599     exec_aten::Tensor& out);
600 
601 Error resize_reduction_out(
602     const exec_aten::Tensor& in,
603     const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
604     bool keepdim,
605     exec_aten::Tensor& out);
606 
607 #ifndef USE_ATEN_LIB
608 bool check_reduction_args(
609     const Tensor& in,
610     const optional<ArrayRef<int64_t>>& dim_list,
611     bool keepdim,
612     optional<ScalarType> dtype,
613     Tensor& out);
614 
615 bool check_reduction_args_single_dim(
616     const Tensor& in,
617     optional<int64_t> dim,
618     bool keepdim,
619     optional<ScalarType> dtype,
620     Tensor& out,
621     bool allow_empty_dim = false);
622 
623 bool check_mean_dim_args(
624     const Tensor& in,
625     optional<ArrayRef<int64_t>> dim_list,
626     bool keepdim,
627     optional<ScalarType> dtype,
628     Tensor& out);
629 
630 bool check_amin_amax_args(
631     const Tensor& in,
632     ArrayRef<int64_t> dim_list,
633     bool keepdim,
634     Tensor& out);
635 
636 bool check_argmin_argmax_args(
637     const Tensor& in,
638     optional<int64_t> dim,
639     bool keepdim,
640     Tensor& out);
641 
642 bool check_min_max_args(
643     const Tensor& in,
644     int64_t dim,
645     bool keepdim,
646     Tensor& max,
647     Tensor& max_indices);
648 
649 bool check_prod_out_args(
650     const Tensor& in,
651     optional<ScalarType> dtype,
652     Tensor& out);
653 
654 #endif
655 
656 } // namespace executor
657 } // namespace torch
658