xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/conv.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/functional/conv.h>
2 #include <torch/nn/functional/padding.h>
3 #include <torch/nn/modules/conv.h>
4 
5 #include <c10/util/irange.h>
6 #include <torch/enum.h>
7 #include <torch/expanding_array.h>
8 #include <torch/nn/init.h>
9 #include <torch/types.h>
10 #include <torch/utils.h>
11 
12 #include <cmath>
13 #include <cstdint>
14 #include <functional>
15 #include <utility>
16 #include <vector>
17 
18 namespace F = torch::nn::functional;
19 
_get_pad_mode_from_conv_padding_mode(torch::nn::detail::conv_padding_mode_t conv_padding_mode)20 static F::PadFuncOptions::mode_t _get_pad_mode_from_conv_padding_mode(
21     torch::nn::detail::conv_padding_mode_t conv_padding_mode) {
22   F::PadFuncOptions::mode_t pad_mode;
23   if (std::holds_alternative<torch::enumtype::kReflect>(conv_padding_mode)) {
24     pad_mode = torch::kReflect;
25   } else if (std::holds_alternative<torch::enumtype::kReplicate>(
26                  conv_padding_mode)) {
27     pad_mode = torch::kReplicate;
28   } else if (std::holds_alternative<torch::enumtype::kCircular>(
29                  conv_padding_mode)) {
30     pad_mode = torch::kCircular;
31   } else {
32     TORCH_CHECK(
33         false,
34         "Unsupported conv padding mode: ",
35         torch::enumtype::get_enum_name(conv_padding_mode));
36   }
37   return pad_mode;
38 }
39 
40 namespace torch {
41 namespace nn {
Conv1dImpl(Conv1dOptions options_)42 Conv1dImpl::Conv1dImpl(Conv1dOptions options_)
43     : ConvNdImpl(detail::ConvNdOptions<1>(
44                      /*in_channels=*/options_.in_channels(),
45                      /*out_channels=*/options_.out_channels(),
46                      /*kernel_size=*/options_.kernel_size())
47                      .stride(options_.stride())
48                      .padding(options_.padding())
49                      .dilation(options_.dilation())
50                      .transposed(false)
51                      .output_padding(0)
52                      .groups(options_.groups())
53                      .bias(options_.bias())
54                      .padding_mode(options_.padding_mode())) {}
55 
forward(const Tensor & input)56 Tensor Conv1dImpl::forward(const Tensor& input) {
57   if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
58     return F::detail::conv1d(
59         F::pad(
60             input,
61             F::PadFuncOptions(_reversed_padding_repeated_twice)
62                 .mode(_get_pad_mode_from_conv_padding_mode(
63                     options.padding_mode()))),
64         weight,
65         bias,
66         options.stride(),
67         /*padding=*/0,
68         options.dilation(),
69         options.groups());
70   }
71   return F::detail::conv1d(
72       input,
73       weight,
74       bias,
75       options.stride(),
76       options.padding(),
77       options.dilation(),
78       options.groups());
79 }
80 
Conv2dImpl(Conv2dOptions options_)81 Conv2dImpl::Conv2dImpl(Conv2dOptions options_)
82     : ConvNdImpl(detail::ConvNdOptions<2>(
83                      /*in_channels=*/options_.in_channels(),
84                      /*out_channels=*/options_.out_channels(),
85                      /*kernel_size=*/options_.kernel_size())
86                      .stride(options_.stride())
87                      .padding(options_.padding())
88                      .dilation(options_.dilation())
89                      .transposed(false)
90                      .output_padding(0)
91                      .groups(options_.groups())
92                      .bias(options_.bias())
93                      .padding_mode(options_.padding_mode())) {}
94 
_conv_forward(const Tensor & input,const Tensor & weight)95 Tensor Conv2dImpl::_conv_forward(const Tensor& input, const Tensor& weight) {
96   if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
97     return F::detail::conv2d(
98         F::pad(
99             input,
100             F::PadFuncOptions(_reversed_padding_repeated_twice)
101                 .mode(_get_pad_mode_from_conv_padding_mode(
102                     options.padding_mode()))),
103         weight,
104         bias,
105         options.stride(),
106         /*padding=*/0,
107         options.dilation(),
108         options.groups());
109   }
110   return F::detail::conv2d(
111       input,
112       weight,
113       bias,
114       options.stride(),
115       options.padding(),
116       options.dilation(),
117       options.groups());
118 }
119 
forward(const Tensor & input)120 Tensor Conv2dImpl::forward(const Tensor& input) {
121   return _conv_forward(input, weight);
122 }
123 
Conv3dImpl(Conv3dOptions options_)124 Conv3dImpl::Conv3dImpl(Conv3dOptions options_)
125     : ConvNdImpl(detail::ConvNdOptions<3>(
126                      /*in_channels=*/options_.in_channels(),
127                      /*out_channels=*/options_.out_channels(),
128                      /*kernel_size=*/options_.kernel_size())
129                      .stride(options_.stride())
130                      .padding(options_.padding())
131                      .dilation(options_.dilation())
132                      .transposed(false)
133                      .output_padding(0)
134                      .groups(options_.groups())
135                      .bias(options_.bias())
136                      .padding_mode(options_.padding_mode())) {}
137 
forward(const Tensor & input)138 Tensor Conv3dImpl::forward(const Tensor& input) {
139   if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
140     return F::detail::conv3d(
141         F::pad(
142             input,
143             F::PadFuncOptions(_reversed_padding_repeated_twice)
144                 .mode(_get_pad_mode_from_conv_padding_mode(
145                     options.padding_mode()))),
146         weight,
147         bias,
148         options.stride(),
149         /*padding=*/0,
150         options.dilation(),
151         options.groups());
152   }
153   return F::detail::conv3d(
154       input,
155       weight,
156       bias,
157       options.stride(),
158       options.padding(),
159       options.dilation(),
160       options.groups());
161 }
162 
163 template class ConvNdImpl<1, Conv1dImpl>;
164 template class ConvNdImpl<2, Conv2dImpl>;
165 template class ConvNdImpl<3, Conv3dImpl>;
166 
167 // ============================================================================
168 
169 template <size_t D, typename Derived>
_output_padding(const Tensor & input,const std::optional<at::IntArrayRef> & output_size,const ExpandingArray<D> & stride,const ExpandingArray<D> & padding,const ExpandingArray<D> & kernel_size)170 std::vector<int64_t> ConvTransposeNdImpl<D, Derived>::_output_padding(
171     const Tensor& input,
172     const std::optional<at::IntArrayRef>& output_size,
173     const ExpandingArray<D>& stride,
174     const ExpandingArray<D>& padding,
175     const ExpandingArray<D>& kernel_size) {
176   std::vector<int64_t> ret;
177   std::optional<at::IntArrayRef> output_size_ = output_size;
178 
179   if (output_size_ == std::nullopt) {
180     ret = at::IntArrayRef(this->options.output_padding()).vec();
181   } else {
182     auto k = input.dim() - 2;
183     if (output_size_.value().size() == static_cast<size_t>(k + 2)) {
184       output_size_ = output_size_.value().slice(2);
185     }
186     if (output_size_.value().size() != static_cast<size_t>(k)) {
187       TORCH_CHECK(
188           false,
189           "output_size must have ",
190           k,
191           " or ",
192           k + 2,
193           " elements (got ",
194           output_size_.value().size(),
195           ")");
196     }
197 
198     std::vector<int64_t> min_sizes;
199     std::vector<int64_t> max_sizes;
200     for (const auto d : c10::irange(k)) {
201       int64_t dim_size =
202           ((input.sizes()[d + 2] - 1) * (*stride)[d] - 2 * (*padding)[d] +
203            (*kernel_size)[d]);
204       min_sizes.push_back(dim_size);
205       max_sizes.push_back(min_sizes[d] + (*stride)[d] - 1);
206     }
207 
208     for (const auto i : c10::irange(output_size_.value().size())) {
209       int64_t size = output_size_.value()[i];
210       int64_t min_size = min_sizes[i];
211       int64_t max_size = max_sizes[i];
212       if (size < min_size || size > max_size) {
213         TORCH_CHECK(
214             false,
215             "requested an output size of ",
216             output_size_.value(),
217             ", but valid sizes range "
218             "from ",
219             min_sizes,
220             " to ",
221             max_sizes,
222             " (for an input of ",
223             input.sizes().slice(2),
224             ")");
225       }
226     }
227 
228     for (const auto d : c10::irange(k)) {
229       ret.push_back(output_size_.value()[d] - min_sizes[d]);
230     }
231   }
232   return ret;
233 }
234 
ConvTranspose1dImpl(ConvTranspose1dOptions options_)235 ConvTranspose1dImpl::ConvTranspose1dImpl(ConvTranspose1dOptions options_)
236     : ConvTransposeNdImpl(detail::ConvNdOptions<1>(
237                               /*in_channels=*/options_.in_channels(),
238                               /*out_channels=*/options_.out_channels(),
239                               /*kernel_size=*/options_.kernel_size())
240                               .stride(options_.stride())
241                               .padding(options_.padding())
242                               .dilation(options_.dilation())
243                               .transposed(true)
244                               .output_padding(options_.output_padding())
245                               .groups(options_.groups())
246                               .bias(options_.bias())
247                               .padding_mode(options_.padding_mode())) {}
248 
forward(const Tensor & input,const std::optional<at::IntArrayRef> & output_size)249 Tensor ConvTranspose1dImpl::forward(
250     const Tensor& input,
251     const std::optional<at::IntArrayRef>& output_size) {
252   if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
253     TORCH_CHECK(
254         false, "Only `zeros` padding mode is supported for ConvTranspose1d");
255   }
256 
257   const auto& pad = padding();
258   std::vector<int64_t> output_padding = _output_padding(
259       input, output_size, options.stride(), pad, options.kernel_size());
260 
261   return F::detail::conv_transpose1d(
262       input,
263       weight,
264       bias,
265       options.stride(),
266       pad,
267       output_padding,
268       options.groups(),
269       options.dilation());
270 }
271 
ConvTranspose2dImpl(ConvTranspose2dOptions options_)272 ConvTranspose2dImpl::ConvTranspose2dImpl(ConvTranspose2dOptions options_)
273     : ConvTransposeNdImpl(detail::ConvNdOptions<2>(
274                               /*in_channels=*/options_.in_channels(),
275                               /*out_channels=*/options_.out_channels(),
276                               /*kernel_size=*/options_.kernel_size())
277                               .stride(options_.stride())
278                               .padding(options_.padding())
279                               .dilation(options_.dilation())
280                               .transposed(true)
281                               .output_padding(options_.output_padding())
282                               .groups(options_.groups())
283                               .bias(options_.bias())
284                               .padding_mode(options_.padding_mode())) {}
285 
forward(const Tensor & input,const std::optional<at::IntArrayRef> & output_size)286 Tensor ConvTranspose2dImpl::forward(
287     const Tensor& input,
288     const std::optional<at::IntArrayRef>& output_size) {
289   if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
290     TORCH_CHECK(
291         false, "Only `zeros` padding mode is supported for ConvTranspose2d");
292   }
293 
294   const auto& pad = padding();
295   std::vector<int64_t> output_padding = _output_padding(
296       input, output_size, options.stride(), pad, options.kernel_size());
297 
298   return F::detail::conv_transpose2d(
299       input,
300       weight,
301       bias,
302       options.stride(),
303       pad,
304       output_padding,
305       options.groups(),
306       options.dilation());
307 }
308 
ConvTranspose3dImpl(ConvTranspose3dOptions options_)309 ConvTranspose3dImpl::ConvTranspose3dImpl(ConvTranspose3dOptions options_)
310     : ConvTransposeNdImpl(detail::ConvNdOptions<3>(
311                               /*in_channels=*/options_.in_channels(),
312                               /*out_channels=*/options_.out_channels(),
313                               /*kernel_size=*/options_.kernel_size())
314                               .stride(options_.stride())
315                               .padding(options_.padding())
316                               .dilation(options_.dilation())
317                               .transposed(true)
318                               .output_padding(options_.output_padding())
319                               .groups(options_.groups())
320                               .bias(options_.bias())
321                               .padding_mode(options_.padding_mode())) {}
322 
forward(const Tensor & input,const std::optional<at::IntArrayRef> & output_size)323 Tensor ConvTranspose3dImpl::forward(
324     const Tensor& input,
325     const std::optional<at::IntArrayRef>& output_size) {
326   if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) {
327     TORCH_CHECK(
328         false, "Only `zeros` padding mode is supported for ConvTranspose3d");
329   }
330 
331   const auto& pad = padding();
332   std::vector<int64_t> output_padding = _output_padding(
333       input, output_size, options.stride(), pad, options.kernel_size());
334 
335   return F::detail::conv_transpose3d(
336       input,
337       weight,
338       bias,
339       options.stride(),
340       pad,
341       output_padding,
342       options.groups(),
343       options.dilation());
344 }
345 
346 template class ConvTransposeNdImpl<1, ConvTranspose1dImpl>;
347 template class ConvTransposeNdImpl<2, ConvTranspose2dImpl>;
348 template class ConvTransposeNdImpl<3, ConvTranspose3dImpl>;
349 
350 } // namespace nn
351 } // namespace torch
352