1 #pragma once
2
3 #include <string>
4 #include <variant>
5
6 #include <ATen/core/Reduction.h>
7 #include <c10/util/Exception.h>
8 #include <torch/csrc/Export.h>
9
10 #define TORCH_ENUM_DECLARE(name) \
11 namespace torch { \
12 namespace enumtype { \
13 /* \
14 NOTE: We need to provide the default constructor for each struct, \
15 otherwise Clang 3.8 would complain: \
16 ``` \
17 error: default initialization of an object of const type 'const \
18 enumtype::Enum1' without a user-provided default constructor \
19 ``` \
20 */ \
21 struct k##name { \
22 k##name() {} \
23 }; \
24 } \
25 TORCH_API extern const enumtype::k##name k##name; \
26 }
27
28 #define TORCH_ENUM_DEFINE(name) \
29 namespace torch { \
30 const enumtype::k##name k##name; \
31 }
32
33 #define TORCH_ENUM_PRETTY_PRINT(name) \
34 std::string operator()(const enumtype::k##name& v [[maybe_unused]]) const { \
35 std::string k("k"); \
36 return k + #name; \
37 }
38
39 // NOTE: Backstory on why we need the following two macros:
40 //
41 // Consider the following options class:
42 //
43 // ```
44 // struct TORCH_API SomeOptions {
45 // typedef std::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>
46 // reduction_t; SomeOptions(reduction_t reduction = torch::kMean) :
47 // reduction_(reduction) {}
48 //
49 // TORCH_ARG(reduction_t, reduction);
50 // };
51 // ```
52 //
53 // and the functional that uses it:
54 //
55 // ```
56 // Tensor some_functional(
57 // const Tensor& input,
58 // SomeOptions options = {}) {
59 // ...
60 // }
61 // ```
62 //
63 // Normally, we would expect this to work:
64 //
65 // `F::some_functional(input, torch::kNone)`
66 //
67 // However, it throws the following error instead:
68 //
69 // ```
70 // error: could not convert `torch::kNone` from `const torch::enumtype::kNone`
71 // to `torch::nn::SomeOptions`
72 // ```
73 //
74 // To get around this problem, we explicitly provide the following constructors
75 // for `SomeOptions`:
76 //
77 // ```
78 // SomeOptions(torch::enumtype::kNone reduction) : reduction_(torch::kNone) {}
79 // SomeOptions(torch::enumtype::kMean reduction) : reduction_(torch::kMean) {}
80 // SomeOptions(torch::enumtype::kSum reduction) : reduction_(torch::kSum) {}
81 // ```
82 //
83 // so that the conversion from `torch::kNone` to `SomeOptions` would work.
84 //
85 // Note that we also provide the default constructor `SomeOptions() {}`, so that
86 // `SomeOptions options = {}` can work.
87 #define TORCH_OPTIONS_CTOR_VARIANT_ARG3( \
88 OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3) \
89 OPTIONS_NAME() = default; \
90 OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
91 OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \
92 OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {}
93
94 #define TORCH_OPTIONS_CTOR_VARIANT_ARG4( \
95 OPTIONS_NAME, ARG_NAME, TYPE1, TYPE2, TYPE3, TYPE4) \
96 OPTIONS_NAME() = default; \
97 OPTIONS_NAME(torch::enumtype::TYPE1 ARG_NAME) : ARG_NAME##_(torch::TYPE1) {} \
98 OPTIONS_NAME(torch::enumtype::TYPE2 ARG_NAME) : ARG_NAME##_(torch::TYPE2) {} \
99 OPTIONS_NAME(torch::enumtype::TYPE3 ARG_NAME) : ARG_NAME##_(torch::TYPE3) {} \
100 OPTIONS_NAME(torch::enumtype::TYPE4 ARG_NAME) : ARG_NAME##_(torch::TYPE4) {}
101
102 TORCH_ENUM_DECLARE(Linear)
TORCH_ENUM_DECLARE(Conv1D)103 TORCH_ENUM_DECLARE(Conv1D)
104 TORCH_ENUM_DECLARE(Conv2D)
105 TORCH_ENUM_DECLARE(Conv3D)
106 TORCH_ENUM_DECLARE(ConvTranspose1D)
107 TORCH_ENUM_DECLARE(ConvTranspose2D)
108 TORCH_ENUM_DECLARE(ConvTranspose3D)
109 TORCH_ENUM_DECLARE(Sigmoid)
110 TORCH_ENUM_DECLARE(Tanh)
111 TORCH_ENUM_DECLARE(ReLU)
112 TORCH_ENUM_DECLARE(GELU)
113 TORCH_ENUM_DECLARE(SiLU)
114 TORCH_ENUM_DECLARE(Mish)
115 TORCH_ENUM_DECLARE(LeakyReLU)
116 TORCH_ENUM_DECLARE(FanIn)
117 TORCH_ENUM_DECLARE(FanOut)
118 TORCH_ENUM_DECLARE(Constant)
119 TORCH_ENUM_DECLARE(Reflect)
120 TORCH_ENUM_DECLARE(Replicate)
121 TORCH_ENUM_DECLARE(Circular)
122 TORCH_ENUM_DECLARE(Nearest)
123 TORCH_ENUM_DECLARE(Bilinear)
124 TORCH_ENUM_DECLARE(Bicubic)
125 TORCH_ENUM_DECLARE(Trilinear)
126 TORCH_ENUM_DECLARE(Area)
127 TORCH_ENUM_DECLARE(NearestExact)
128 TORCH_ENUM_DECLARE(Sum)
129 TORCH_ENUM_DECLARE(Mean)
130 TORCH_ENUM_DECLARE(Max)
131 TORCH_ENUM_DECLARE(None)
132 TORCH_ENUM_DECLARE(BatchMean)
133 TORCH_ENUM_DECLARE(Zeros)
134 TORCH_ENUM_DECLARE(Border)
135 TORCH_ENUM_DECLARE(Reflection)
136 TORCH_ENUM_DECLARE(RNN_TANH)
137 TORCH_ENUM_DECLARE(RNN_RELU)
138 TORCH_ENUM_DECLARE(LSTM)
139 TORCH_ENUM_DECLARE(GRU)
140 TORCH_ENUM_DECLARE(Valid)
141 TORCH_ENUM_DECLARE(Same)
142
143 namespace torch {
144 namespace enumtype {
145
146 struct _compute_enum_name {
147 TORCH_ENUM_PRETTY_PRINT(Linear)
148 TORCH_ENUM_PRETTY_PRINT(Conv1D)
149 TORCH_ENUM_PRETTY_PRINT(Conv2D)
150 TORCH_ENUM_PRETTY_PRINT(Conv3D)
151 TORCH_ENUM_PRETTY_PRINT(ConvTranspose1D)
152 TORCH_ENUM_PRETTY_PRINT(ConvTranspose2D)
153 TORCH_ENUM_PRETTY_PRINT(ConvTranspose3D)
154 TORCH_ENUM_PRETTY_PRINT(Sigmoid)
155 TORCH_ENUM_PRETTY_PRINT(Tanh)
156 TORCH_ENUM_PRETTY_PRINT(ReLU)
157 TORCH_ENUM_PRETTY_PRINT(GELU)
158 TORCH_ENUM_PRETTY_PRINT(SiLU)
159 TORCH_ENUM_PRETTY_PRINT(Mish)
160 TORCH_ENUM_PRETTY_PRINT(LeakyReLU)
161 TORCH_ENUM_PRETTY_PRINT(FanIn)
162 TORCH_ENUM_PRETTY_PRINT(FanOut)
163 TORCH_ENUM_PRETTY_PRINT(Constant)
164 TORCH_ENUM_PRETTY_PRINT(Reflect)
165 TORCH_ENUM_PRETTY_PRINT(Replicate)
166 TORCH_ENUM_PRETTY_PRINT(Circular)
167 TORCH_ENUM_PRETTY_PRINT(Nearest)
168 TORCH_ENUM_PRETTY_PRINT(Bilinear)
169 TORCH_ENUM_PRETTY_PRINT(Bicubic)
170 TORCH_ENUM_PRETTY_PRINT(Trilinear)
171 TORCH_ENUM_PRETTY_PRINT(Area)
172 TORCH_ENUM_PRETTY_PRINT(NearestExact)
173 TORCH_ENUM_PRETTY_PRINT(Sum)
174 TORCH_ENUM_PRETTY_PRINT(Mean)
175 TORCH_ENUM_PRETTY_PRINT(Max)
176 TORCH_ENUM_PRETTY_PRINT(None)
177 TORCH_ENUM_PRETTY_PRINT(BatchMean)
178 TORCH_ENUM_PRETTY_PRINT(Zeros)
179 TORCH_ENUM_PRETTY_PRINT(Border)
180 TORCH_ENUM_PRETTY_PRINT(Reflection)
181 TORCH_ENUM_PRETTY_PRINT(RNN_TANH)
182 TORCH_ENUM_PRETTY_PRINT(RNN_RELU)
183 TORCH_ENUM_PRETTY_PRINT(LSTM)
184 TORCH_ENUM_PRETTY_PRINT(GRU)
185 TORCH_ENUM_PRETTY_PRINT(Valid)
186 TORCH_ENUM_PRETTY_PRINT(Same)
187 };
188
189 template <typename V>
190 std::string get_enum_name(V variant_enum) {
191 return std::visit(enumtype::_compute_enum_name{}, variant_enum);
192 }
193
194 template <typename V>
195 at::Reduction::Reduction reduction_get_enum(V variant_enum) {
196 if (std::holds_alternative<enumtype::kNone>(variant_enum)) {
197 return at::Reduction::None;
198 } else if (std::holds_alternative<enumtype::kMean>(variant_enum)) {
199 return at::Reduction::Mean;
200 } else if (std::holds_alternative<enumtype::kSum>(variant_enum)) {
201 return at::Reduction::Sum;
202 } else {
203 TORCH_CHECK(
204 false,
205 get_enum_name(variant_enum),
206 " is not a valid value for reduction");
207 return at::Reduction::END;
208 }
209 }
210
211 } // namespace enumtype
212 } // namespace torch
213