1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/TensorMeta.h>
4 #include <ATen/TensorUtils.h>
5
6 #include <ATen/core/Tensor.h>
7 #include <ATen/native/ConvUtils.h>
8 #include <ATen/native/CPUBlas.h>
9 #include <ATen/native/im2col.h>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/ones.h>
17 #include <ATen/ops/slow_conv_transpose2d_native.h>
18 #include <ATen/ops/sum.h>
19 #include <ATen/ops/zeros.h>
20 #endif
21
22 #include <c10/core/TensorOptions.h>
23 #include <c10/util/irange.h>
24
25 namespace at {
26 namespace {
slow_conv_transpose2d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int kernel_height,int kernel_width,int stride_height,int stride_width,int pad_height,int pad_width,int output_padding_height,int output_padding_width,int dilation_height,int dilation_width,bool weight_nullable)27 static inline void slow_conv_transpose2d_shape_check(
28 const Tensor& input,
29 const Tensor& grad_output,
30 const Tensor& weight,
31 const Tensor& bias,
32 int kernel_height,
33 int kernel_width,
34 int stride_height,
35 int stride_width,
36 int pad_height,
37 int pad_width,
38 int output_padding_height,
39 int output_padding_width,
40 int dilation_height,
41 int dilation_width,
42 bool weight_nullable) {
43 TORCH_CHECK(
44 kernel_width > 0 && kernel_height > 0,
45 "kernel size should be greater than zero, but got kernel_height: ",
46 kernel_height,
47 " kernel_width: ",
48 kernel_width);
49 TORCH_CHECK(
50 stride_width > 0 && stride_height > 0,
51 "stride should be greater than zero, but got stride_height: ",
52 stride_height,
53 " stride_width: ",
54 stride_width);
55 TORCH_CHECK(
56 dilation_width > 0 && dilation_height > 0,
57 "dilation should be greater than zero, but got dilation_height: ",
58 dilation_height,
59 ", dilation_width: ",
60 dilation_width);
61 TORCH_CHECK(
62 (output_padding_width < stride_width ||
63 output_padding_width < dilation_width) &&
64 (output_padding_height < stride_height ||
65 output_padding_height < dilation_height),
66 "output padding must be smaller than either stride or dilation, but got output_padding_height: ",
67 output_padding_height,
68 " output_padding_width: ",
69 output_padding_width,
70 " stride_height: ",
71 stride_height,
72 " stride_width: ",
73 stride_width,
74 " dilation_height: ",
75 dilation_height,
76 " dilation_width: ",
77 dilation_width);
78
79 if (weight.defined()) {
80 TORCH_CHECK(
81 weight.numel() != 0 && (weight.dim() == 2 || weight.dim() == 4),
82 "non-empty 2D or 4D weight tensor expected, but got: ",
83 weight.sizes());
84 if (bias.defined()) {
85 check_dim_size(bias, 1, 0, weight.size(1));
86 }
87 } else if (!weight_nullable) {
88 AT_ERROR("weight tensor is expected to be non-nullable");
89 }
90
91 int ndim = input.dim();
92 int dimf = 0;
93 int dimh = 1;
94 int dimw = 2;
95
96 if (ndim == 4) {
97 dimf++;
98 dimh++;
99 dimw++;
100 }
101
102 TORCH_CHECK(
103 input.numel() != 0 && (ndim == 3 || ndim == 4),
104 "non-empty 3D or 4D input tensor expected but got a tensor with size ",
105 input.sizes());
106
107 int64_t input_height = input.size(dimh);
108 int64_t input_width = input.size(dimw);
109 int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
110 (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
111 int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
112 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
113
114 if (output_width < 1 || output_height < 1) {
115 AT_ERROR(
116 "Given input size per channel: (",
117 input_height,
118 " x ",
119 input_width,
120 "). "
121 "Calculated output size per channel: (",
122 output_height,
123 " x ",
124 output_width,
125 "). Output size is too small");
126 }
127
128 if (weight.defined()) {
129 int64_t n_input_plane = weight.size(0);
130 check_dim_size(input, ndim, dimf, n_input_plane);
131 }
132
133 if (grad_output.defined()) {
134 if (weight.defined()) {
135 int64_t n_output_plane = weight.size(1);
136 check_dim_size(grad_output, ndim, dimf, n_output_plane);
137 } else if (bias.defined()) {
138 int64_t n_output_plane = bias.size(0);
139 check_dim_size(grad_output, ndim, dimf, n_output_plane);
140 }
141 check_dim_size(grad_output, ndim, dimh, output_height);
142 check_dim_size(grad_output, ndim, dimw, output_width);
143 }
144 }
145 } // namespace
146
147 namespace meta {
TORCH_META_FUNC(slow_conv_transpose2d)148 TORCH_META_FUNC(slow_conv_transpose2d)
149 (const Tensor& input,
150 const Tensor& weight,
151 IntArrayRef kernel_size,
152 OptionalTensorRef bias_opt,
153 IntArrayRef stride,
154 IntArrayRef padding,
155 IntArrayRef output_padding,
156 IntArrayRef dilation) {
157 TORCH_CHECK(
158 kernel_size.size() == 2,
159 "It is expected kernel_size equals to 2, but got size ",
160 kernel_size.size());
161
162 TORCH_CHECK(
163 dilation.size() == 2,
164 "It is expected dilation equals to 2, but got size ",
165 dilation.size());
166
167 TORCH_CHECK(
168 padding.size() == 2,
169 "It is expected padding equals to 2, but got size ",
170 padding.size());
171
172 TORCH_CHECK(
173 stride.size() == 2,
174 "It is expected stride equals to 2, but got size ",
175 stride.size());
176
177 TORCH_CHECK(
178 output_padding.size() == 2,
179 "It is expected stride equals to 2, but got size ",
180 output_padding.size());
181
182 int64_t kernel_height = kernel_size[0];
183 int64_t kernel_width = kernel_size[1];
184 int64_t dilation_height = dilation[0];
185 int64_t dilation_width = dilation[1];
186 int64_t pad_height = padding[0];
187 int64_t pad_width = padding[1];
188 int64_t stride_height = stride[0];
189 int64_t stride_width = stride[1];
190 int64_t output_padding_height = output_padding[0];
191 int64_t output_padding_width = output_padding[1];
192
193 slow_conv_transpose2d_shape_check(
194 input,
195 Tensor(),
196 weight,
197 bias_opt.getTensorRef(),
198 kernel_height,
199 kernel_width,
200 stride_height,
201 stride_width,
202 pad_height,
203 pad_width,
204 output_padding_height,
205 output_padding_width,
206 dilation_height,
207 dilation_width,
208 false);
209
210 int n_output_plane = weight.size(1);
211
212 bool use_channels_last = native::thnn_conv_use_channels_last(input, weight);
213 auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
214
215 Tensor input_ = input.contiguous(memory_format);
216
217 if (input_.dim() == 3) {
218 input_.resize_({1, input_.size(0), input_.size(1), input_.size(2)});
219 }
220
221 int64_t input_height = input_.size(2);
222 int64_t input_width = input_.size(3);
223 int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
224 (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
225 int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
226 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
227
228 // Batch size + input planes
229 int64_t batch_size = input_.size(0);
230
231 // Resize output
232 TensorOptions options(input.options());
233 set_output_raw_strided(
234 0,
235 {batch_size, n_output_plane, output_height, output_width},
236 {},
237 options.memory_format(memory_format));
238 }
239 } // namespace meta
240
241 namespace native {
242
243 namespace {
slow_conv_transpose2d_out_cpu_template(const Tensor & output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)244 void slow_conv_transpose2d_out_cpu_template(
245 const Tensor& output,
246 const Tensor& input,
247 const Tensor& weight,
248 IntArrayRef kernel_size,
249 const Tensor& bias,
250 IntArrayRef stride,
251 IntArrayRef padding,
252 IntArrayRef output_padding,
253 IntArrayRef dilation) {
254 int64_t kernel_height = kernel_size[0];
255 int64_t kernel_width = kernel_size[1];
256 int64_t dilation_height = dilation[0];
257 int64_t dilation_width = dilation[1];
258 int64_t pad_height = padding[0];
259 int64_t pad_width = padding[1];
260 int64_t stride_height = stride[0];
261 int64_t stride_width = stride[1];
262 int64_t output_padding_height = output_padding[0];
263 int64_t output_padding_width = output_padding[1];
264
265 int n_input_plane = weight.size(0);
266 int n_output_plane = weight.size(1);
267
268 bool use_channels_last = thnn_conv_use_channels_last(input, weight);
269 auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
270
271 Tensor input_ = input.contiguous(memory_format);
272 Tensor weight_ = weight.contiguous(memory_format);
273 Tensor bias_ = bias.defined() ? bias.contiguous() : Tensor();
274
275 bool is_batch = false;
276 if (input_.dim() == 3) {
277 // Force batch
278 is_batch = true;
279 input_.resize_({1, input.size(0), input.size(1), input.size(2)});
280 }
281
282 int64_t input_height = input_.size(2);
283 int64_t input_width = input_.size(3);
284 int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
285 (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
286 int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
287 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
288
289 // Batch size + input planes
290 int64_t batch_size = input_.size(0);
291
292 // Create temporary columns
293 Tensor columns = at::empty({0}, input.options());
294 if (use_channels_last) {
295 columns.resize_({batch_size, input_height * input_width, kernel_height * kernel_width * n_output_plane});
296 } else {
297 columns.resize_({batch_size, n_output_plane * kernel_height * kernel_width, input_height * input_width});
298 }
299 columns.zero_();
300
301 // Materialize if COW, since we cannot do so during parallel_for
302 output.mutable_data_ptr();
303
304 AT_DISPATCH_FLOATING_TYPES_AND3(at::ScalarType::Long, at::ScalarType::BFloat16,
305 at::ScalarType::Half, input.scalar_type(), "slow_conv_transpose2d_out_cpu", [&] {
306
307 at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
308 // For each elt in batch, do:
309 for (const auto elt : c10::irange(begin, end)) {
310 // Matrix multiply per output:
311 Tensor input_n = input_.select(0, elt);
312 Tensor output_n = output.select(0, elt);
313 Tensor columns_n = columns.select(0, elt);
314
315 if (use_channels_last) {
316 int64_t m = kernel_height * kernel_width * n_output_plane;
317 int64_t n = input_height * input_width;
318 int64_t k = n_input_plane;
319
320 // column-major matrices
321 cpublas::gemm(
322 TransposeType::NoTranspose,
323 TransposeType::NoTranspose,
324 m,
325 n,
326 k,
327 static_cast<scalar_t>(1),
328 weight_.const_data_ptr<scalar_t>(),
329 m,
330 input_n.const_data_ptr<scalar_t>(),
331 k,
332 static_cast<scalar_t>(0),
333 columns_n.mutable_data_ptr<scalar_t>(),
334 m);
335 } else {
336 int64_t m = input_height * input_width;
337 int64_t n = n_output_plane * kernel_height * kernel_width;
338 int64_t k = n_input_plane;
339
340 // column-major matrices
341 cpublas::gemm(
342 TransposeType::NoTranspose,
343 TransposeType::Transpose,
344 m,
345 n,
346 k,
347 static_cast<scalar_t>(1),
348 input_n.const_data_ptr<scalar_t>(),
349 m,
350 weight_.const_data_ptr<scalar_t>(),
351 n,
352 static_cast<scalar_t>(0),
353 columns_n.mutable_data_ptr<scalar_t>(),
354 m);
355 }
356
357 // Unpack columns back into input:
358 col2im<scalar_t>(
359 columns_n.const_data_ptr<scalar_t>(),
360 n_output_plane,
361 output_height,
362 output_width,
363 input_height,
364 input_width,
365 kernel_height,
366 kernel_width,
367 pad_height,
368 pad_width,
369 stride_height,
370 stride_width,
371 dilation_height,
372 dilation_width,
373 output_n.data_ptr<scalar_t>(),
374 use_channels_last);
375 }
376 });
377 });
378
379 if (bias.defined()) {
380 output.add_(bias_.reshape({-1, 1, 1}));
381 }
382
383 // Resize output
384 if (is_batch) {
385 output.resize_({n_output_plane, output_height, output_width});
386 }
387 }
388
slow_conv_transpose2d_backward_out_cpu_template(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_input,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)389 static void slow_conv_transpose2d_backward_out_cpu_template(
390 const Tensor& input_,
391 const Tensor& grad_output_,
392 Tensor& grad_input,
393 const Tensor& weight_,
394 IntArrayRef kernel_size,
395 IntArrayRef stride,
396 IntArrayRef padding,
397 IntArrayRef output_padding,
398 IntArrayRef dilation) {
399 TORCH_CHECK(
400 kernel_size.size() == 2,
401 "It is expected kernel_size equals to 2, but got size ",
402 kernel_size.size());
403
404 TORCH_CHECK(
405 dilation.size() == 2,
406 "It is expected dilation equals to 2, but got size ",
407 dilation.size());
408
409 TORCH_CHECK(
410 padding.size() == 2,
411 "It is expected padding equals to 2, but got size ",
412 padding.size());
413
414 TORCH_CHECK(
415 stride.size() == 2,
416 "It is expected stride equals to 2, but got size ",
417 stride.size());
418
419 TORCH_CHECK(
420 output_padding.size() == 2,
421 "It is expected stride equals to 2, but got size ",
422 output_padding.size());
423
424 int64_t kernel_height = kernel_size[0];
425 int64_t kernel_width = kernel_size[1];
426 int64_t dilation_height = dilation[0];
427 int64_t dilation_width = dilation[1];
428 int64_t pad_height = padding[0];
429 int64_t pad_width = padding[1];
430 int64_t stride_height = stride[0];
431 int64_t stride_width = stride[1];
432 int64_t output_padding_height = output_padding[0];
433 int64_t output_padding_width = output_padding[1];
434
435 int64_t n_input_plane = weight_.size(0);
436 int64_t n_output_plane = weight_.size(1);
437
438 bool use_channels_last = thnn_conv_use_channels_last(input_, weight_);
439 auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
440
441 slow_conv_transpose2d_shape_check(
442 input_,
443 grad_output_,
444 weight_,
445 Tensor(),
446 kernel_height,
447 kernel_width,
448 stride_height,
449 stride_width,
450 pad_height,
451 pad_width,
452 output_padding_height,
453 output_padding_width,
454 dilation_height,
455 dilation_width,
456 false);
457
458 Tensor input = input_.contiguous(memory_format);
459 Tensor grad_output = grad_output_.contiguous(memory_format);
460 Tensor weight = weight_.contiguous(memory_format);
461
462 bool is_batch = false;
463 if (input.dim() == 3) {
464 // Force batch
465 is_batch = true;
466 input.resize_({1, input.size(0), input.size(1), input.size(2)});
467 grad_output.resize_(
468 {1, grad_output.size(0), grad_output.size(1), grad_output.size(2)});
469 }
470
471 int64_t input_width = input.size(3);
472 int64_t input_height = input.size(2);
473 int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
474 (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
475 int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
476 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
477
478 // Batch size + input planes
479 int64_t batch_size = input.size(0);
480
481 // Resize output
482 grad_input.resize_({batch_size, n_input_plane, input_height, input_width}, memory_format);
483 grad_input.zero_();
484
485 // Create temporary columns
486 bool need_columns = (kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
487 stride_width != 1 || pad_height != 0 || pad_width != 0 ||
488 dilation_height != 1 || dilation_width != 1);
489
490 Tensor grad_columns = at::empty({0}, input.options());
491 if (need_columns) {
492 if (use_channels_last) {
493 grad_columns.resize_({input_height * input_width, kernel_height * kernel_width * n_output_plane});
494 } else {
495 grad_columns.resize_({n_output_plane * kernel_height * kernel_width, input_height * input_width});
496 }
497 }
498
499 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
500 grad_output.scalar_type(), "slow_conv_transpose2d_backward_out_cpu", [&] {
501 // Helpers
502 Tensor grad_input_n = Tensor();
503 Tensor grad_output_n = Tensor();
504
505 // For each elt in batch, do:
506 for (const auto elt : c10::irange(batch_size)) {
507 // Matrix multiply per sample:
508 grad_input_n = grad_input.select(0, elt);
509 grad_output_n = grad_output.select(0, elt);
510
511 if (need_columns) {
512 // Extract columns:
513 im2col<scalar_t>(
514 grad_output_n.const_data_ptr<scalar_t>(),
515 n_output_plane,
516 output_height,
517 output_width,
518 input_height,
519 input_width,
520 kernel_height,
521 kernel_width,
522 pad_height,
523 pad_width,
524 stride_height,
525 stride_width,
526 dilation_height,
527 dilation_width,
528 grad_columns.data_ptr<scalar_t>(),
529 use_channels_last);
530 }
531
532 auto gemm_in_ptr = need_columns ? grad_columns.const_data_ptr<scalar_t>()
533 : grad_output_n.const_data_ptr<scalar_t>();
534
535 if (use_channels_last) {
536 int64_t m = n_input_plane;
537 int64_t n = input_height * input_width;
538 int64_t k = n_output_plane * kernel_height * kernel_width;
539
540 // column-major matrices
541 cpublas::gemm(
542 TransposeType::Transpose,
543 TransposeType::NoTranspose,
544 m,
545 n,
546 k,
547 static_cast<scalar_t>(1),
548 weight.const_data_ptr<scalar_t>(),
549 k,
550 gemm_in_ptr,
551 k,
552 static_cast<scalar_t>(0),
553 grad_input_n.mutable_data_ptr<scalar_t>(),
554 m);
555
556 } else {
557 int64_t m = input_height * input_width;
558 int64_t n = n_input_plane;
559 int64_t k = n_output_plane * kernel_height * kernel_width;
560
561 // column-major matrices
562 cpublas::gemm(
563 TransposeType::NoTranspose,
564 TransposeType::NoTranspose,
565 m,
566 n,
567 k,
568 static_cast<scalar_t>(1),
569 gemm_in_ptr,
570 m,
571 weight.const_data_ptr<scalar_t>(),
572 k,
573 static_cast<scalar_t>(0),
574 grad_input_n.mutable_data_ptr<scalar_t>(),
575 m);
576 }
577 }
578
579 // Resize output
580 if (is_batch) {
581 grad_input.resize_({n_input_plane, input_height, input_width});
582 }
583 });
584 }
585
slow_conv_transpose2d_acc_grad_parameters_cpu(const Tensor & input_,const Tensor & weight_,const Tensor & grad_output_,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,int scale_)586 void slow_conv_transpose2d_acc_grad_parameters_cpu(
587 const Tensor& input_,
588 const Tensor& weight_,
589 const Tensor& grad_output_,
590 Tensor& grad_weight,
591 Tensor& grad_bias,
592 IntArrayRef kernel_size,
593 IntArrayRef stride,
594 IntArrayRef padding,
595 IntArrayRef output_padding,
596 IntArrayRef dilation,
597 int scale_) {
598 TORCH_CHECK(
599 kernel_size.size() == 2,
600 "It is expected kernel_size equals to 2, but got size ",
601 kernel_size.size());
602
603 TORCH_CHECK(
604 dilation.size() == 2,
605 "It is expected dilation equals to 2, but got size ",
606 dilation.size());
607
608 TORCH_CHECK(
609 padding.size() == 2,
610 "It is expected padding equals to 2, but got size ",
611 padding.size());
612
613 TORCH_CHECK(
614 stride.size() == 2,
615 "It is expected stride equals to 2, but got size ",
616 stride.size());
617
618 TORCH_CHECK(
619 output_padding.size() == 2,
620 "It is expected stride equals to 2, but got size ",
621 output_padding.size());
622
623 int64_t kernel_height = kernel_size[0];
624 int64_t kernel_width = kernel_size[1];
625 int64_t dilation_height = dilation[0];
626 int64_t dilation_width = dilation[1];
627 int64_t pad_height = padding[0];
628 int64_t pad_width = padding[1];
629 int64_t stride_height = stride[0];
630 int64_t stride_width = stride[1];
631 int64_t output_padding_height = output_padding[0];
632 int64_t output_padding_width = output_padding[1];
633
634 bool use_channels_last = thnn_conv_use_channels_last(input_, weight_);
635 auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
636
637 slow_conv_transpose2d_shape_check(
638 input_,
639 grad_output_,
640 grad_weight,
641 grad_bias,
642 kernel_height,
643 kernel_width,
644 stride_height,
645 stride_width,
646 pad_height,
647 pad_width,
648 output_padding_height,
649 output_padding_width,
650 dilation_height,
651 dilation_width,
652 true);
653
654 int n_input_plane = weight_.size(0);
655 int n_output_plane = weight_.size(1);
656
657 Tensor input = input_.contiguous(memory_format);
658 Tensor grad_output = grad_output_.contiguous(memory_format);
659 TORCH_CHECK(grad_weight.is_contiguous(memory_format), "grad_weight needs to be contiguous");
660
661 if (input.dim() == 3) {
662 input.resize_({1, input.size(0), input.size(1), input.size(2)});
663 grad_output.resize_(
664 {1, grad_output.size(0), grad_output.size(1), grad_output.size(2)});
665 }
666
667 int64_t input_width = input.size(3);
668 int64_t input_height = input.size(2);
669 int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
670 (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
671 int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
672 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
673
674 // Batch size + input planes
675 int64_t batch_size = input.size(0);
676
677 // Resize temporary columns
678 bool need_columns = (kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
679 stride_width != 1 || pad_height != 0 || pad_width != 0 ||
680 dilation_height != 1 || dilation_width != 1);
681
682 Tensor columns = at::empty({0}, input.options());
683 if (need_columns) {
684 if (use_channels_last) {
685 columns.resize_({input_height * input_width, kernel_height * kernel_width * n_output_plane});
686 } else {
687 columns.resize_({n_output_plane * kernel_height * kernel_width, input_height * input_width});
688 }
689 }
690
691 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
692 input.scalar_type(), "slow_conv_transpose2d_acc_grad_parameters_cpu", [&] {
693 // Helpers
694 Tensor input_n = Tensor();
695 Tensor grad_output_n = Tensor();
696
697 scalar_t scale = static_cast<scalar_t>(scale_);
698
699 // For each elt in batch, do:
700 for (const auto elt : c10::irange(batch_size)) {
701 // Matrix multiply per output:
702 grad_output_n = grad_output.select(0, elt);
703
704 // Do Weight:
705 if (grad_weight.defined()) {
706 // Matrix multiply per output:
707 input_n = input.select(0, elt);
708
709 if (need_columns) {
710 // Extract columns:
711 im2col<scalar_t>(
712 grad_output_n.const_data_ptr<scalar_t>(),
713 n_output_plane,
714 output_height,
715 output_width,
716 input_height,
717 input_width,
718 kernel_height,
719 kernel_width,
720 pad_height,
721 pad_width,
722 stride_height,
723 stride_width,
724 dilation_height,
725 dilation_width,
726 columns.data_ptr<scalar_t>(),
727 use_channels_last);
728 }
729
730 auto gemm_in_ptr = need_columns ? columns.const_data_ptr<scalar_t>()
731 : grad_output_n.const_data_ptr<scalar_t>();
732
733 if (use_channels_last) {
734 int64_t m = kernel_height * kernel_width * n_output_plane;
735 int64_t n = n_input_plane;
736 int64_t k = input_height * input_width;
737
738 // column-major matrices
739 cpublas::gemm(
740 TransposeType::NoTranspose,
741 TransposeType::Transpose,
742 m,
743 n,
744 k,
745 static_cast<scalar_t>(scale),
746 gemm_in_ptr,
747 m,
748 input_n.const_data_ptr<scalar_t>(),
749 n,
750 static_cast<scalar_t>(1),
751 grad_weight.mutable_data_ptr<scalar_t>(),
752 m);
753 } else {
754 int64_t m = n_output_plane * kernel_height * kernel_width;
755 int64_t n = n_input_plane;
756 int64_t k = input_height * input_width;
757
758 // column-major matrices
759 cpublas::gemm(
760 TransposeType::Transpose,
761 TransposeType::NoTranspose,
762 m,
763 n,
764 k,
765 static_cast<scalar_t>(scale),
766 gemm_in_ptr,
767 k,
768 input_n.const_data_ptr<scalar_t>(),
769 k,
770 static_cast<scalar_t>(1),
771 grad_weight.mutable_data_ptr<scalar_t>(),
772 m);
773 }
774 }
775 }
776 });
777 }
778
779 } // namespace
780
TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cpu)781 TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cpu)
782 (const Tensor& input,
783 const Tensor& weight,
784 IntArrayRef kernel_size,
785 OptionalTensorRef bias_opt,
786 IntArrayRef stride,
787 IntArrayRef padding,
788 IntArrayRef output_padding,
789 IntArrayRef dilation,
790 const Tensor& output){
791 const Tensor& bias = bias_opt.getTensorRef();
792
793 slow_conv_transpose2d_out_cpu_template(
794 output,
795 input,
796 weight,
797 kernel_size,
798 bias,
799 stride,
800 padding,
801 output_padding,
802 dilation);
803 }
804
slow_conv_transpose2d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,std::array<bool,3> output_mask)805 static std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose2d_backward_cpu(
806 const Tensor& grad_output,
807 const Tensor& input,
808 const Tensor& weight,
809 IntArrayRef kernel_size,
810 IntArrayRef stride,
811 IntArrayRef padding,
812 IntArrayRef output_padding,
813 IntArrayRef dilation,
814 std::array<bool, 3> output_mask) {
815 Tensor grad_input;
816 Tensor grad_weight;
817 Tensor grad_bias;
818
819 if (output_mask[0]) {
820 grad_input = at::empty({0}, grad_output.options());
821 } else {
822 grad_input = Tensor();
823 }
824
825 if (output_mask[1]) {
826 grad_weight = at::empty({0}, grad_output.options());
827 } else {
828 grad_weight = Tensor();
829 }
830
831 if (output_mask[2]) {
832 grad_bias = at::empty({0}, grad_output.options());
833 } else {
834 grad_bias = Tensor();
835 }
836
837 if (grad_input.defined()) {
838 slow_conv_transpose2d_backward_out_cpu_template(
839 input,
840 grad_output,
841 grad_input,
842 weight,
843 kernel_size,
844 stride,
845 padding,
846 output_padding,
847 dilation);
848 }
849
850 if (grad_bias.defined()) {
851 at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3});
852 }
853
854 if (grad_weight.defined()) {
855 grad_weight.resize_(weight.sizes(), weight.suggest_memory_format());
856 grad_weight.zero_();
857 slow_conv_transpose2d_acc_grad_parameters_cpu(
858 input,
859 weight,
860 grad_output,
861 grad_weight,
862 grad_bias,
863 kernel_size,
864 stride,
865 padding,
866 output_padding,
867 dilation,
868 1);
869 }
870
871 return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
872 }
873
874 REGISTER_ALL_CPU_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cpu);
875
876 } // namespace native
877 } // namespace at
878