1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13 #include <cstring>
14 #include <tuple>
15
16 namespace torch {
17 namespace executor {
18 namespace {
19
20 template <typename Fn>
apply_on_flat_ix_with_stride_and_base(const Fn & fn,const size_t stride,const size_t base,const size_t start,const size_t end)21 void apply_on_flat_ix_with_stride_and_base(
22 const Fn& fn,
23 const size_t stride,
24 const size_t base,
25 const size_t start,
26 const size_t end) {
27 for (size_t i = start; i <= end; i++) {
28 fn(base + i * stride);
29 }
30 }
31
32 template <typename Fn>
apply_on_flat_and_dim_ix_with_stride_and_base(const Fn & fn,const size_t stride,const size_t base,const size_t start,const size_t end)33 void apply_on_flat_and_dim_ix_with_stride_and_base(
34 const Fn& fn,
35 const size_t stride,
36 const size_t base,
37 const size_t start,
38 const size_t end) {
39 for (size_t i = start; i <= end; i++) {
40 fn(base + i * stride, i);
41 }
42 }
43
44 template <typename Fn>
apply_on_flat_ix_with_dim_mask_and_base(const Fn & fn,const Tensor & in,bool * dim_mask,const size_t base,const size_t start,const size_t end)45 void apply_on_flat_ix_with_dim_mask_and_base(
46 const Fn& fn,
47 const Tensor& in,
48 bool* dim_mask,
49 const size_t base,
50 const size_t start,
51 const size_t end) {
52 // Compute innermost dim from dim list
53 size_t inner_dim = in.dim() - 1;
54 while (!dim_mask[inner_dim]) {
55 inner_dim--;
56 }
57
58 // Initialize array of indices per dimension. This array is used to maintain
59 // the per-dimension index of the element in `in` that is being reduced over
60 // Only the dims that are in the dim list are relevant.
61 size_t dim_index[kTensorDimensionLimit];
62 for (int64_t d = 0; d < in.dim(); d++) {
63 dim_index[d] = 0;
64 }
65
66 // Gather strides
67 const auto strides = in.strides();
68
69 // curr_index will always be index of the element from `in` we are currently
70 // reducing. Initialized to the first index from `in` that maps to `out_ix`
71 size_t curr_index = base;
72
73 size_t apply_fun_counter = 0;
74 while (true) {
75 // Apply reduction to current index
76 if (apply_fun_counter >= start && apply_fun_counter <= end) {
77 fn(curr_index);
78 }
79 apply_fun_counter += 1;
80 if (apply_fun_counter > end) {
81 return;
82 }
83
84 // Next index to reduce. Increase dim_index[inner_dim] by 1, and curr_index
85 // by strides[inner_dim].
86 dim_index[inner_dim]++;
87 curr_index += strides[inner_dim];
88
89 // Check if we have reached the end of the innermost dimension
90 if (dim_index[inner_dim] == in.size(inner_dim)) {
91 // If we reached the end, we need to update the indices in dim_index. We
92 // do this by resetting dim_index[inner_dim] to 0, and then incrementing
93 // the index of the next innermost dimension from the dim list by 1.
94 // If when we do this increment, we also reach the end of that dimension,
95 // we need to keep repeating that procedure.
96 // This is similar to doing the carry over when adding 1 to a number.
97
98 // curr_dim will be the dim from the dim list we are currently updating
99 int64_t curr_dim = inner_dim;
100
101 while (dim_index[curr_dim] == in.size(curr_dim)) {
102 if (curr_dim == 0) {
103 // Exit function if we've reached the end of the outermost dimension
104 return;
105 }
106 // Reset dim_index[curr_dim] to 0. We need to update curr_index
107 // accordingly. Reseting dim_index[curr_dim] from in.size(curr_dim)
108 // to 0 means we need to subtract in.size(curr_dim) * strides[curr_dim]
109 // from curr_index. However in.size(curr_dim) * strides[curr_dim] is
110 // equal to strides[curr_dim - 1]. Notice that curr_dim > 0 at this
111 // point in the execution
112 dim_index[curr_dim] = 0;
113 curr_index -= strides[curr_dim - 1];
114
115 // Decrease current dim
116 curr_dim--;
117 while (curr_dim >= 0) {
118 // Stop if curr_dim is in the dim list
119 if (dim_mask[curr_dim]) {
120 break;
121 }
122 // Keep decreasing if curr_dim is not in the dim list
123 curr_dim--;
124 }
125 // Exit function if curr_dim was decreased to -1. This means we have
126 // reduced over all the elements we needed to.
127 if (curr_dim < 0) {
128 return;
129 }
130
131 // At this point in the execution, curr_dim is the next innermost
132 // dimension. Increase dim_index[curr_dim] by 1 and update curr_index
133 // accordingly.
134 dim_index[curr_dim]++;
135 curr_index += strides[curr_dim];
136 }
137 }
138 }
139 }
140
141 } // namespace
142
143 //
144 // Helper Functions
145 //
146
147 ET_NODISCARD bool check_dim_list_is_valid(
148 const exec_aten::Tensor& in,
149 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list);
150
151 bool check_dim_in_dim_list(
152 const size_t dim,
153 const size_t max_dim,
154 const exec_aten::ArrayRef<int64_t>& dim_list);
155
156 size_t get_reduced_dim_product(
157 const exec_aten::Tensor& in,
158 const exec_aten::optional<int64_t>& dim);
159
160 size_t get_reduced_dim_product(
161 const exec_aten::Tensor& in,
162 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list);
163
164 size_t get_out_numel(
165 const exec_aten::Tensor& in,
166 const exec_aten::optional<int64_t>& dim);
167
168 size_t get_out_numel(
169 const exec_aten::Tensor& in,
170 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list);
171
172 size_t get_init_index(
173 const exec_aten::Tensor& in,
174 const exec_aten::optional<int64_t>& dim,
175 const size_t out_ix);
176
177 size_t get_init_index(
178 const exec_aten::Tensor& in,
179 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
180 const size_t out_ix);
181
182 //
183 // Iteration Functions
184 //
185
186 /**
187 * Useful to reduce a tensor `in` over a given dimension `dim` using the
188 * reduce function `fn`, which should have the following signature:
189 * void fn(const size_t size, const size_t stride, const size_t base_ix)
190 * where `size` and `stride` are the size and stride of the dimension being
191 * reduced and `base_ix` is the index of the first element of the reduction.
192 */
193 template <typename Fn>
apply_over_dim(const Fn & fn,const exec_aten::Tensor & in,const exec_aten::optional<int64_t> & dim)194 void apply_over_dim(
195 const Fn& fn,
196 const exec_aten::Tensor& in,
197 const exec_aten::optional<int64_t>& dim) {
198 // If dim is null, apply fn over the entire tensor
199 if (!dim.has_value()) {
200 fn(in.numel(), 1, 0);
201 return;
202 }
203
204 if (in.dim() != 0) {
205 ET_CHECK_VALID_DIM(dim.value(), in.dim());
206 } else {
207 // Special handling for 0-D tensor; 0 or -1 is valid for PyTorch code
208 // `torch.mean(torch.tensor(2, dtype=float), dim=-1)`
209 ET_CHECK(dim.value() == 0 || dim.value() == -1);
210 fn(in.numel(), 1, 0);
211 return;
212 }
213
214 if (in.numel() == 0) {
215 return;
216 }
217
218 const size_t d = ET_NORMALIZE_IX(dim.value(), in.dim());
219
220 const size_t size = in.size(d);
221 const size_t stride = in.strides()[d];
222 const size_t outer_size = getLeadingDims(in, d);
223 const size_t outer_stride = size * stride;
224 // Loop through all outer dimensions
225 for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
226 size_t outer = outer_idx * outer_stride;
227 // Loop through all inner dimensions
228 for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) {
229 size_t base = outer + inner_idx;
230 fn(size, stride, base);
231 }
232 }
233 }
234
235 /**
236 * Useful to reduce a tensor `in` over a given dimension `dim` for the output
237 * element at index `out_ix` using the reduce function `fn`, which
238 * should have the following signature:
239 * `void fn(const size_t in_ix, const size_t dim_ix)`
240 * where `in_ix` is the flat index of each element from `in` that maps to
241 * `out_ix` and `dim_ix` is its index along `dim`.
242 */
243 template <typename Fn>
244 void apply_over_dim(
245 const Fn& fn,
246 const exec_aten::Tensor& in,
247 const exec_aten::optional<int64_t>& dim,
248 const size_t out_ix,
249 const int64_t start = 0,
250 const int64_t end = -1) {
251 if (dim.has_value()) {
252 if (in.dim() != 0) {
253 ET_CHECK_VALID_DIM(dim.value(), in.dim());
254 } else {
255 ET_CHECK(dim.value() == 0 || dim.value() == -1);
256 }
257 }
258 ET_CHECK_MSG(
259 out_ix < get_out_numel(in, dim),
260 "Out index %zd is out of bounds",
261 out_ix);
262
263 if (in.numel() == 0) {
264 return;
265 }
266
267 const size_t iter_length = get_reduced_dim_product(in, dim);
268 const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length);
269 const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length);
270 const size_t ustart = std::max(normalized_start, size_t(0));
271 const size_t uend = std::min(normalized_end, iter_length - 1);
272
273 // If dim is null, iterate over the entire tensor
274 if (!dim.has_value()) {
275 apply_on_flat_and_dim_ix_with_stride_and_base(
276 fn, /*stride=*/1, /*base=*/0, ustart, uend);
277 return;
278 }
279
280 // Compute the starting base index
281 const size_t base = get_init_index(in, dim, out_ix);
282
283 // Compute non-negative dimension value from dim value
284 const size_t d = ET_NORMALIZE_IX(dim.value(), in.dim());
285
286 if (in.dim() == 0) {
287 fn(base, ustart);
288 } else {
289 apply_on_flat_and_dim_ix_with_stride_and_base(
290 fn, in.strides()[d], base, ustart, uend);
291 }
292 }
293
294 /**
295 * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
296 * for the output element at index `out_ix` using the reduce function
297 * `fn`, which should have the following signature:
298 * `void fn(const size_t in_ix)`
299 * where `in_ix` is the index of each element from `in` that maps to `out_ix`
300 */
301 template <typename Fn>
302 void apply_over_dim_list(
303 const Fn& fn,
304 const exec_aten::Tensor& in,
305 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
306 const size_t out_ix,
307 const int64_t start = 0,
308 const int64_t end = -1) {
309 ET_CHECK(check_dim_list_is_valid(in, dim_list));
310 ET_CHECK_MSG(
311 out_ix < get_out_numel(in, dim_list),
312 "Out index %zd is out of bounds",
313 out_ix);
314
315 if (in.numel() == 0) {
316 return;
317 }
318
319 const size_t iter_length = get_reduced_dim_product(in, dim_list);
320 const size_t normalized_start = ET_NORMALIZE_IX(start, iter_length);
321 const size_t normalized_end = ET_NORMALIZE_IX(end, iter_length);
322 const size_t ustart = std::max(normalized_start, size_t(0));
323 const size_t uend = std::min(normalized_end, iter_length - 1);
324
325 // If dim_list is null or empty, or in is 0-D, iterate over the entire tensor
326 if (!dim_list.has_value() || dim_list.value().size() == 0 || in.dim() == 0) {
327 apply_on_flat_ix_with_stride_and_base(
328 fn, /*stride=*/1, /*base=*/0, ustart, uend);
329 return;
330 }
331
332 // Create is_in_dims to check whether each dimension is in the dim list
333 bool is_in_dim_list[kTensorDimensionLimit];
334 memset(is_in_dim_list, false, sizeof(is_in_dim_list));
335 for (const auto& d : dim_list.value()) {
336 const size_t non_neg_d = d < 0 ? d + in.dim() : d;
337 is_in_dim_list[non_neg_d] = true;
338 }
339
340 // Compute the starting base index
341 const size_t base = get_init_index(in, dim_list, out_ix);
342
343 apply_on_flat_ix_with_dim_mask_and_base(
344 fn, in, is_in_dim_list, base, ustart, uend);
345 }
346
347 //
348 // Reduce Functions
349 //
350
351 /**
352 * Useful to reduce a tensor `in` over a dimension `dim` for the output element
353 * at index `out_ix`, first applying the map `map_fun` to each element of `in`,
354 * which should have the signature: CTYPE_OUT map_fun(CTYPE_IN v)
355 * and then reducing using `reduce_fun`, which should have the signature:
356 * `CTYPE_OUT reduce_fun(CTYPE_OUT val, long ix, CTYPE_OUT acc_val, long
357 * acc_ix)`
358 *
359 * Common usage:
360 *
361 * CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
362 * for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
363 * out_data[out_ix] = map_reduce_over_dim<CTYPE_IN, CTYPE_OUT>(
364 * [](CTYPE_IN v) {
365 * // map operation on `v`, outputs `val`
366 * },
367 * [](CTYPE_OUT val, long ix, CTYPE_OUT acc_val, long acc_ix) {
368 * // reduce operation on `acc_val` and `acc_ix` using `val` and `ix`,
369 * // outputs {`acc_val`, `acc_ix`} pair
370 * in,
371 * dim_list,
372 * out_ix);
373 * }
374 */
375 template <
376 typename CTYPE_IN,
377 typename CTYPE_OUT,
378 typename MapOp,
379 typename ReduceOp>
map_reduce_over_dim(const MapOp & map_fun,const ReduceOp & reduce_fun,const exec_aten::Tensor & in,const exec_aten::optional<int64_t> & dim,const size_t out_ix)380 std::tuple<CTYPE_OUT, long> map_reduce_over_dim(
381 const MapOp& map_fun,
382 const ReduceOp& reduce_fun,
383 const exec_aten::Tensor& in,
384 const exec_aten::optional<int64_t>& dim,
385 const size_t out_ix) {
386 if (dim.has_value()) {
387 if (in.dim() != 0) {
388 ET_CHECK_VALID_DIM(dim.value(), in.dim());
389 } else {
390 ET_CHECK(dim.value() == 0 || dim.value() == -1);
391 }
392 }
393
394 ET_CHECK_MSG(
395 out_ix < get_out_numel(in, dim),
396 "Out index %zd is out of bounds",
397 out_ix);
398
399 ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty");
400
401 const size_t init_index = get_init_index(in, dim, out_ix);
402
403 const CTYPE_IN* const in_data = in.const_data_ptr<CTYPE_IN>();
404 CTYPE_OUT acc_val = map_fun(in_data[init_index]);
405 long acc_ix = 0;
406
407 if (in.numel() == 1) {
408 return std::tuple<CTYPE_OUT, long>{acc_val, acc_ix};
409 }
410
411 apply_over_dim(
412 [&acc_val, &acc_ix, reduce_fun, map_fun, in_data](
413 const size_t in_ix, const size_t dim_ix) {
414 std::tuple<CTYPE_OUT, long> res =
415 reduce_fun(map_fun(in_data[in_ix]), dim_ix, acc_val, acc_ix);
416 acc_val = std::get<0>(res);
417 acc_ix = std::get<1>(res);
418 },
419 in,
420 dim,
421 out_ix,
422 1,
423 -1);
424
425 return std::tuple<CTYPE_OUT, long>{acc_val, acc_ix};
426 }
427
428 /**
429 * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
430 * for the output element at index `out_ix`, first applying the map `map_fun`
431 * to each element of `in`, which should have the signature:
432 * `CTYPE_OUT map_fun(CTYPE_IN v)`
433 * and then reducing using `reduce_fun`, which should have the signature:
434 * `CTYPE_OUT reduce_fun(CTYPE_OUT v, CTYPE_OUT acc)`
435 *
436 * Common usage:
437 *
438 * CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
439 * for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
440 * out_data[out_ix] = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
441 * [](CTYPE_IN v) {
442 * // map operation on `v`, outputs `outv`
443 * },
444 * [](CTYPE_OUT outv, CTYPE_OUT acc) {
445 * // reduce operation on `acc` using `v`, outputs `acc`
446 * in,
447 * dim_list,
448 * out_ix);
449 * }
450 */
451 template <
452 typename CTYPE_IN,
453 typename CTYPE_OUT,
454 typename MapOp,
455 typename ReduceOp>
map_reduce_over_dim_list(const MapOp & map_fun,const ReduceOp & reduce_fun,const exec_aten::Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,const size_t out_ix)456 CTYPE_OUT map_reduce_over_dim_list(
457 const MapOp& map_fun,
458 const ReduceOp& reduce_fun,
459 const exec_aten::Tensor& in,
460 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
461 const size_t out_ix) {
462 ET_CHECK(check_dim_list_is_valid(in, dim_list));
463
464 ET_CHECK_MSG(
465 out_ix < get_out_numel(in, dim_list),
466 "Out index %zd is out of bounds",
467 out_ix);
468
469 ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty");
470
471 const size_t init_index = get_init_index(in, dim_list, out_ix);
472
473 const CTYPE_IN* const in_data = in.const_data_ptr<CTYPE_IN>();
474 CTYPE_OUT acc_val = map_fun(in_data[init_index]);
475
476 if (in.numel() == 1) {
477 return acc_val;
478 }
479
480 apply_over_dim_list(
481 [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
482 acc_val = reduce_fun(map_fun(in_data[in_ix]), acc_val);
483 },
484 in,
485 dim_list,
486 out_ix,
487 1,
488 -1);
489
490 return acc_val;
491 }
492
493 /**
494 * Useful to reduce a tensor `in` over a dimension `dim` for the output element
495 * at index `out_ix` using the reduce function `reduce_fun`, which should have
496 * the following signature:
497 * `CTYPE reduce_fun(CTYPE val, long ix, CTYPE acc_val, long acc_ix)`
498 *
499 * Common usage:
500 *
501 * CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
502 * for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
503 * out_data[out_ix] = reduce_over_dim<CTYPE>(
504 * [](CTYPE val, long ix, CTYPE acc_val, long acc_ix) {
505 * // reduce operation on `acc_val` and `acc_ix` using `val` and `ix`,
506 * // outputs {`acc_val`, `acc_ix`} pair
507 * },
508 * in,
509 * dim_list,
510 * out_ix);
511 * }
512 */
513 template <typename CTYPE, typename ReduceOp>
reduce_over_dim(const ReduceOp & reduce_fun,const exec_aten::Tensor & in,const exec_aten::optional<int64_t> & dim,const size_t out_ix)514 std::tuple<CTYPE, long> reduce_over_dim(
515 const ReduceOp& reduce_fun,
516 const exec_aten::Tensor& in,
517 const exec_aten::optional<int64_t>& dim,
518 const size_t out_ix) {
519 return map_reduce_over_dim<CTYPE, CTYPE>(
520 [](CTYPE v) { return v; }, reduce_fun, in, dim, out_ix);
521 }
522
523 /**
524 * Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
525 * for the output element at index `out_ix` using the reduce function
526 * `reduce_fun`, which should have the following signature:
527 * `CTYPE reduce_fun(CTYPE v, CTYPE acc)`
528 *
529 * Common usage:
530 *
531 * CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
532 * for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
533 * out_data[out_ix] = reduce_over_dim_list<CTYPE>(
534 * [](CTYPE v, CTYPE acc) {
535 * // reduce operation on `acc` using `v`, outputs `acc`
536 * },
537 * in,
538 * dim_list,
539 * out_ix);
540 * }
541 */
542 template <typename CTYPE, typename ReduceOp>
reduce_over_dim_list(const ReduceOp & reduce_fun,const exec_aten::Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,const size_t out_ix)543 CTYPE reduce_over_dim_list(
544 const ReduceOp& reduce_fun,
545 const exec_aten::Tensor& in,
546 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
547 const size_t out_ix) {
548 return map_reduce_over_dim_list<CTYPE, CTYPE>(
549 [](CTYPE v) { return v; }, reduce_fun, in, dim_list, out_ix);
550 }
551
552 //
553 // Compute reduced out tensor size and dim
554 //
555
556 size_t compute_reduced_out_size(
557 const exec_aten::Tensor& in,
558 const exec_aten::optional<int64_t>& dim,
559 bool keepdim,
560 exec_aten::SizesType* sizes_arr);
561
562 size_t compute_reduced_out_size(
563 const exec_aten::Tensor& in,
564 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
565 bool keepdim,
566 exec_aten::SizesType* sizes_arr);
567
compute_reduced_out_dim(const exec_aten::Tensor & in,const exec_aten::optional<int64_t> & dim,bool keepdim)568 inline ssize_t compute_reduced_out_dim(
569 const exec_aten::Tensor& in,
570 const exec_aten::optional<int64_t>& dim,
571 bool keepdim) {
572 return (
573 keepdim ? in.dim()
574 : dim.has_value() && in.dim() != 0 ? in.dim() - 1
575 : 0);
576 }
577
compute_reduced_out_dim(const exec_aten::Tensor & in,const exec_aten::optional<exec_aten::ArrayRef<int64_t>> & dim_list,bool keepdim)578 inline ssize_t compute_reduced_out_dim(
579 const exec_aten::Tensor& in,
580 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
581 bool keepdim) {
582 return (
583 keepdim ? in.dim()
584 : dim_list.has_value() && dim_list.value().size() != 0 &&
585 in.dim() != 0
586
587 ? in.dim() - dim_list.value().size()
588 : 0);
589 }
590
591 //
592 // Resize out tensor of reduction op
593 //
594
595 Error resize_reduction_out(
596 const exec_aten::Tensor& in,
597 const exec_aten::optional<int64_t>& dim,
598 bool keepdim,
599 exec_aten::Tensor& out);
600
601 Error resize_reduction_out(
602 const exec_aten::Tensor& in,
603 const exec_aten::optional<exec_aten::ArrayRef<int64_t>>& dim_list,
604 bool keepdim,
605 exec_aten::Tensor& out);
606
607 #ifndef USE_ATEN_LIB
608 bool check_reduction_args(
609 const Tensor& in,
610 const optional<ArrayRef<int64_t>>& dim_list,
611 bool keepdim,
612 optional<ScalarType> dtype,
613 Tensor& out);
614
615 bool check_reduction_args_single_dim(
616 const Tensor& in,
617 optional<int64_t> dim,
618 bool keepdim,
619 optional<ScalarType> dtype,
620 Tensor& out,
621 bool allow_empty_dim = false);
622
623 bool check_mean_dim_args(
624 const Tensor& in,
625 optional<ArrayRef<int64_t>> dim_list,
626 bool keepdim,
627 optional<ScalarType> dtype,
628 Tensor& out);
629
630 bool check_amin_amax_args(
631 const Tensor& in,
632 ArrayRef<int64_t> dim_list,
633 bool keepdim,
634 Tensor& out);
635
636 bool check_argmin_argmax_args(
637 const Tensor& in,
638 optional<int64_t> dim,
639 bool keepdim,
640 Tensor& out);
641
642 bool check_min_max_args(
643 const Tensor& in,
644 int64_t dim,
645 bool keepdim,
646 Tensor& max,
647 Tensor& max_indices);
648
649 bool check_prod_out_args(
650 const Tensor& in,
651 optional<ScalarType> dtype,
652 Tensor& out);
653
654 #endif
655
656 } // namespace executor
657 } // namespace torch
658