xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ConvolutionMM2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/div_rtn.h>
7 #include <ATen/native/ConvUtils.h>
8 #include <ATen/native/CPUBlas.h>
9 #include <ATen/native/Unfold2d.h>
10 #include <c10/util/irange.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_slow_conv2d_backward_native.h>
17 #include <ATen/ops/_slow_conv2d_forward.h>
18 #include <ATen/ops/_slow_conv2d_forward_native.h>
19 #include <ATen/ops/empty.h>
20 #include <ATen/ops/sum.h>
21 #include <ATen/ops/thnn_conv2d_native.h>
22 #endif
23 
24 namespace at::native {
25 
26 namespace {
27 
compute_columns2d(const Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef kernel_size,bool is_channels_last)28 static Tensor compute_columns2d(
29     const Tensor& input,
30     IntArrayRef padding,
31     IntArrayRef stride,
32     IntArrayRef kernel_size,
33     bool is_channels_last) {
34   const int64_t kernel_height = kernel_size[0];
35   const int64_t kernel_width = kernel_size[1];
36   const int64_t pad_height = padding[0];
37   const int64_t pad_width = padding[1];
38   const int64_t stride_height = stride[0];
39   const int64_t stride_width = stride[1];
40   const int64_t batch_size = input.size(0);
41   const int64_t n_input_plane = input.size(1);
42   const int64_t input_height = input.size(2);
43   const int64_t input_width = input.size(3);
44   const int64_t output_height = (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
45   const int64_t output_width =  (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
46 
47   Tensor columns;
48   if ((kernel_height == 1) && (stride_height == 1) && (pad_height == 0) &&
49       (kernel_width == 1) && (stride_width == 1) && (pad_width == 0)) {
50     // Columns are just a view on the input for the 1x1 kernel special case.
51     if (is_channels_last) {
52       columns = input.as_strided({batch_size, output_height * output_width, n_input_plane},
53           {output_height * output_width * n_input_plane, n_input_plane, 1}).detach();
54     } else {
55       columns = input.view({batch_size, n_input_plane, output_height * output_width}).detach();
56     }
57   } else {
58     int64_t row = is_channels_last ?
59         output_height * output_width : n_input_plane * kernel_height * kernel_width;
60     int64_t col = is_channels_last ?
61         kernel_height * kernel_width * n_input_plane : output_height * output_width;
62     columns = at::empty({batch_size, row, col}, input.options());
63     AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "slow_conv2d_cpu", [&]{
64       auto input_a = input.accessor<const scalar_t, 4>();
65       auto columns_a = columns.accessor<scalar_t, 3>();
66 
67       at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
68         for (const auto t : c10::irange(start, end)) {
69           auto input_t = input_a[t];
70           auto columns_t = columns_a[t];
71           unfolded2d_copy_stub(
72               kCPU,
73               c10::CppTypeToScalarType<scalar_t>::value,
74               columns_t.data(),
75               input_t.data(),
76               kernel_height,
77               kernel_width,
78               stride_height,
79               stride_width,
80               pad_height,
81               pad_width,
82               n_input_plane,
83               input_height,
84               input_width,
85               output_height,
86               output_width,
87               is_channels_last);
88         }
89       });
90     });
91   }
92 
93   return columns.contiguous();
94 }
95 
slow_conv2d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int64_t kernel_height,int64_t kernel_width,int64_t stride_height,int64_t stride_width,int64_t pad_height,int64_t pad_width,bool weight_optional)96 static inline void slow_conv2d_shape_check(
97     const Tensor& input,
98     const Tensor& grad_output,
99     const Tensor& weight,
100     const Tensor& bias,
101     int64_t kernel_height,
102     int64_t kernel_width,
103     int64_t stride_height,
104     int64_t stride_width,
105     int64_t pad_height,
106     int64_t pad_width,
107     bool weight_optional) {
108   TORCH_CHECK(
109       kernel_width > 0 && kernel_height > 0,
110       "kernel size should be greater than zero, but got kernel_height: ",
111       kernel_height,
112       " kernel_width: ",
113       kernel_width);
114   TORCH_CHECK(
115       stride_width > 0 && stride_height > 0,
116       "stride should be greater than zero, but got stride_height: ",
117       stride_height,
118       " stride_width: ",
119       stride_width);
120 
121   if (weight.defined()) {
122     TORCH_CHECK(
123         weight.numel() > 0 && (weight.dim() == 2 || weight.dim() == 4),
124         "non-empty 2D or 4D weight tensor expected, but got: ",
125         weight.sizes());
126     if (bias.defined()) {
127       check_dim_size(bias, 1, 0, weight.size(0));
128     }
129   } else {
130     TORCH_CHECK(weight_optional, "weight tensor is undefined");
131   }
132 
133   const int64_t ndim = input.dim();
134   const int64_t dim_planes = 1;
135   const int64_t dim_height = 2;
136   const int64_t dim_width = 3;
137 
138   // Allow for empty batch size and channel size but not other dimensions
139   TORCH_CHECK(ndim == 4, "Expected 4D input tensor, but got: ", input.sizes());
140   for (const auto dim : c10::irange(2, ndim)) {
141     TORCH_CHECK(input.size(dim) != 0,
142                 "Expected non-zero size for input dimension ", dim,
143                 ", but got input shape: ", input.sizes(), ". Only the batch and channel dimensions support size 0.");
144   }
145 
146   const int64_t input_height = input.size(dim_height);
147   const int64_t input_width = input.size(dim_width);
148 
149   const int64_t exact_input_height = input_height + 2 * pad_height;
150   const int64_t exact_input_width = input_width + 2 * pad_width;
151 
152   TORCH_CHECK(
153       exact_input_height >= kernel_height && exact_input_width >= kernel_width,
154       "Calculated padded input size per channel: (",
155       exact_input_height,
156       " x ",
157       exact_input_width,
158       "). ",
159       "Kernel size: (",
160       kernel_height,
161       " x ",
162       kernel_width,
163       "). Kernel size can't be greater than actual input size");
164 
165   const int64_t output_height =
166       div_rtn<int64_t>(exact_input_height - kernel_height, stride_height) + 1;
167   const int64_t output_width =
168       div_rtn<int64_t>(exact_input_width - kernel_width, stride_width) + 1;
169 
170   TORCH_CHECK(
171       output_width >= 1 && output_height >= 1,
172       "Given input size per channel: (",
173       input_height,
174       " x ",
175       input_width,
176       "). "
177       "Calculated output size per channel: (",
178       output_height,
179       " x ",
180       output_width,
181       "). Output size is too small");
182 
183   if (weight.defined()) {
184     int64_t n_input_plane = weight.size(1);
185     if (weight.dim() == 2) {
186       n_input_plane /= (kernel_height * kernel_width);
187     }
188     if (input.size(1) != 0) {
189       check_dim_size(input, ndim, dim_planes, n_input_plane);
190     }
191   }
192 
193   if (grad_output.defined()) {
194     if (weight.defined()) {
195       int64_t n_output_plane = weight.size(0);
196       check_dim_size(grad_output, ndim, dim_planes, n_output_plane);
197     } else if (bias.defined()) {
198       TORCH_CHECK(bias.numel() > 0, "non-empty bias tensor expected");
199       const int64_t n_output_plane = bias.dim() == 0 ? 1 : bias.size(0);
200       check_dim_size(grad_output, ndim, dim_planes, n_output_plane);
201     }
202     check_dim_size(grad_output, ndim, dim_height, output_height);
203     check_dim_size(grad_output, ndim, dim_width, output_width);
204   }
205 }
206 
view_weight_2d(const Tensor & weight_,at::MemoryFormat memory_format=at::MemoryFormat::Contiguous)207 static inline Tensor view_weight_2d(const Tensor& weight_,
208     at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) {
209   Tensor weight = weight_.contiguous(memory_format);
210   if (weight.dim() == 4) {
211     const int64_t s1 = weight.size(0);
212     const int64_t s2 = weight.size(1) * weight.size(2) * weight.size(3);
213     return memory_format == at::MemoryFormat::ChannelsLast
214         ? weight.as_strided({s1, s2}, {s2, 1}) // CL: view as {oc, kh*kw*ic}
215         : weight.view({s1, s2}); // CF: view as {oc, ic*kh*kw}
216   } else {
217     return weight;
218   }
219 }
220 
221 template <typename scalar_t>
slow_conv2d_update_output_frame(TensorAccessor<const scalar_t,3> input,TensorAccessor<scalar_t,3> output,TensorAccessor<const scalar_t,2> weight,bool has_bias,TensorAccessor<scalar_t,2> finput,int64_t kernel_height,int64_t kernel_width,int64_t stride_height,int64_t stride_width,int64_t pad_height,int64_t pad_width,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t n_output_plane,int64_t output_height,int64_t output_width,bool is_channels_last)222 static void slow_conv2d_update_output_frame(
223     TensorAccessor<const scalar_t, 3> input,
224     TensorAccessor<scalar_t, 3> output,
225     TensorAccessor<const scalar_t, 2> weight,
226     bool has_bias,
227     TensorAccessor<scalar_t, 2> finput,
228     int64_t kernel_height,
229     int64_t kernel_width,
230     int64_t stride_height,
231     int64_t stride_width,
232     int64_t pad_height,
233     int64_t pad_width,
234     int64_t n_input_plane,
235     int64_t input_height,
236     int64_t input_width,
237     int64_t n_output_plane,
238     int64_t output_height,
239     int64_t output_width,
240     bool is_channels_last) {
241   const int beta = has_bias ? 1 : 0;
242 
243   // Compute out = weight * input
244   // Note gemm expects fortran order, so all 3 matrices are transposed.
245   // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
246   if (is_channels_last) {
247     const int64_t m = n_output_plane;
248     const int64_t n = output_height * output_width;
249     const int64_t k = n_input_plane * kernel_height * kernel_width;
250 
251     const int64_t lda = k;
252     const int64_t ldb = k;
253     const int64_t ldc = m;
254 
255     at::native::cpublas::gemm(
256         TransposeType::Transpose,
257         TransposeType::NoTranspose,
258         m, n, k,
259         static_cast<scalar_t>(1),
260         weight.data(), lda,
261         finput.data(), ldb,
262         static_cast<scalar_t>(beta),
263         output.data(), ldc);
264   } else {
265     const int64_t m = output_height * output_width;
266     const int64_t n = n_output_plane;
267     const int64_t k = n_input_plane * kernel_height * kernel_width;
268 
269     const int64_t lda = m;
270     const int64_t ldb = k;
271     const int64_t ldc = m;
272 
273     at::native::cpublas::gemm(
274         TransposeType::NoTranspose,
275         TransposeType::NoTranspose,
276         m, n, k,
277         static_cast<scalar_t>(1),
278         finput.data(), lda,
279         weight.data(), ldb,
280         static_cast<scalar_t>(beta),
281         output.data(), ldc);
282   }
283 }
284 
285 template <typename scalar_t>
slow_conv2d_backward_update_grad_input_frame(TensorAccessor<scalar_t,3> grad_input,TensorAccessor<const scalar_t,3> grad_output,TensorAccessor<const scalar_t,2> weight,scalar_t * fgrad_input,int64_t kernel_height,int64_t kernel_width,int64_t stride_height,int64_t stride_width,int64_t pad_height,int64_t pad_width,bool is_channels_last)286 void slow_conv2d_backward_update_grad_input_frame(
287     TensorAccessor<scalar_t, 3> grad_input,
288     TensorAccessor<const scalar_t, 3> grad_output,
289     TensorAccessor<const scalar_t, 2> weight,
290     scalar_t *fgrad_input,
291     int64_t kernel_height,
292     int64_t kernel_width,
293     int64_t stride_height,
294     int64_t stride_width,
295     int64_t pad_height,
296     int64_t pad_width,
297     bool is_channels_last) {
298   // Compute fgrad_input = weight.T * grad_output.reshape({grad_output.shape(0), -1})
299   // Note gemm expects fortran order, so all 3 matrices are transposed.
300   // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
301   if (is_channels_last) {
302     const int64_t m = weight.size(1);
303     const int64_t n = grad_output.size(1) * grad_output.size(2);
304     const int64_t k = weight.size(0);
305 
306     const int64_t lda = m;
307     const int64_t ldb = k;
308     const int64_t ldc = m;
309 
310     at::native::cpublas::gemm(
311         TransposeType::NoTranspose,
312         TransposeType::NoTranspose,
313         m, n, k,
314         static_cast<scalar_t>(1),
315         weight.data(), lda,
316         grad_output.data(), ldb,
317         static_cast<scalar_t>(0),
318         fgrad_input, ldc);
319   } else {
320     const int64_t m = grad_output.size(1) * grad_output.size(2);
321     const int64_t n = weight.size(1);
322     const int64_t k = weight.size(0);
323 
324     const int64_t lda = m;
325     const int64_t ldb = n;
326     const int64_t ldc = m;
327 
328     at::native::cpublas::gemm(
329         TransposeType::NoTranspose,
330         TransposeType::Transpose,
331         m, n, k,
332         static_cast<scalar_t>(1),
333         grad_output.data(), lda,
334         weight.data(), ldb,
335         static_cast<scalar_t>(0),
336         fgrad_input, ldc);
337   }
338 
339   unfolded2d_acc_stub(
340       kCPU,
341       c10::CppTypeToScalarType<scalar_t>::value,
342       fgrad_input,
343       grad_input.data(),
344       kernel_height,
345       kernel_width,
346       stride_height,
347       stride_width,
348       pad_height,
349       pad_width,
350       grad_input.size(0),
351       grad_input.size(1),
352       grad_input.size(2),
353       grad_output.size(1),
354       grad_output.size(2),
355       is_channels_last);
356 }
357 
slow_conv2d_backward_out_cpu_template(Tensor & grad_input,const Tensor & grad_output_,const Tensor & input_,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding)358 void slow_conv2d_backward_out_cpu_template(
359     Tensor& grad_input,
360     const Tensor& grad_output_,
361     const Tensor& input_,
362     const Tensor& weight_,
363     IntArrayRef kernel_size,
364     IntArrayRef stride,
365     IntArrayRef padding) {
366   const int64_t kernel_height = kernel_size[0];
367   const int64_t kernel_width = kernel_size[1];
368   const int64_t pad_height = padding[0];
369   const int64_t pad_width = padding[1];
370   const int64_t stride_height = stride[0];
371   const int64_t stride_width = stride[1];
372 
373   bool use_channels_last = thnn_conv_use_channels_last(input_, weight_);
374   auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
375 
376   const Tensor weight = view_weight_2d(weight_, memory_format);
377   slow_conv2d_shape_check(
378       input_,
379       grad_output_,
380       weight,
381       Tensor(),
382       kernel_height,
383       kernel_width,
384       stride_height,
385       stride_width,
386       pad_height,
387       pad_width,
388       false);
389 
390   const Tensor input = input_.contiguous(memory_format);
391 
392   // Compute shape of columnized data excluding batch dim.
393   const int64_t batch_size = input.size(0);
394   const int64_t n_input_plane = input.size(1);
395   const int64_t input_height = input.size(2);
396   const int64_t input_width = input.size(3);
397   const int64_t output_height = (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
398   const int64_t output_width = (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
399   const int64_t fgrad_input_size = n_input_plane * kernel_height * kernel_width * output_height * output_width;
400 
401   const Tensor grad_output = grad_output_.contiguous(memory_format);
402   grad_input.resize_as_(input, memory_format);
403   grad_input.zero_();
404   TORCH_CHECK(grad_input.is_contiguous(memory_format), "slow_conv2d: grad_input must be contiguous");
405 
406   AT_DISPATCH_FLOATING_TYPES_AND2(
407       kBFloat16, kHalf, input.scalar_type(), "slow_conv2d_cpu_grad_input", [&] {
408     auto grad_output_a = grad_output.accessor<const scalar_t, 4>();
409     auto grad_input_a = grad_input.accessor<scalar_t, 4>();
410     auto weight_a = weight.accessor<const scalar_t, 2>();
411 
412     at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
413       auto fgrad_input = std::make_unique<scalar_t[]>(fgrad_input_size);
414       for (const auto t : c10::irange(start, end)) {
415         auto grad_input_t = grad_input_a[t];
416         auto grad_output_t = grad_output_a[t];
417         slow_conv2d_backward_update_grad_input_frame(
418             grad_input_t,
419             grad_output_t,
420             weight_a,
421             fgrad_input.get(),
422             kernel_height,
423             kernel_width,
424             stride_height,
425             stride_width,
426             pad_height,
427             pad_width,
428             use_channels_last);
429       }
430     });
431   });
432 }
433 
434 template <typename scalar_t>
slow_conv2d_backward_weight_frame(TensorAccessor<scalar_t,2> grad_weight,TensorAccessor<const scalar_t,3> grad_output,TensorAccessor<const scalar_t,2> finput,bool is_channels_last)435 void slow_conv2d_backward_weight_frame(
436     TensorAccessor<scalar_t, 2> grad_weight,
437     TensorAccessor<const scalar_t, 3> grad_output,
438     TensorAccessor<const scalar_t, 2> finput,
439     bool is_channels_last) {
440   // Compute grad_weight += grad_output.reshape({grad_output.shape(0), -1}) * finput.T
441   // Note gemm expects fortran order, so all 3 matrices are transposed.
442   // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
443   if (is_channels_last) {
444     const int64_t m = finput.size(1);
445     const int64_t n = grad_output.size(0);
446     const int64_t k = grad_output.size(1) * grad_output.size(2);
447 
448     const int64_t lda = m;
449     const int64_t ldb = n;
450     const int64_t ldc = m;
451 
452     at::native::cpublas::gemm(
453         TransposeType::NoTranspose,
454         TransposeType::Transpose,
455         m, n, k,
456         static_cast<scalar_t>(1),
457         finput.data(), lda,
458         grad_output.data(), ldb,
459         static_cast<scalar_t>(1),
460         grad_weight.data(), ldc);
461   } else {
462     const int64_t m = finput.size(0);
463     const int64_t n = grad_output.size(0);
464     const int64_t k = grad_output.size(1) * grad_output.size(2);
465 
466     const int64_t lda = k;
467     const int64_t ldb = k;
468     const int64_t ldc = m;
469 
470     at::native::cpublas::gemm(
471         TransposeType::Transpose,
472         TransposeType::NoTranspose,
473         m, n, k,
474         static_cast<scalar_t>(1),
475         finput.data(), lda,
476         grad_output.data(), ldb,
477         static_cast<scalar_t>(1),
478         grad_weight.data(), ldc);
479   }
480 }
481 
slow_conv2d_backward_weight_out_cpu_template(Tensor & grad_weight,const Tensor & input,const Tensor & grad_output_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding)482 static void slow_conv2d_backward_weight_out_cpu_template(
483     Tensor& grad_weight,
484     const Tensor& input,
485     const Tensor& grad_output_,
486     IntArrayRef kernel_size,
487     IntArrayRef stride,
488     IntArrayRef padding) {
489   const int64_t kernel_height = kernel_size[0];
490   const int64_t kernel_width = kernel_size[1];
491   const int64_t pad_height = padding[0];
492   const int64_t pad_width = padding[1];
493   const int64_t stride_height = stride[0];
494   const int64_t stride_width = stride[1];
495 
496   bool use_channels_last = thnn_conv_use_channels_last(input, grad_weight);
497   auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
498 
499   TORCH_CHECK(grad_weight.is_contiguous(memory_format), "slow_conv2d: grad_weight must be contiguous");
500   Tensor grad_weight_2d = view_weight_2d(grad_weight, memory_format);
501 
502   slow_conv2d_shape_check(
503       input,
504       grad_output_,
505       grad_weight_2d,
506       {},
507       kernel_height,
508       kernel_width,
509       stride_height,
510       stride_width,
511       pad_height,
512       pad_width,
513       true);
514 
515   auto grad_output = grad_output_.contiguous(memory_format);
516   Tensor finput = compute_columns2d(input, padding, stride, kernel_size, use_channels_last);
517 
518   const int64_t batch_size = input.size(0);
519 
520   AT_DISPATCH_FLOATING_TYPES_AND2(
521       kBFloat16, kHalf, input.scalar_type(), "slow_conv2d_cpu_grad_weight", [&] {
522     auto grad_output_a = grad_output.accessor<const scalar_t, 4>();
523     auto grad_weight_2d_a = grad_weight_2d.accessor<scalar_t, 2>();
524     auto finput_a = finput.accessor<const scalar_t, 3>();
525 
526     for (const auto t : c10::irange(batch_size)) {
527       auto grad_output_t = grad_output_a[t];
528       auto finput_t = finput_a[t];
529 
530       slow_conv2d_backward_weight_frame(
531           grad_weight_2d_a, grad_output_t, finput_t, use_channels_last);
532     }
533   });
534 }
535 
536 } // namespace
537 
slow_conv2d_forward_out_cpu(const Tensor & self,const Tensor & weight_,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,Tensor & output)538 Tensor& slow_conv2d_forward_out_cpu(
539     const Tensor& self,
540     const Tensor& weight_,
541     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
542     IntArrayRef stride,
543     IntArrayRef padding,
544     Tensor& output) {
545   // See [Note: hacky wrapper removal for optional tensor]
546 
547   TORCH_CHECK(kernel_size.size() == 2, "2D kernel_size expected");
548   TORCH_CHECK(stride.size() == 2, "2D stride expected");
549   TORCH_CHECK(padding.size() == 2, "2D padding expected");
550 
551   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
552   const Tensor& bias = *bias_maybe_owned;
553 
554   const int64_t kernel_height = kernel_size[0];
555   const int64_t kernel_width = kernel_size[1];
556   const int64_t pad_height = padding[0];
557   const int64_t pad_width = padding[1];
558   const int64_t stride_height = stride[0];
559   const int64_t stride_width = stride[1];
560 
561   bool use_channels_last = thnn_conv_use_channels_last(self, weight_);
562   auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
563 
564   const Tensor weight_2d = view_weight_2d(weight_, memory_format);
565 
566   slow_conv2d_shape_check(
567       self,
568       Tensor(),
569       weight_2d,
570       bias,
571       kernel_height,
572       kernel_width,
573       stride_height,
574       stride_width,
575       pad_height,
576       pad_width,
577       false);
578 
579   const Tensor input = self.contiguous(memory_format);
580   const int64_t batch_size = input.size(0);
581   const int64_t n_input_plane = input.size(1);
582   const int64_t input_height = input.size(2);
583   const int64_t input_width = input.size(3);
584   const int64_t n_output_plane = weight_2d.size(0);
585   const int64_t output_height = (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
586   const int64_t output_width = (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
587 
588   Tensor finput = compute_columns2d(input, padding, stride, kernel_size, use_channels_last);
589   output.resize_({batch_size, n_output_plane, output_height, output_width}, memory_format);
590   if (bias.defined()) {
591     output.copy_(bias.reshape({-1, 1, 1}));
592   }
593   TORCH_CHECK(output.is_contiguous(memory_format), "slow_conv2d output tensor must be contiguous");
594 
595   AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "slow_conv2d_cpu", [&]{
596     auto input_a = input.accessor<const scalar_t, 4>();
597     auto output_a = output.accessor<scalar_t, 4>();
598     auto finput_a = finput.accessor<scalar_t, 3>();
599     auto weight_2d_a = weight_2d.accessor<const scalar_t, 2>();
600 
601     at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
602       for (const auto t : c10::irange(start, end)) {
603         auto input_t = input_a[t];
604         auto output_t = output_a[t];
605         auto finput_t = finput_a[t];
606         slow_conv2d_update_output_frame(
607             input_t,
608             output_t,
609             weight_2d_a,
610             bias.defined(),
611             finput_t,
612             kernel_height,
613             kernel_width,
614             stride_height,
615             stride_width,
616             pad_height,
617             pad_width,
618             n_input_plane,
619             input_height,
620             input_width,
621             n_output_plane,
622             output_height,
623             output_width,
624             use_channels_last);
625       }
626     });
627   });
628 
629   return output;
630 }
631 
slow_conv2d_forward_cpu(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding)632 Tensor slow_conv2d_forward_cpu(
633     const Tensor& self,
634     const Tensor& weight,
635     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
636     IntArrayRef stride,
637     IntArrayRef padding) {
638   // See [Note: hacky wrapper removal for optional tensor]
639   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
640   const Tensor& bias = *bias_maybe_owned;
641 
642   auto output = at::empty({0}, self.options());
643   at::native::slow_conv2d_forward_out_cpu(
644       self,
645       weight,
646       kernel_size,
647       bias,
648       stride,
649       padding,
650       output);
651 
652   return output;
653 }
654 
slow_conv2d_backward_out_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)655 std::tuple<Tensor&, Tensor&, Tensor&> slow_conv2d_backward_out_cpu(
656     const Tensor& grad_output,
657     const Tensor& self,
658     const Tensor& weight,
659     IntArrayRef kernel_size,
660     IntArrayRef stride,
661     IntArrayRef padding,
662     Tensor& grad_input,
663     Tensor& grad_weight,
664     Tensor& grad_bias) {
665   if (grad_input.defined()) {
666     slow_conv2d_backward_out_cpu_template(
667         grad_input,
668         grad_output,
669         self,
670         weight,
671         kernel_size,
672         stride,
673         padding);
674   }
675 
676   if (grad_bias.defined()) {
677     at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3});
678   }
679 
680   if (grad_weight.defined()) {
681     grad_weight.resize_(weight.sizes(), weight.suggest_memory_format());
682     grad_weight.zero_();
683     slow_conv2d_backward_weight_out_cpu_template(
684         grad_weight,
685         self,
686         grad_output,
687         kernel_size,
688         stride,
689         padding);
690   }
691 
692   return std::tuple<Tensor&, Tensor&, Tensor&>(
693       grad_input, grad_weight, grad_bias);
694 }
695 
slow_conv2d_backward_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,std::array<bool,3> output_mask)696 std::tuple<Tensor, Tensor, Tensor> slow_conv2d_backward_cpu(
697     const Tensor& grad_output,
698     const Tensor& self,
699     const Tensor& weight,
700     IntArrayRef kernel_size,
701     IntArrayRef stride,
702     IntArrayRef padding,
703     std::array<bool, 3> output_mask) {
704   Tensor grad_input;
705   Tensor grad_weight;
706   Tensor grad_bias;
707 
708   if (output_mask[0]) {
709     grad_input = at::empty({0}, grad_output.options());
710   }
711 
712   if (output_mask[1]) {
713     grad_weight = at::empty({0}, grad_output.options());
714   }
715 
716   if (output_mask[2]) {
717     grad_bias = at::empty({0}, grad_output.options());
718   }
719 
720   at::native::slow_conv2d_backward_out_cpu(
721       grad_output,
722       self,
723       weight,
724       kernel_size,
725       stride,
726       padding,
727       grad_input,
728       grad_weight,
729       grad_bias);
730 
731   return std::make_tuple(grad_input, grad_weight, grad_bias);
732 }
733 
thnn_conv2d_out(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,Tensor & output)734 Tensor & thnn_conv2d_out(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt, IntArrayRef stride, IntArrayRef padding, Tensor & output) {
735   // See [Note: hacky wrapper removal for optional tensor]
736   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
737   const Tensor& bias = *bias_maybe_owned;
738 
739   return at::_slow_conv2d_forward_out(output, self, weight, kernel_size, bias, stride, padding);
740 }
741 
thnn_conv2d(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding)742 Tensor thnn_conv2d(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt, IntArrayRef stride, IntArrayRef padding) {
743   // See [Note: hacky wrapper removal for optional tensor]
744   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
745   const Tensor& bias = *bias_maybe_owned;
746 
747   return at::_slow_conv2d_forward(self, weight, kernel_size, bias, stride, padding);
748 }
749 
750 } // namespace at::native
751