1 #pragma once
2
3 #include <c10/core/Scalar.h>
4
5 namespace at::native {
6
7 enum class ReductionType {MAX, MEAN, MIN, SUM, PROD};
8
get_reduction_enum(const c10::string_view & reduce)9 inline ReductionType get_reduction_enum(const c10::string_view& reduce) {
10 if (reduce == "max" || reduce == "amax") {
11 return ReductionType::MAX;
12 } else if (reduce == "mean") {
13 return ReductionType::MEAN;
14 } else if (reduce == "min" || reduce == "amin") {
15 return ReductionType::MIN;
16 } else if (reduce == "sum") {
17 return ReductionType::SUM;
18 } else if (reduce == "prod") {
19 return ReductionType::PROD;
20 } else {
21 TORCH_CHECK(false, "reduce argument must be either sum, prod, mean, amax or amin, got ", reduce);
22 }
23 }
24
25 // used for `scatter_reduce`, old options for BC.
get_operator_enum(const c10::string_view reduce,bool use_new_options)26 inline ReductionType get_operator_enum(const c10::string_view reduce, bool use_new_options) {
27 if (use_new_options) {
28 return get_reduction_enum(reduce);
29 } else {
30 if (reduce == "add") {
31 return ReductionType::SUM;
32 } else if (reduce == "multiply") {
33 return ReductionType::PROD;
34 } else {
35 TORCH_CHECK(false, "reduce argument must be either add or multiply.")
36 }
37 }
38 }
39
40 } // at::native
41