xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ReductionType.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Scalar.h>
4 
5 namespace at::native {
6 
7 enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
8 
get_reduction_enum(const c10::string_view & reduce)9 inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
10   if (reduce == "max" || reduce == "amax") {
11     return ReductionType::MAX;
12   } else if (reduce == "mean") {
13     return ReductionType::MEAN;
14   } else if (reduce == "min" || reduce == "amin") {
15     return ReductionType::MIN;
16   } else if (reduce == "sum") {
17     return ReductionType::SUM;
18   } else if (reduce == "prod") {
19     return ReductionType::PROD;
20   } else {
21     TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
22   }
23 }
24 
25 // used for `scatter_reduce`, old options for BC.
get_operator_enum(const c10::string_view reduce,bool use_new_options)26 inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
27   if (use_new_options) {
28     return get_reduction_enum(reduce);
29   } else {
30     if (reduce == "add") {
31       return ReductionType::SUM;
32     } else if (reduce == "multiply") {
33       return ReductionType::PROD;
34     } else {
35       TORCH_CHECK(false, "reduce argument must be either add or multiply.")
36     }
37   }
38 }
39 
40 } // at::native
41