1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/div_rtn.h>
7 #include <ATen/native/ConvUtils.h>
8 #include <ATen/native/CPUBlas.h>
9 #include <ATen/native/Unfold2d.h>
10 #include <c10/util/irange.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_slow_conv2d_backward_native.h>
17 #include <ATen/ops/_slow_conv2d_forward.h>
18 #include <ATen/ops/_slow_conv2d_forward_native.h>
19 #include <ATen/ops/empty.h>
20 #include <ATen/ops/sum.h>
21 #include <ATen/ops/thnn_conv2d_native.h>
22 #endif
23
24 namespace at::native {
25
26 namespace {
27
compute_columns2d(const Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef kernel_size,bool is_channels_last)28 static Tensor compute_columns2d(
29 const Tensor& input,
30 IntArrayRef padding,
31 IntArrayRef stride,
32 IntArrayRef kernel_size,
33 bool is_channels_last) {
34 const int64_t kernel_height = kernel_size[0];
35 const int64_t kernel_width = kernel_size[1];
36 const int64_t pad_height = padding[0];
37 const int64_t pad_width = padding[1];
38 const int64_t stride_height = stride[0];
39 const int64_t stride_width = stride[1];
40 const int64_t batch_size = input.size(0);
41 const int64_t n_input_plane = input.size(1);
42 const int64_t input_height = input.size(2);
43 const int64_t input_width = input.size(3);
44 const int64_t output_height = (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
45 const int64_t output_width = (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
46
47 Tensor columns;
48 if ((kernel_height == 1) && (stride_height == 1) && (pad_height == 0) &&
49 (kernel_width == 1) && (stride_width == 1) && (pad_width == 0)) {
50 // Columns are just a view on the input for the 1x1 kernel special case.
51 if (is_channels_last) {
52 columns = input.as_strided({batch_size, output_height * output_width, n_input_plane},
53 {output_height * output_width * n_input_plane, n_input_plane, 1}).detach();
54 } else {
55 columns = input.view({batch_size, n_input_plane, output_height * output_width}).detach();
56 }
57 } else {
58 int64_t row = is_channels_last ?
59 output_height * output_width : n_input_plane * kernel_height * kernel_width;
60 int64_t col = is_channels_last ?
61 kernel_height * kernel_width * n_input_plane : output_height * output_width;
62 columns = at::empty({batch_size, row, col}, input.options());
63 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "slow_conv2d_cpu", [&]{
64 auto input_a = input.accessor<const scalar_t, 4>();
65 auto columns_a = columns.accessor<scalar_t, 3>();
66
67 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
68 for (const auto t : c10::irange(start, end)) {
69 auto input_t = input_a[t];
70 auto columns_t = columns_a[t];
71 unfolded2d_copy_stub(
72 kCPU,
73 c10::CppTypeToScalarType<scalar_t>::value,
74 columns_t.data(),
75 input_t.data(),
76 kernel_height,
77 kernel_width,
78 stride_height,
79 stride_width,
80 pad_height,
81 pad_width,
82 n_input_plane,
83 input_height,
84 input_width,
85 output_height,
86 output_width,
87 is_channels_last);
88 }
89 });
90 });
91 }
92
93 return columns.contiguous();
94 }
95
slow_conv2d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int64_t kernel_height,int64_t kernel_width,int64_t stride_height,int64_t stride_width,int64_t pad_height,int64_t pad_width,bool weight_optional)96 static inline void slow_conv2d_shape_check(
97 const Tensor& input,
98 const Tensor& grad_output,
99 const Tensor& weight,
100 const Tensor& bias,
101 int64_t kernel_height,
102 int64_t kernel_width,
103 int64_t stride_height,
104 int64_t stride_width,
105 int64_t pad_height,
106 int64_t pad_width,
107 bool weight_optional) {
108 TORCH_CHECK(
109 kernel_width > 0 && kernel_height > 0,
110 "kernel size should be greater than zero, but got kernel_height: ",
111 kernel_height,
112 " kernel_width: ",
113 kernel_width);
114 TORCH_CHECK(
115 stride_width > 0 && stride_height > 0,
116 "stride should be greater than zero, but got stride_height: ",
117 stride_height,
118 " stride_width: ",
119 stride_width);
120
121 if (weight.defined()) {
122 TORCH_CHECK(
123 weight.numel() > 0 && (weight.dim() == 2 || weight.dim() == 4),
124 "non-empty 2D or 4D weight tensor expected, but got: ",
125 weight.sizes());
126 if (bias.defined()) {
127 check_dim_size(bias, 1, 0, weight.size(0));
128 }
129 } else {
130 TORCH_CHECK(weight_optional, "weight tensor is undefined");
131 }
132
133 const int64_t ndim = input.dim();
134 const int64_t dim_planes = 1;
135 const int64_t dim_height = 2;
136 const int64_t dim_width = 3;
137
138 // Allow for empty batch size and channel size but not other dimensions
139 TORCH_CHECK(ndim == 4, "Expected 4D input tensor, but got: ", input.sizes());
140 for (const auto dim : c10::irange(2, ndim)) {
141 TORCH_CHECK(input.size(dim) != 0,
142 "Expected non-zero size for input dimension ", dim,
143 ", but got input shape: ", input.sizes(), ". Only the batch and channel dimensions support size 0.");
144 }
145
146 const int64_t input_height = input.size(dim_height);
147 const int64_t input_width = input.size(dim_width);
148
149 const int64_t exact_input_height = input_height + 2 * pad_height;
150 const int64_t exact_input_width = input_width + 2 * pad_width;
151
152 TORCH_CHECK(
153 exact_input_height >= kernel_height && exact_input_width >= kernel_width,
154 "Calculated padded input size per channel: (",
155 exact_input_height,
156 " x ",
157 exact_input_width,
158 "). ",
159 "Kernel size: (",
160 kernel_height,
161 " x ",
162 kernel_width,
163 "). Kernel size can't be greater than actual input size");
164
165 const int64_t output_height =
166 div_rtn<int64_t>(exact_input_height - kernel_height, stride_height) + 1;
167 const int64_t output_width =
168 div_rtn<int64_t>(exact_input_width - kernel_width, stride_width) + 1;
169
170 TORCH_CHECK(
171 output_width >= 1 && output_height >= 1,
172 "Given input size per channel: (",
173 input_height,
174 " x ",
175 input_width,
176 "). "
177 "Calculated output size per channel: (",
178 output_height,
179 " x ",
180 output_width,
181 "). Output size is too small");
182
183 if (weight.defined()) {
184 int64_t n_input_plane = weight.size(1);
185 if (weight.dim() == 2) {
186 n_input_plane /= (kernel_height * kernel_width);
187 }
188 if (input.size(1) != 0) {
189 check_dim_size(input, ndim, dim_planes, n_input_plane);
190 }
191 }
192
193 if (grad_output.defined()) {
194 if (weight.defined()) {
195 int64_t n_output_plane = weight.size(0);
196 check_dim_size(grad_output, ndim, dim_planes, n_output_plane);
197 } else if (bias.defined()) {
198 TORCH_CHECK(bias.numel() > 0, "non-empty bias tensor expected");
199 const int64_t n_output_plane = bias.dim() == 0 ? 1 : bias.size(0);
200 check_dim_size(grad_output, ndim, dim_planes, n_output_plane);
201 }
202 check_dim_size(grad_output, ndim, dim_height, output_height);
203 check_dim_size(grad_output, ndim, dim_width, output_width);
204 }
205 }
206
view_weight_2d(const Tensor & weight_,at::MemoryFormat memory_format=at::MemoryFormat::Contiguous)207 static inline Tensor view_weight_2d(const Tensor& weight_,
208 at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) {
209 Tensor weight = weight_.contiguous(memory_format);
210 if (weight.dim() == 4) {
211 const int64_t s1 = weight.size(0);
212 const int64_t s2 = weight.size(1) * weight.size(2) * weight.size(3);
213 return memory_format == at::MemoryFormat::ChannelsLast
214 ? weight.as_strided({s1, s2}, {s2, 1}) // CL: view as {oc, kh*kw*ic}
215 : weight.view({s1, s2}); // CF: view as {oc, ic*kh*kw}
216 } else {
217 return weight;
218 }
219 }
220
221 template <typename scalar_t>
slow_conv2d_update_output_frame(TensorAccessor<const scalar_t,3> input,TensorAccessor<scalar_t,3> output,TensorAccessor<const scalar_t,2> weight,bool has_bias,TensorAccessor<scalar_t,2> finput,int64_t kernel_height,int64_t kernel_width,int64_t stride_height,int64_t stride_width,int64_t pad_height,int64_t pad_width,int64_t n_input_plane,int64_t input_height,int64_t input_width,int64_t n_output_plane,int64_t output_height,int64_t output_width,bool is_channels_last)222 static void slow_conv2d_update_output_frame(
223 TensorAccessor<const scalar_t, 3> input,
224 TensorAccessor<scalar_t, 3> output,
225 TensorAccessor<const scalar_t, 2> weight,
226 bool has_bias,
227 TensorAccessor<scalar_t, 2> finput,
228 int64_t kernel_height,
229 int64_t kernel_width,
230 int64_t stride_height,
231 int64_t stride_width,
232 int64_t pad_height,
233 int64_t pad_width,
234 int64_t n_input_plane,
235 int64_t input_height,
236 int64_t input_width,
237 int64_t n_output_plane,
238 int64_t output_height,
239 int64_t output_width,
240 bool is_channels_last) {
241 const int beta = has_bias ? 1 : 0;
242
243 // Compute out = weight * input
244 // Note gemm expects fortran order, so all 3 matrices are transposed.
245 // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
246 if (is_channels_last) {
247 const int64_t m = n_output_plane;
248 const int64_t n = output_height * output_width;
249 const int64_t k = n_input_plane * kernel_height * kernel_width;
250
251 const int64_t lda = k;
252 const int64_t ldb = k;
253 const int64_t ldc = m;
254
255 at::native::cpublas::gemm(
256 TransposeType::Transpose,
257 TransposeType::NoTranspose,
258 m, n, k,
259 static_cast<scalar_t>(1),
260 weight.data(), lda,
261 finput.data(), ldb,
262 static_cast<scalar_t>(beta),
263 output.data(), ldc);
264 } else {
265 const int64_t m = output_height * output_width;
266 const int64_t n = n_output_plane;
267 const int64_t k = n_input_plane * kernel_height * kernel_width;
268
269 const int64_t lda = m;
270 const int64_t ldb = k;
271 const int64_t ldc = m;
272
273 at::native::cpublas::gemm(
274 TransposeType::NoTranspose,
275 TransposeType::NoTranspose,
276 m, n, k,
277 static_cast<scalar_t>(1),
278 finput.data(), lda,
279 weight.data(), ldb,
280 static_cast<scalar_t>(beta),
281 output.data(), ldc);
282 }
283 }
284
285 template <typename scalar_t>
slow_conv2d_backward_update_grad_input_frame(TensorAccessor<scalar_t,3> grad_input,TensorAccessor<const scalar_t,3> grad_output,TensorAccessor<const scalar_t,2> weight,scalar_t * fgrad_input,int64_t kernel_height,int64_t kernel_width,int64_t stride_height,int64_t stride_width,int64_t pad_height,int64_t pad_width,bool is_channels_last)286 void slow_conv2d_backward_update_grad_input_frame(
287 TensorAccessor<scalar_t, 3> grad_input,
288 TensorAccessor<const scalar_t, 3> grad_output,
289 TensorAccessor<const scalar_t, 2> weight,
290 scalar_t *fgrad_input,
291 int64_t kernel_height,
292 int64_t kernel_width,
293 int64_t stride_height,
294 int64_t stride_width,
295 int64_t pad_height,
296 int64_t pad_width,
297 bool is_channels_last) {
298 // Compute fgrad_input = weight.T * grad_output.reshape({grad_output.shape(0), -1})
299 // Note gemm expects fortran order, so all 3 matrices are transposed.
300 // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
301 if (is_channels_last) {
302 const int64_t m = weight.size(1);
303 const int64_t n = grad_output.size(1) * grad_output.size(2);
304 const int64_t k = weight.size(0);
305
306 const int64_t lda = m;
307 const int64_t ldb = k;
308 const int64_t ldc = m;
309
310 at::native::cpublas::gemm(
311 TransposeType::NoTranspose,
312 TransposeType::NoTranspose,
313 m, n, k,
314 static_cast<scalar_t>(1),
315 weight.data(), lda,
316 grad_output.data(), ldb,
317 static_cast<scalar_t>(0),
318 fgrad_input, ldc);
319 } else {
320 const int64_t m = grad_output.size(1) * grad_output.size(2);
321 const int64_t n = weight.size(1);
322 const int64_t k = weight.size(0);
323
324 const int64_t lda = m;
325 const int64_t ldb = n;
326 const int64_t ldc = m;
327
328 at::native::cpublas::gemm(
329 TransposeType::NoTranspose,
330 TransposeType::Transpose,
331 m, n, k,
332 static_cast<scalar_t>(1),
333 grad_output.data(), lda,
334 weight.data(), ldb,
335 static_cast<scalar_t>(0),
336 fgrad_input, ldc);
337 }
338
339 unfolded2d_acc_stub(
340 kCPU,
341 c10::CppTypeToScalarType<scalar_t>::value,
342 fgrad_input,
343 grad_input.data(),
344 kernel_height,
345 kernel_width,
346 stride_height,
347 stride_width,
348 pad_height,
349 pad_width,
350 grad_input.size(0),
351 grad_input.size(1),
352 grad_input.size(2),
353 grad_output.size(1),
354 grad_output.size(2),
355 is_channels_last);
356 }
357
slow_conv2d_backward_out_cpu_template(Tensor & grad_input,const Tensor & grad_output_,const Tensor & input_,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding)358 void slow_conv2d_backward_out_cpu_template(
359 Tensor& grad_input,
360 const Tensor& grad_output_,
361 const Tensor& input_,
362 const Tensor& weight_,
363 IntArrayRef kernel_size,
364 IntArrayRef stride,
365 IntArrayRef padding) {
366 const int64_t kernel_height = kernel_size[0];
367 const int64_t kernel_width = kernel_size[1];
368 const int64_t pad_height = padding[0];
369 const int64_t pad_width = padding[1];
370 const int64_t stride_height = stride[0];
371 const int64_t stride_width = stride[1];
372
373 bool use_channels_last = thnn_conv_use_channels_last(input_, weight_);
374 auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
375
376 const Tensor weight = view_weight_2d(weight_, memory_format);
377 slow_conv2d_shape_check(
378 input_,
379 grad_output_,
380 weight,
381 Tensor(),
382 kernel_height,
383 kernel_width,
384 stride_height,
385 stride_width,
386 pad_height,
387 pad_width,
388 false);
389
390 const Tensor input = input_.contiguous(memory_format);
391
392 // Compute shape of columnized data excluding batch dim.
393 const int64_t batch_size = input.size(0);
394 const int64_t n_input_plane = input.size(1);
395 const int64_t input_height = input.size(2);
396 const int64_t input_width = input.size(3);
397 const int64_t output_height = (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
398 const int64_t output_width = (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
399 const int64_t fgrad_input_size = n_input_plane * kernel_height * kernel_width * output_height * output_width;
400
401 const Tensor grad_output = grad_output_.contiguous(memory_format);
402 grad_input.resize_as_(input, memory_format);
403 grad_input.zero_();
404 TORCH_CHECK(grad_input.is_contiguous(memory_format), "slow_conv2d: grad_input must be contiguous");
405
406 AT_DISPATCH_FLOATING_TYPES_AND2(
407 kBFloat16, kHalf, input.scalar_type(), "slow_conv2d_cpu_grad_input", [&] {
408 auto grad_output_a = grad_output.accessor<const scalar_t, 4>();
409 auto grad_input_a = grad_input.accessor<scalar_t, 4>();
410 auto weight_a = weight.accessor<const scalar_t, 2>();
411
412 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
413 auto fgrad_input = std::make_unique<scalar_t[]>(fgrad_input_size);
414 for (const auto t : c10::irange(start, end)) {
415 auto grad_input_t = grad_input_a[t];
416 auto grad_output_t = grad_output_a[t];
417 slow_conv2d_backward_update_grad_input_frame(
418 grad_input_t,
419 grad_output_t,
420 weight_a,
421 fgrad_input.get(),
422 kernel_height,
423 kernel_width,
424 stride_height,
425 stride_width,
426 pad_height,
427 pad_width,
428 use_channels_last);
429 }
430 });
431 });
432 }
433
434 template <typename scalar_t>
slow_conv2d_backward_weight_frame(TensorAccessor<scalar_t,2> grad_weight,TensorAccessor<const scalar_t,3> grad_output,TensorAccessor<const scalar_t,2> finput,bool is_channels_last)435 void slow_conv2d_backward_weight_frame(
436 TensorAccessor<scalar_t, 2> grad_weight,
437 TensorAccessor<const scalar_t, 3> grad_output,
438 TensorAccessor<const scalar_t, 2> finput,
439 bool is_channels_last) {
440 // Compute grad_weight += grad_output.reshape({grad_output.shape(0), -1}) * finput.T
441 // Note gemm expects fortran order, so all 3 matrices are transposed.
442 // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
443 if (is_channels_last) {
444 const int64_t m = finput.size(1);
445 const int64_t n = grad_output.size(0);
446 const int64_t k = grad_output.size(1) * grad_output.size(2);
447
448 const int64_t lda = m;
449 const int64_t ldb = n;
450 const int64_t ldc = m;
451
452 at::native::cpublas::gemm(
453 TransposeType::NoTranspose,
454 TransposeType::Transpose,
455 m, n, k,
456 static_cast<scalar_t>(1),
457 finput.data(), lda,
458 grad_output.data(), ldb,
459 static_cast<scalar_t>(1),
460 grad_weight.data(), ldc);
461 } else {
462 const int64_t m = finput.size(0);
463 const int64_t n = grad_output.size(0);
464 const int64_t k = grad_output.size(1) * grad_output.size(2);
465
466 const int64_t lda = k;
467 const int64_t ldb = k;
468 const int64_t ldc = m;
469
470 at::native::cpublas::gemm(
471 TransposeType::Transpose,
472 TransposeType::NoTranspose,
473 m, n, k,
474 static_cast<scalar_t>(1),
475 finput.data(), lda,
476 grad_output.data(), ldb,
477 static_cast<scalar_t>(1),
478 grad_weight.data(), ldc);
479 }
480 }
481
slow_conv2d_backward_weight_out_cpu_template(Tensor & grad_weight,const Tensor & input,const Tensor & grad_output_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding)482 static void slow_conv2d_backward_weight_out_cpu_template(
483 Tensor& grad_weight,
484 const Tensor& input,
485 const Tensor& grad_output_,
486 IntArrayRef kernel_size,
487 IntArrayRef stride,
488 IntArrayRef padding) {
489 const int64_t kernel_height = kernel_size[0];
490 const int64_t kernel_width = kernel_size[1];
491 const int64_t pad_height = padding[0];
492 const int64_t pad_width = padding[1];
493 const int64_t stride_height = stride[0];
494 const int64_t stride_width = stride[1];
495
496 bool use_channels_last = thnn_conv_use_channels_last(input, grad_weight);
497 auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
498
499 TORCH_CHECK(grad_weight.is_contiguous(memory_format), "slow_conv2d: grad_weight must be contiguous");
500 Tensor grad_weight_2d = view_weight_2d(grad_weight, memory_format);
501
502 slow_conv2d_shape_check(
503 input,
504 grad_output_,
505 grad_weight_2d,
506 {},
507 kernel_height,
508 kernel_width,
509 stride_height,
510 stride_width,
511 pad_height,
512 pad_width,
513 true);
514
515 auto grad_output = grad_output_.contiguous(memory_format);
516 Tensor finput = compute_columns2d(input, padding, stride, kernel_size, use_channels_last);
517
518 const int64_t batch_size = input.size(0);
519
520 AT_DISPATCH_FLOATING_TYPES_AND2(
521 kBFloat16, kHalf, input.scalar_type(), "slow_conv2d_cpu_grad_weight", [&] {
522 auto grad_output_a = grad_output.accessor<const scalar_t, 4>();
523 auto grad_weight_2d_a = grad_weight_2d.accessor<scalar_t, 2>();
524 auto finput_a = finput.accessor<const scalar_t, 3>();
525
526 for (const auto t : c10::irange(batch_size)) {
527 auto grad_output_t = grad_output_a[t];
528 auto finput_t = finput_a[t];
529
530 slow_conv2d_backward_weight_frame(
531 grad_weight_2d_a, grad_output_t, finput_t, use_channels_last);
532 }
533 });
534 }
535
536 } // namespace
537
slow_conv2d_forward_out_cpu(const Tensor & self,const Tensor & weight_,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,Tensor & output)538 Tensor& slow_conv2d_forward_out_cpu(
539 const Tensor& self,
540 const Tensor& weight_,
541 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
542 IntArrayRef stride,
543 IntArrayRef padding,
544 Tensor& output) {
545 // See [Note: hacky wrapper removal for optional tensor]
546
547 TORCH_CHECK(kernel_size.size() == 2, "2D kernel_size expected");
548 TORCH_CHECK(stride.size() == 2, "2D stride expected");
549 TORCH_CHECK(padding.size() == 2, "2D padding expected");
550
551 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
552 const Tensor& bias = *bias_maybe_owned;
553
554 const int64_t kernel_height = kernel_size[0];
555 const int64_t kernel_width = kernel_size[1];
556 const int64_t pad_height = padding[0];
557 const int64_t pad_width = padding[1];
558 const int64_t stride_height = stride[0];
559 const int64_t stride_width = stride[1];
560
561 bool use_channels_last = thnn_conv_use_channels_last(self, weight_);
562 auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
563
564 const Tensor weight_2d = view_weight_2d(weight_, memory_format);
565
566 slow_conv2d_shape_check(
567 self,
568 Tensor(),
569 weight_2d,
570 bias,
571 kernel_height,
572 kernel_width,
573 stride_height,
574 stride_width,
575 pad_height,
576 pad_width,
577 false);
578
579 const Tensor input = self.contiguous(memory_format);
580 const int64_t batch_size = input.size(0);
581 const int64_t n_input_plane = input.size(1);
582 const int64_t input_height = input.size(2);
583 const int64_t input_width = input.size(3);
584 const int64_t n_output_plane = weight_2d.size(0);
585 const int64_t output_height = (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
586 const int64_t output_width = (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
587
588 Tensor finput = compute_columns2d(input, padding, stride, kernel_size, use_channels_last);
589 output.resize_({batch_size, n_output_plane, output_height, output_width}, memory_format);
590 if (bias.defined()) {
591 output.copy_(bias.reshape({-1, 1, 1}));
592 }
593 TORCH_CHECK(output.is_contiguous(memory_format), "slow_conv2d output tensor must be contiguous");
594
595 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "slow_conv2d_cpu", [&]{
596 auto input_a = input.accessor<const scalar_t, 4>();
597 auto output_a = output.accessor<scalar_t, 4>();
598 auto finput_a = finput.accessor<scalar_t, 3>();
599 auto weight_2d_a = weight_2d.accessor<const scalar_t, 2>();
600
601 at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
602 for (const auto t : c10::irange(start, end)) {
603 auto input_t = input_a[t];
604 auto output_t = output_a[t];
605 auto finput_t = finput_a[t];
606 slow_conv2d_update_output_frame(
607 input_t,
608 output_t,
609 weight_2d_a,
610 bias.defined(),
611 finput_t,
612 kernel_height,
613 kernel_width,
614 stride_height,
615 stride_width,
616 pad_height,
617 pad_width,
618 n_input_plane,
619 input_height,
620 input_width,
621 n_output_plane,
622 output_height,
623 output_width,
624 use_channels_last);
625 }
626 });
627 });
628
629 return output;
630 }
631
slow_conv2d_forward_cpu(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding)632 Tensor slow_conv2d_forward_cpu(
633 const Tensor& self,
634 const Tensor& weight,
635 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
636 IntArrayRef stride,
637 IntArrayRef padding) {
638 // See [Note: hacky wrapper removal for optional tensor]
639 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
640 const Tensor& bias = *bias_maybe_owned;
641
642 auto output = at::empty({0}, self.options());
643 at::native::slow_conv2d_forward_out_cpu(
644 self,
645 weight,
646 kernel_size,
647 bias,
648 stride,
649 padding,
650 output);
651
652 return output;
653 }
654
slow_conv2d_backward_out_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)655 std::tuple<Tensor&, Tensor&, Tensor&> slow_conv2d_backward_out_cpu(
656 const Tensor& grad_output,
657 const Tensor& self,
658 const Tensor& weight,
659 IntArrayRef kernel_size,
660 IntArrayRef stride,
661 IntArrayRef padding,
662 Tensor& grad_input,
663 Tensor& grad_weight,
664 Tensor& grad_bias) {
665 if (grad_input.defined()) {
666 slow_conv2d_backward_out_cpu_template(
667 grad_input,
668 grad_output,
669 self,
670 weight,
671 kernel_size,
672 stride,
673 padding);
674 }
675
676 if (grad_bias.defined()) {
677 at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3});
678 }
679
680 if (grad_weight.defined()) {
681 grad_weight.resize_(weight.sizes(), weight.suggest_memory_format());
682 grad_weight.zero_();
683 slow_conv2d_backward_weight_out_cpu_template(
684 grad_weight,
685 self,
686 grad_output,
687 kernel_size,
688 stride,
689 padding);
690 }
691
692 return std::tuple<Tensor&, Tensor&, Tensor&>(
693 grad_input, grad_weight, grad_bias);
694 }
695
slow_conv2d_backward_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,std::array<bool,3> output_mask)696 std::tuple<Tensor, Tensor, Tensor> slow_conv2d_backward_cpu(
697 const Tensor& grad_output,
698 const Tensor& self,
699 const Tensor& weight,
700 IntArrayRef kernel_size,
701 IntArrayRef stride,
702 IntArrayRef padding,
703 std::array<bool, 3> output_mask) {
704 Tensor grad_input;
705 Tensor grad_weight;
706 Tensor grad_bias;
707
708 if (output_mask[0]) {
709 grad_input = at::empty({0}, grad_output.options());
710 }
711
712 if (output_mask[1]) {
713 grad_weight = at::empty({0}, grad_output.options());
714 }
715
716 if (output_mask[2]) {
717 grad_bias = at::empty({0}, grad_output.options());
718 }
719
720 at::native::slow_conv2d_backward_out_cpu(
721 grad_output,
722 self,
723 weight,
724 kernel_size,
725 stride,
726 padding,
727 grad_input,
728 grad_weight,
729 grad_bias);
730
731 return std::make_tuple(grad_input, grad_weight, grad_bias);
732 }
733
thnn_conv2d_out(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,Tensor & output)734 Tensor & thnn_conv2d_out(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt, IntArrayRef stride, IntArrayRef padding, Tensor & output) {
735 // See [Note: hacky wrapper removal for optional tensor]
736 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
737 const Tensor& bias = *bias_maybe_owned;
738
739 return at::_slow_conv2d_forward_out(output, self, weight, kernel_size, bias, stride, padding);
740 }
741
thnn_conv2d(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding)742 Tensor thnn_conv2d(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt, IntArrayRef stride, IntArrayRef padding) {
743 // See [Note: hacky wrapper removal for optional tensor]
744 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
745 const Tensor& bias = *bias_maybe_owned;
746
747 return at::_slow_conv2d_forward(self, weight, kernel_size, bias, stride, padding);
748 }
749
750 } // namespace at::native
751