xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/enum.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <string>
4 #include <variant>
5 
6 #include <ATen/core/Reduction.h>
7 #include <c10/util/Exception.h>
8 #include <torch/csrc/Export.h>
9 
10 #define TORCH_ENUM_DECLARE(name)                                      \
11   namespace torch {                                                   \
12   namespace enumtype {                                                \
13   /*                                                                  \
14     NOTE: We need to provide the default constructor for each struct, \
15     otherwise Clang 3.8 would complain:                               \
16     ```                                                               \
17     error: default initialization of an object of const type 'const   \
18     enumtype::Enum1' without a user-provided default constructor      \
19     ```                                                               \
20   */                                                                  \
21   struct k##name {                                                    \
22     k##name() {}                                                      \
23   };                                                                  \
24   }                                                                   \
25   TORCH_API extern const enumtype::k##name k##name;                   \
26   }
27 
28 #define TORCH_ENUM_DEFINE(name)    \
29   namespace torch {                \
30   const enumtype::k##name k##name; \
31   }
32 
33 #define TORCH_ENUM_PRETTY_PRINT(name)                                         \
34   std::string operator()(const enumtype::k##name& v [[maybe_unused]]) const { \
35     std::string k("k");                                                       \
36     return k + #name;                                                         \
37   }
38 
39 // NOTE: Backstory on why we need the following two macros:
40 //
41 // Consider the following options class:
42 //
43 // ```
44 // struct TORCH_API SomeOptions {
45 //   typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
46 //   reduction_t; SomeOptions(reduction_t reduction = torch::kMean) :
47 //   reduction_(reduction) {}
48 //
49 //   TORCH_ARG(reduction_t, reduction);
50 // };
51 // ```
52 //
53 // and the functional that uses it:
54 //
55 // ```
56 // Tensor some_functional(
57 //     const Tensor& input,
58 //     SomeOptions options = {}) {
59 //   ...
60 // }
61 // ```
62 //
63 // Normally, we would expect this to work:
64 //
65 // `F::some_functional(input, torch::kNone)`
66 //
67 // However, it throws the following error instead:
68 //
69 // ```
70 // error: could not convert `torch::kNone` from `const torch::enumtype::kNone`
71 // to `torch::nn::SomeOptions`
72 // ```
73 //
74 // To get around this problem, we explicitly provide the following constructors
75 // for `SomeOptions`:
76 //
77 // ```
78 // SomeOptions(torch::enumtype::kNone reduction) : reduction_(torch::kNone) {}
79 // SomeOptions(torch::enumtype::kMean reduction) : reduction_(torch::kMean) {}
80 // SomeOptions(torch::enumtype::kSum reduction) : reduction_(torch::kSum) {}
81 // ```
82 //
83 // so that the conversion from `torch::kNone` to `SomeOptions` would work.
84 //
85 // Note that we also provide the default constructor `SomeOptions() {}`, so that
86 // `SomeOptions options = {}` can work.
87 #define TORCH_OPTIONS_CTOR_VARIANT_ARG3(                                       \
88     OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3)                               \
89   OPTIONS_NAME() = default;                                                    \
90   OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
91   OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \
92   OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {}
93 
94 #define TORCH_OPTIONS_CTOR_VARIANT_ARG4(                                       \
95     OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3, TYPE4)                        \
96   OPTIONS_NAME() = default;                                                    \
97   OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
98   OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \
99   OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} \
100   OPTIONS_NAME(torch::enumtype::TYPE4 ARG_NAME) : ARG_NAME##_(torch::TYPE4) {}
101 
102 TORCH_ENUM_DECLARE(Linear)
TORCH_ENUM_DECLARE(Conv1D)103 TORCH_ENUM_DECLARE(Conv1D)
104 TORCH_ENUM_DECLARE(Conv2D)
105 TORCH_ENUM_DECLARE(Conv3D)
106 TORCH_ENUM_DECLARE(ConvTranspose1D)
107 TORCH_ENUM_DECLARE(ConvTranspose2D)
108 TORCH_ENUM_DECLARE(ConvTranspose3D)
109 TORCH_ENUM_DECLARE(Sigmoid)
110 TORCH_ENUM_DECLARE(Tanh)
111 TORCH_ENUM_DECLARE(ReLU)
112 TORCH_ENUM_DECLARE(GELU)
113 TORCH_ENUM_DECLARE(SiLU)
114 TORCH_ENUM_DECLARE(Mish)
115 TORCH_ENUM_DECLARE(LeakyReLU)
116 TORCH_ENUM_DECLARE(FanIn)
117 TORCH_ENUM_DECLARE(FanOut)
118 TORCH_ENUM_DECLARE(Constant)
119 TORCH_ENUM_DECLARE(Reflect)
120 TORCH_ENUM_DECLARE(Replicate)
121 TORCH_ENUM_DECLARE(Circular)
122 TORCH_ENUM_DECLARE(Nearest)
123 TORCH_ENUM_DECLARE(Bilinear)
124 TORCH_ENUM_DECLARE(Bicubic)
125 TORCH_ENUM_DECLARE(Trilinear)
126 TORCH_ENUM_DECLARE(Area)
127 TORCH_ENUM_DECLARE(NearestExact)
128 TORCH_ENUM_DECLARE(Sum)
129 TORCH_ENUM_DECLARE(Mean)
130 TORCH_ENUM_DECLARE(Max)
131 TORCH_ENUM_DECLARE(None)
132 TORCH_ENUM_DECLARE(BatchMean)
133 TORCH_ENUM_DECLARE(Zeros)
134 TORCH_ENUM_DECLARE(Border)
135 TORCH_ENUM_DECLARE(Reflection)
136 TORCH_ENUM_DECLARE(RNN_TANH)
137 TORCH_ENUM_DECLARE(RNN_RELU)
138 TORCH_ENUM_DECLARE(LSTM)
139 TORCH_ENUM_DECLARE(GRU)
140 TORCH_ENUM_DECLARE(Valid)
141 TORCH_ENUM_DECLARE(Same)
142 
143 namespace torch {
144 namespace enumtype {
145 
146 struct _compute_enum_name {
147   TORCH_ENUM_PRETTY_PRINT(Linear)
148   TORCH_ENUM_PRETTY_PRINT(Conv1D)
149   TORCH_ENUM_PRETTY_PRINT(Conv2D)
150   TORCH_ENUM_PRETTY_PRINT(Conv3D)
151   TORCH_ENUM_PRETTY_PRINT(ConvTranspose1D)
152   TORCH_ENUM_PRETTY_PRINT(ConvTranspose2D)
153   TORCH_ENUM_PRETTY_PRINT(ConvTranspose3D)
154   TORCH_ENUM_PRETTY_PRINT(Sigmoid)
155   TORCH_ENUM_PRETTY_PRINT(Tanh)
156   TORCH_ENUM_PRETTY_PRINT(ReLU)
157   TORCH_ENUM_PRETTY_PRINT(GELU)
158   TORCH_ENUM_PRETTY_PRINT(SiLU)
159   TORCH_ENUM_PRETTY_PRINT(Mish)
160   TORCH_ENUM_PRETTY_PRINT(LeakyReLU)
161   TORCH_ENUM_PRETTY_PRINT(FanIn)
162   TORCH_ENUM_PRETTY_PRINT(FanOut)
163   TORCH_ENUM_PRETTY_PRINT(Constant)
164   TORCH_ENUM_PRETTY_PRINT(Reflect)
165   TORCH_ENUM_PRETTY_PRINT(Replicate)
166   TORCH_ENUM_PRETTY_PRINT(Circular)
167   TORCH_ENUM_PRETTY_PRINT(Nearest)
168   TORCH_ENUM_PRETTY_PRINT(Bilinear)
169   TORCH_ENUM_PRETTY_PRINT(Bicubic)
170   TORCH_ENUM_PRETTY_PRINT(Trilinear)
171   TORCH_ENUM_PRETTY_PRINT(Area)
172   TORCH_ENUM_PRETTY_PRINT(NearestExact)
173   TORCH_ENUM_PRETTY_PRINT(Sum)
174   TORCH_ENUM_PRETTY_PRINT(Mean)
175   TORCH_ENUM_PRETTY_PRINT(Max)
176   TORCH_ENUM_PRETTY_PRINT(None)
177   TORCH_ENUM_PRETTY_PRINT(BatchMean)
178   TORCH_ENUM_PRETTY_PRINT(Zeros)
179   TORCH_ENUM_PRETTY_PRINT(Border)
180   TORCH_ENUM_PRETTY_PRINT(Reflection)
181   TORCH_ENUM_PRETTY_PRINT(RNN_TANH)
182   TORCH_ENUM_PRETTY_PRINT(RNN_RELU)
183   TORCH_ENUM_PRETTY_PRINT(LSTM)
184   TORCH_ENUM_PRETTY_PRINT(GRU)
185   TORCH_ENUM_PRETTY_PRINT(Valid)
186   TORCH_ENUM_PRETTY_PRINT(Same)
187 };
188 
189 template <typename V>
190 std::string get_enum_name(V variant_enum) {
191   return std::visit(enumtype::_compute_enum_name{}, variant_enum);
192 }
193 
194 template <typename V>
195 at::Reduction::Reduction reduction_get_enum(V variant_enum) {
196   if (std::holds_alternative<enumtype::kNone>(variant_enum)) {
197     return at::Reduction::None;
198   } else if (std::holds_alternative<enumtype::kMean>(variant_enum)) {
199     return at::Reduction::Mean;
200   } else if (std::holds_alternative<enumtype::kSum>(variant_enum)) {
201     return at::Reduction::Sum;
202   } else {
203     TORCH_CHECK(
204         false,
205         get_enum_name(variant_enum),
206         " is not a valid value for reduction");
207     return at::Reduction::END;
208   }
209 }
210 
211 } // namespace enumtype
212 } // namespace torch
213