1 #include <torch/nn/functional/conv.h>
2 #include <torch/nn/functional/padding.h>
3 #include <torch/nn/modules/conv.h>
4
5 #include <c10/util/irange.h>
6 #include <torch/enum.h>
7 #include <torch/expanding_array.h>
8 #include <torch/nn/init.h>
9 #include <torch/types.h>
10 #include <torch/utils.h>
11
12 #include <cmath>
13 #include <cstdint>
14 #include <functional>
15 #include <utility>
16 #include <vector>
17
18 namespace F = torch::nn::functional;
19
_get_pad_mode_from_conv_padding_mode(torch::nn::detail::conv_padding_mode_t conv_padding_mode)20 static F::PadFuncOptions::mode_t _get_pad_mode_from_conv_padding_mode(
21 torch::nn::detail::conv_padding_mode_t conv_padding_mode) {
22 F::PadFuncOptions::mode_t pad_mode;
23 if (std::holds_alternative<torch::enumtype::kReflect>(conv_padding_mode)) {
24 pad_mode = torch::kReflect;
25 } else if (std::holds_alternative<torch::enumtype::kReplicate>(
26 conv_padding_mode)) {
27 pad_mode = torch::kReplicate;
28 } else if (std::holds_alternative<torch::enumtype::kCircular>(
29 conv_padding_mode)) {
30 pad_mode = torch::kCircular;
31 } else {
32 TORCH_CHECK(
33 false,
34 "Unsupported conv padding mode: ",
35 torch::enumtype::get_enum_name(conv_padding_mode));
36 }
37 return pad_mode;
38 }
39
40 namespace torch {
41 namespace nn {
Conv1dImpl(Conv1dOptions options_)42 Conv1dImpl::Conv1dImpl(Conv1dOptions options_)
43 : ConvNdImpl(detail::ConvNdOptions<1>(
44 /*in_channels=*/options_.in_channels(),
45 /*out_channels=*/options_.out_channels(),
46 /*kernel_size=*/options_.kernel_size())
47 .stride(options_.stride())
48 .padding(options_.padding())
49 .dilation(options_.dilation())
50 .transposed(false)
51 .output_padding(0)
52 .groups(options_.groups())
53 .bias(options_.bias())
54 .padding_mode(options_.padding_mode())) {}
55
forward(const Tensor & input)56 Tensor Conv1dImpl::forward(const Tensor& input) {
57 if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
58 return F::detail::conv1d(
59 F::pad(
60 input,
61 F::PadFuncOptions(_reversed_padding_repeated_twice)
62 .mode(_get_pad_mode_from_conv_padding_mode(
63 options.padding_mode()))),
64 weight,
65 bias,
66 options.stride(),
67 /*padding=*/0,
68 options.dilation(),
69 options.groups());
70 }
71 return F::detail::conv1d(
72 input,
73 weight,
74 bias,
75 options.stride(),
76 options.padding(),
77 options.dilation(),
78 options.groups());
79 }
80
Conv2dImpl(Conv2dOptions options_)81 Conv2dImpl::Conv2dImpl(Conv2dOptions options_)
82 : ConvNdImpl(detail::ConvNdOptions<2>(
83 /*in_channels=*/options_.in_channels(),
84 /*out_channels=*/options_.out_channels(),
85 /*kernel_size=*/options_.kernel_size())
86 .stride(options_.stride())
87 .padding(options_.padding())
88 .dilation(options_.dilation())
89 .transposed(false)
90 .output_padding(0)
91 .groups(options_.groups())
92 .bias(options_.bias())
93 .padding_mode(options_.padding_mode())) {}
94
_conv_forward(const Tensor & input,const Tensor & weight)95 Tensor Conv2dImpl::_conv_forward(const Tensor& input, const Tensor& weight) {
96 if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
97 return F::detail::conv2d(
98 F::pad(
99 input,
100 F::PadFuncOptions(_reversed_padding_repeated_twice)
101 .mode(_get_pad_mode_from_conv_padding_mode(
102 options.padding_mode()))),
103 weight,
104 bias,
105 options.stride(),
106 /*padding=*/0,
107 options.dilation(),
108 options.groups());
109 }
110 return F::detail::conv2d(
111 input,
112 weight,
113 bias,
114 options.stride(),
115 options.padding(),
116 options.dilation(),
117 options.groups());
118 }
119
forward(const Tensor & input)120 Tensor Conv2dImpl::forward(const Tensor& input) {
121 return _conv_forward(input, weight);
122 }
123
Conv3dImpl(Conv3dOptions options_)124 Conv3dImpl::Conv3dImpl(Conv3dOptions options_)
125 : ConvNdImpl(detail::ConvNdOptions<3>(
126 /*in_channels=*/options_.in_channels(),
127 /*out_channels=*/options_.out_channels(),
128 /*kernel_size=*/options_.kernel_size())
129 .stride(options_.stride())
130 .padding(options_.padding())
131 .dilation(options_.dilation())
132 .transposed(false)
133 .output_padding(0)
134 .groups(options_.groups())
135 .bias(options_.bias())
136 .padding_mode(options_.padding_mode())) {}
137
forward(const Tensor & input)138 Tensor Conv3dImpl::forward(const Tensor& input) {
139 if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
140 return F::detail::conv3d(
141 F::pad(
142 input,
143 F::PadFuncOptions(_reversed_padding_repeated_twice)
144 .mode(_get_pad_mode_from_conv_padding_mode(
145 options.padding_mode()))),
146 weight,
147 bias,
148 options.stride(),
149 /*padding=*/0,
150 options.dilation(),
151 options.groups());
152 }
153 return F::detail::conv3d(
154 input,
155 weight,
156 bias,
157 options.stride(),
158 options.padding(),
159 options.dilation(),
160 options.groups());
161 }
162
163 template class ConvNdImpl<1, Conv1dImpl>;
164 template class ConvNdImpl<2, Conv2dImpl>;
165 template class ConvNdImpl<3, Conv3dImpl>;
166
167 // ============================================================================
168
169 template <size_t D, typename Derived>
_output_padding(const Tensor & input,const std::optional<at::IntArrayRef> & output_size,const ExpandingArray<D> & stride,const ExpandingArray<D> & padding,const ExpandingArray<D> & kernel_size)170 std::vector<int64_t> ConvTransposeNdImpl<D, Derived>::_output_padding(
171 const Tensor& input,
172 const std::optional<at::IntArrayRef>& output_size,
173 const ExpandingArray<D>& stride,
174 const ExpandingArray<D>& padding,
175 const ExpandingArray<D>& kernel_size) {
176 std::vector<int64_t> ret;
177 std::optional<at::IntArrayRef> output_size_ = output_size;
178
179 if (output_size_ == std::nullopt) {
180 ret = at::IntArrayRef(this->options.output_padding()).vec();
181 } else {
182 auto k = input.dim() - 2;
183 if (output_size_.value().size() == static_cast<size_t>(k + 2)) {
184 output_size_ = output_size_.value().slice(2);
185 }
186 if (output_size_.value().size() != static_cast<size_t>(k)) {
187 TORCH_CHECK(
188 false,
189 "output_size must have ",
190 k,
191 " or ",
192 k + 2,
193 " elements (got ",
194 output_size_.value().size(),
195 ")");
196 }
197
198 std::vector<int64_t> min_sizes;
199 std::vector<int64_t> max_sizes;
200 for (const auto d : c10::irange(k)) {
201 int64_t dim_size =
202 ((input.sizes()[d + 2] - 1) * (*stride)[d] - 2 * (*padding)[d] +
203 (*kernel_size)[d]);
204 min_sizes.push_back(dim_size);
205 max_sizes.push_back(min_sizes[d] + (*stride)[d] - 1);
206 }
207
208 for (const auto i : c10::irange(output_size_.value().size())) {
209 int64_t size = output_size_.value()[i];
210 int64_t min_size = min_sizes[i];
211 int64_t max_size = max_sizes[i];
212 if (size < min_size || size > max_size) {
213 TORCH_CHECK(
214 false,
215 "requested an output size of ",
216 output_size_.value(),
217 ", but valid sizes range "
218 "from ",
219 min_sizes,
220 " to ",
221 max_sizes,
222 " (for an input of ",
223 input.sizes().slice(2),
224 ")");
225 }
226 }
227
228 for (const auto d : c10::irange(k)) {
229 ret.push_back(output_size_.value()[d] - min_sizes[d]);
230 }
231 }
232 return ret;
233 }
234
ConvTranspose1dImpl(ConvTranspose1dOptions options_)235 ConvTranspose1dImpl::ConvTranspose1dImpl(ConvTranspose1dOptions options_)
236 : ConvTransposeNdImpl(detail::ConvNdOptions<1>(
237 /*in_channels=*/options_.in_channels(),
238 /*out_channels=*/options_.out_channels(),
239 /*kernel_size=*/options_.kernel_size())
240 .stride(options_.stride())
241 .padding(options_.padding())
242 .dilation(options_.dilation())
243 .transposed(true)
244 .output_padding(options_.output_padding())
245 .groups(options_.groups())
246 .bias(options_.bias())
247 .padding_mode(options_.padding_mode())) {}
248
forward(const Tensor & input,const std::optional<at::IntArrayRef> & output_size)249 Tensor ConvTranspose1dImpl::forward(
250 const Tensor& input,
251 const std::optional<at::IntArrayRef>& output_size) {
252 if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
253 TORCH_CHECK(
254 false, "Only `zeros` padding mode is supported for ConvTranspose1d");
255 }
256
257 const auto& pad = padding();
258 std::vector<int64_t> output_padding = _output_padding(
259 input, output_size, options.stride(), pad, options.kernel_size());
260
261 return F::detail::conv_transpose1d(
262 input,
263 weight,
264 bias,
265 options.stride(),
266 pad,
267 output_padding,
268 options.groups(),
269 options.dilation());
270 }
271
ConvTranspose2dImpl(ConvTranspose2dOptions options_)272 ConvTranspose2dImpl::ConvTranspose2dImpl(ConvTranspose2dOptions options_)
273 : ConvTransposeNdImpl(detail::ConvNdOptions<2>(
274 /*in_channels=*/options_.in_channels(),
275 /*out_channels=*/options_.out_channels(),
276 /*kernel_size=*/options_.kernel_size())
277 .stride(options_.stride())
278 .padding(options_.padding())
279 .dilation(options_.dilation())
280 .transposed(true)
281 .output_padding(options_.output_padding())
282 .groups(options_.groups())
283 .bias(options_.bias())
284 .padding_mode(options_.padding_mode())) {}
285
forward(const Tensor & input,const std::optional<at::IntArrayRef> & output_size)286 Tensor ConvTranspose2dImpl::forward(
287 const Tensor& input,
288 const std::optional<at::IntArrayRef>& output_size) {
289 if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
290 TORCH_CHECK(
291 false, "Only `zeros` padding mode is supported for ConvTranspose2d");
292 }
293
294 const auto& pad = padding();
295 std::vector<int64_t> output_padding = _output_padding(
296 input, output_size, options.stride(), pad, options.kernel_size());
297
298 return F::detail::conv_transpose2d(
299 input,
300 weight,
301 bias,
302 options.stride(),
303 pad,
304 output_padding,
305 options.groups(),
306 options.dilation());
307 }
308
ConvTranspose3dImpl(ConvTranspose3dOptions options_)309 ConvTranspose3dImpl::ConvTranspose3dImpl(ConvTranspose3dOptions options_)
310 : ConvTransposeNdImpl(detail::ConvNdOptions<3>(
311 /*in_channels=*/options_.in_channels(),
312 /*out_channels=*/options_.out_channels(),
313 /*kernel_size=*/options_.kernel_size())
314 .stride(options_.stride())
315 .padding(options_.padding())
316 .dilation(options_.dilation())
317 .transposed(true)
318 .output_padding(options_.output_padding())
319 .groups(options_.groups())
320 .bias(options_.bias())
321 .padding_mode(options_.padding_mode())) {}
322
forward(const Tensor & input,const std::optional<at::IntArrayRef> & output_size)323 Tensor ConvTranspose3dImpl::forward(
324 const Tensor& input,
325 const std::optional<at::IntArrayRef>& output_size) {
326 if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
327 TORCH_CHECK(
328 false, "Only `zeros` padding mode is supported for ConvTranspose3d");
329 }
330
331 const auto& pad = padding();
332 std::vector<int64_t> output_padding = _output_padding(
333 input, output_size, options.stride(), pad, options.kernel_size());
334
335 return F::detail::conv_transpose3d(
336 input,
337 weight,
338 bias,
339 options.stride(),
340 pad,
341 output_padding,
342 options.groups(),
343 options.dilation());
344 }
345
346 template class ConvTransposeNdImpl<1, ConvTranspose1dImpl>;
347 template class ConvTransposeNdImpl<2, ConvTranspose2dImpl>;
348 template class ConvTransposeNdImpl<3, ConvTranspose3dImpl>;
349
350 } // namespace nn
351 } // namespace torch
352