1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/im2col.cuh>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/TensorMeta.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10
11 #include <ATen/cuda/CUDABlas.h>
12 #include <ATen/cuda/CUDAContext.h>
13
14 #include <ATen/native/ConvUtils.h>
15
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/sum.h>
22 #include <ATen/ops/ones.h>
23 #include <ATen/ops/slow_conv_transpose2d_native.h>
24 #endif
25
26 namespace at::native {
27 namespace {
28
slow_conv_transpose2d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int kernel_height,int kernel_width,int stride_height,int stride_width,int pad_height,int pad_width,int output_padding_height,int output_padding_width,int dilation_height,int dilation_width,bool weight_nullable)29 static inline void slow_conv_transpose2d_shape_check(
30 const Tensor& input,
31 const Tensor& grad_output,
32 const Tensor& weight,
33 const Tensor& bias,
34 int kernel_height,
35 int kernel_width,
36 int stride_height,
37 int stride_width,
38 int pad_height,
39 int pad_width,
40 int output_padding_height,
41 int output_padding_width,
42 int dilation_height,
43 int dilation_width,
44 bool weight_nullable) {
45 TORCH_CHECK(
46 kernel_width > 0 && kernel_height > 0,
47 "kernel size should be greater than zero, but got kernel_height: ",
48 kernel_height,
49 " kernel_width: ",
50 kernel_width);
51 TORCH_CHECK(
52 stride_width > 0 && stride_height > 0,
53 "stride should be greater than zero, but got stride_height: ",
54 stride_height,
55 " stride_width: ",
56 stride_width);
57 TORCH_CHECK(
58 dilation_width > 0 && dilation_height > 0,
59 "dilation should be greater than zero, but got dilation_height: ",
60 dilation_height,
61 ", dilation_width: ",
62 dilation_width);
63 TORCH_CHECK(
64 (output_padding_width < stride_width ||
65 output_padding_width < dilation_width) &&
66 (output_padding_height < stride_height ||
67 output_padding_height < dilation_height),
68 "output padding must be smaller than either stride or dilation, ",
69 "but got output_padding_height: ",
70 output_padding_height,
71 " output_padding_width: ",
72 output_padding_width,
73 " stride_height: ",
74 stride_height,
75 " stride_width: ",
76 stride_width,
77 " dilation_height: ",
78 dilation_height,
79 " dilation_width: ",
80 dilation_width);
81
82 if (weight.defined()) {
83 TORCH_CHECK(
84 weight.numel() != 0 && (weight.dim() == 2 || weight.dim() == 4),
85 "non-empty 2D or 4D weight tensor expected, but got: ",
86 weight.sizes());
87 if (bias.defined()) {
88 check_dim_size(bias, 1, 0, weight.size(1));
89 }
90 } else if (!weight_nullable) {
91 AT_ERROR("weight tensor is expected to be non-nullable");
92 }
93
94 int ndim = input.dim();
95 int dimf = 0;
96 int dimh = 1;
97 int dimw = 2;
98
99 if (ndim == 4) {
100 dimf++;
101 dimh++;
102 dimw++;
103 }
104
105 TORCH_CHECK(
106 input.numel() != 0 && (ndim == 3 || ndim == 4),
107 "non-empty 3D or 4D input tensor expected but got a tensor with size ",
108 input.sizes());
109
110 int64_t input_height = input.size(dimh);
111 int64_t input_width = input.size(dimw);
112 int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
113 (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
114 int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
115 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
116
117 if (output_width < 1 || output_height < 1) {
118 AT_ERROR(
119 "Given input size per channel: (",
120 input_height,
121 " x ",
122 input_width,
123 "). Calculated output spatial size per channel: (",
124 output_height,
125 " x ",
126 output_width,
127 "). Output size is too small");
128 }
129
130 if (weight.defined()) {
131 int64_t n_input_plane = weight.size(0);
132 check_dim_size(input, ndim, dimf, n_input_plane);
133 }
134
135 if (grad_output.defined()) {
136 if (weight.defined()) {
137 int64_t n_output_plane = weight.size(1);
138 check_dim_size(grad_output, ndim, dimf, n_output_plane);
139 } else if (bias.defined()) {
140 int64_t n_output_plane = bias.size(0);
141 check_dim_size(grad_output, ndim, dimf, n_output_plane);
142 }
143 check_dim_size(grad_output, ndim, dimh, output_height);
144 check_dim_size(grad_output, ndim, dimw, output_width);
145 }
146 }
147
slow_conv_transpose2d_out_cuda_template(const Tensor & output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)148 void slow_conv_transpose2d_out_cuda_template(
149 const Tensor& output,
150 const Tensor& input,
151 const Tensor& weight,
152 IntArrayRef kernel_size,
153 const Tensor& bias,
154 IntArrayRef stride,
155 IntArrayRef padding,
156 IntArrayRef output_padding,
157 IntArrayRef dilation) {
158 TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2},
159 weight_arg{weight, "weight", 3}, bias_arg{bias, "bias", 4};
160
161 checkAllSameGPU(
162 __func__,
163 {input_arg, output_arg, weight_arg, bias_arg});
164
165 int n_input_plane = weight.size(0);
166 int n_output_plane = weight.size(1);
167
168 int64_t kernel_height = kernel_size[0];
169 int64_t kernel_width = kernel_size[1];
170 int64_t dilation_height = dilation[0];
171 int64_t dilation_width = dilation[1];
172 int64_t pad_height = padding[0];
173 int64_t pad_width = padding[1];
174 int64_t stride_height = stride[0];
175 int64_t stride_width = stride[1];
176 int64_t output_padding_height = output_padding[0];
177 int64_t output_padding_width = output_padding[1];
178
179 Tensor input_ = input.contiguous();
180 Tensor weight_ = weight.contiguous();
181
182 Tensor bias_ = Tensor();
183
184 if (bias.defined()) {
185 bias_ = bias.contiguous();
186 }
187
188 bool is_batch = false;
189 if (input_.dim() == 3) {
190 // Force batch
191 is_batch = true;
192 input_.resize_({1, input_.size(0), input_.size(1), input_.size(2)});
193 }
194
195 int64_t input_height = input_.size(2);
196 int64_t input_width = input_.size(3);
197 int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
198 (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
199 int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
200 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
201
202 // Batch size + input planes
203 int64_t batch_size = input_.size(0);
204
205 // Create temporary columns
206 Tensor columns_ = at::empty({n_output_plane * kernel_width * kernel_height,
207 input_height * input_width}, input_.options());
208
209 // Define a buffer of ones, for bias accumulation
210 Tensor ones_ = bias.defined() ? at::ones({output_height, output_width}, input_.options()) : Tensor();
211
212 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
213 input_.scalar_type(), "slow_conv_transpose2d_out_cuda", [&] {
214 using accscalar_t = at::acc_type<scalar_t, true>;
215
216 // Helpers
217 Tensor input_n;
218 Tensor output_n;
219
220 // For each elt in batch, do:
221 for (int elt = 0; elt < batch_size; elt++) {
222 // Matrix multiply per output:
223 input_n = input_.select(0, elt);
224 output_n = output.select(0, elt);
225
226 // M,N,K are dims of matrix A and B
227 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
228 int64_t m = weight_.size(1) * weight_.size(2) * weight_.size(3);
229 int64_t n = input_height * input_width;
230 int64_t k = weight_.size(0);
231
232 // Do GEMM (note: this is a bit confusing because gemm assumes
233 // column-major matrices)
234 at::cuda::blas::gemm<scalar_t>(
235 'n',
236 't',
237 n,
238 m,
239 k,
240 1,
241 input_n.const_data_ptr<scalar_t>(),
242 n,
243 weight_.const_data_ptr<scalar_t>(),
244 m,
245 0,
246 columns_.mutable_data_ptr<scalar_t>(),
247 n);
248
249 // Unpack columns back into input:
250 col2im<scalar_t, accscalar_t>(
251 at::cuda::getCurrentCUDAStream(),
252 columns_.const_data_ptr<scalar_t>(),
253 n_output_plane,
254 output_height,
255 output_width,
256 input_height,
257 input_width,
258 kernel_height,
259 kernel_width,
260 pad_height,
261 pad_width,
262 stride_height,
263 stride_width,
264 dilation_height,
265 dilation_width,
266 output_n.mutable_data_ptr<scalar_t>());
267
268 // Do Bias after:
269 // M,N,K are dims of matrix A and B
270 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
271 int64_t m_ = n_output_plane;
272 int64_t n_ = output_height * output_width;
273 int64_t k_ = 1;
274
275 // Do GEMM (note: this is a bit confusing because gemm assumes
276 // column-major matrices)
277 if (bias.defined()) {
278 at::cuda::blas::gemm<scalar_t>(
279 't',
280 'n',
281 n_,
282 m_,
283 k_,
284 1,
285 ones_.const_data_ptr<scalar_t>(),
286 k_,
287 bias_.const_data_ptr<scalar_t>(),
288 k_,
289 1,
290 output_n.mutable_data_ptr<scalar_t>(),
291 n_);
292 }
293 }
294
295 // Resize output
296 if (is_batch) {
297 output.resize_({n_output_plane, output_height, output_width});
298 input_.resize_({n_input_plane, input_height, input_width});
299 }
300 }); // end of dispatch
301 }
302
slow_conv_transpose2d_backward_out_cuda_template(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_input,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)303 static void slow_conv_transpose2d_backward_out_cuda_template(
304 const Tensor& input_,
305 const Tensor& grad_output_,
306 Tensor& grad_input,
307 const Tensor& weight_,
308 IntArrayRef kernel_size,
309 IntArrayRef stride,
310 IntArrayRef padding,
311 IntArrayRef output_padding,
312 IntArrayRef dilation) {
313 TORCH_CHECK(
314 kernel_size.size() == 2,
315 "It is expected kernel_size equals to 2, but got size ",
316 kernel_size.size());
317
318 TORCH_CHECK(
319 dilation.size() == 2,
320 "It is expected dilation equals to 2, but got size ",
321 dilation.size());
322
323 TORCH_CHECK(
324 padding.size() == 2,
325 "It is expected padding equals to 2, but got size ",
326 padding.size());
327
328 TORCH_CHECK(
329 stride.size() == 2,
330 "It is expected stride equals to 2, but got size ",
331 stride.size());
332
333 TORCH_CHECK(
334 output_padding.size() == 2,
335 "It is expected stride equals to 2, but got size ",
336 output_padding.size());
337
338 TensorArg input_arg{input_, "input", 1},
339 grad_output_arg{grad_output_, "grad_output", 2},
340 weight_arg{weight_, "weight", 3},
341 grad_input_arg{grad_input, "grad_input", 4};
342
343 checkAllSameGPU(
344 __func__,
345 {input_arg,
346 grad_output_arg,
347 weight_arg,
348 grad_input_arg});
349
350 int n_input_plane = weight_.size(0);
351 int n_output_plane = weight_.size(1);
352
353 int64_t kernel_height = kernel_size[0];
354 int64_t kernel_width = kernel_size[1];
355 int64_t dilation_height = dilation[0];
356 int64_t dilation_width = dilation[1];
357 int64_t pad_height = padding[0];
358 int64_t pad_width = padding[1];
359 int64_t stride_height = stride[0];
360 int64_t stride_width = stride[1];
361 int64_t output_padding_height = output_padding[0];
362 int64_t output_padding_width = output_padding[1];
363
364 slow_conv_transpose2d_shape_check(
365 input_,
366 grad_output_,
367 weight_,
368 Tensor(),
369 kernel_height,
370 kernel_width,
371 stride_height,
372 stride_width,
373 pad_height,
374 pad_width,
375 output_padding_height,
376 output_padding_width,
377 dilation_height,
378 dilation_width,
379 false);
380
381 Tensor input = input_.contiguous();
382 Tensor grad_output = grad_output_.contiguous();
383 Tensor weight = weight_.contiguous();
384
385 bool is_batch = false;
386 if (input.dim() == 3) {
387 // Force batch
388 is_batch = true;
389 input.resize_({1, input.size(0), input.size(1), input.size(2)});
390 grad_output.resize_(
391 {1, grad_output.size(0), grad_output.size(1), grad_output.size(2)});
392 }
393
394 int64_t input_width = input.size(3);
395 int64_t input_height = input.size(2);
396 int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
397 (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
398 int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
399 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
400
401 // Batch size + input planes
402 int64_t batch_size = input.size(0);
403
404 // Resize output
405 grad_input.resize_({batch_size, n_input_plane, input_height, input_width});
406
407 // Create temporary columns
408 bool need_columns = (kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
409 stride_width != 1 || pad_height != 0 || pad_width != 0 ||
410 dilation_height != 1 || dilation_width != 1);
411 Tensor grad_columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height,
412 input_height * input_width}, input.options()) : Tensor();
413
414 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
415 grad_output.scalar_type(), "slow_conv_transpose2d_backward_out_cuda", [&] {
416 // Helpers
417 Tensor grad_input_n = Tensor();
418 Tensor grad_output_n = Tensor();
419
420 // For each elt in batch, do:
421 for (int elt = 0; elt < batch_size; elt++) {
422 // Matrix multiply per sample:
423 grad_input_n = grad_input.select(0, elt);
424 grad_output_n = grad_output.select(0, elt);
425
426 if (need_columns) {
427 im2col<scalar_t>(
428 at::cuda::getCurrentCUDAStream(),
429 grad_output_n.const_data_ptr<scalar_t>(),
430 n_output_plane,
431 output_height,
432 output_width,
433 input_height,
434 input_width,
435 kernel_height,
436 kernel_width,
437 pad_height,
438 pad_width,
439 stride_height,
440 stride_width,
441 dilation_height,
442 dilation_width,
443 grad_columns.mutable_data_ptr<scalar_t>());
444 }
445
446 // M,N,K are dims of matrix A and B
447 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
448 int64_t m = weight.size(0);
449 int64_t n = input_height * input_width;
450 int64_t k = weight.size(1) * weight.size(2) * weight.size(3);
451
452 // Do GEMM (note: this is a bit confusing because gemm assumes
453 // column-major matrices)
454 auto gemm_in_ptr = need_columns ? grad_columns.const_data_ptr<scalar_t>()
455 : grad_output_n.const_data_ptr<scalar_t>();
456 at::cuda::blas::gemm<scalar_t>(
457 'n',
458 'n',
459 n,
460 m,
461 k,
462 1,
463 gemm_in_ptr,
464 n,
465 weight.const_data_ptr<scalar_t>(),
466 k,
467 0,
468 grad_input_n.mutable_data_ptr<scalar_t>(),
469 n);
470 }
471
472 // Resize output
473 if (is_batch) {
474 grad_output.resize_({n_output_plane, output_height, output_width});
475 input.resize_({n_input_plane, input_height, input_width});
476 grad_input.resize_({n_input_plane, input_height, input_width});
477 }
478 }); // end of dispatch
479 }
480
slow_conv_transpose2d_acc_grad_parameters_cuda_template(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,int scale_)481 void slow_conv_transpose2d_acc_grad_parameters_cuda_template(
482 const Tensor& input_,
483 const Tensor& grad_output_,
484 Tensor& grad_weight,
485 Tensor& grad_bias,
486 IntArrayRef kernel_size,
487 IntArrayRef stride,
488 IntArrayRef padding,
489 IntArrayRef output_padding,
490 IntArrayRef dilation,
491 int scale_) {
492 TORCH_CHECK(
493 kernel_size.size() == 2,
494 "It is expected kernel_size equals to 2, but got size ",
495 kernel_size.size());
496
497 TORCH_CHECK(
498 dilation.size() == 2,
499 "It is expected dilation equals to 2, but got size ",
500 dilation.size());
501
502 TORCH_CHECK(
503 padding.size() == 2,
504 "It is expected padding equals to 2, but got size ",
505 padding.size());
506
507 TORCH_CHECK(
508 stride.size() == 2,
509 "It is expected stride equals to 2, but got size ",
510 stride.size());
511
512 TORCH_CHECK(
513 output_padding.size() == 2,
514 "It is expected stride equals to 2, but got size ",
515 output_padding.size());
516
517 TensorArg input_arg{input_, "input", 1},
518 grad_output_arg{grad_output_, "grad_output", 2},
519 grad_weight_arg{grad_weight, "grad_weight", 3},
520 grad_bias_arg{grad_bias, "grad_bias", 4};
521
522 checkAllSameGPU(
523 __func__,
524 {input_arg,
525 grad_output_arg,
526 grad_weight_arg,
527 grad_bias_arg});
528
529 int64_t kernel_height = kernel_size[0];
530 int64_t kernel_width = kernel_size[1];
531 int64_t dilation_height = dilation[0];
532 int64_t dilation_width = dilation[1];
533 int64_t pad_height = padding[0];
534 int64_t pad_width = padding[1];
535 int64_t stride_height = stride[0];
536 int64_t stride_width = stride[1];
537 int64_t output_padding_height = output_padding[0];
538 int64_t output_padding_width = output_padding[1];
539
540 slow_conv_transpose2d_shape_check(
541 input_,
542 grad_output_,
543 grad_weight,
544 grad_bias,
545 kernel_height,
546 kernel_width,
547 stride_height,
548 stride_width,
549 pad_height,
550 pad_width,
551 output_padding_height,
552 output_padding_width,
553 dilation_height,
554 dilation_width,
555 true);
556
557 Tensor input = input_.contiguous();
558 Tensor grad_output = grad_output_.contiguous();
559
560 int64_t n_output_plane;
561 if (grad_weight.defined()) {
562 n_output_plane = grad_weight.size(1);
563 } else if (grad_bias.defined()) {
564 n_output_plane = grad_bias.size(0);
565 } else {
566 return;
567 }
568
569 if (grad_weight.defined()) {
570 TORCH_CHECK(
571 grad_weight.is_contiguous(), "grad_weight needs to be contiguous");
572 }
573
574 if (grad_bias.defined()) {
575 TORCH_CHECK(grad_bias.is_contiguous(), "grad_bias needs to be contiguous");
576 }
577
578 bool is_batch = false;
579 if (input.dim() == 3) {
580 // Force batch
581 is_batch = true;
582 input.resize_({1, input.size(0), input.size(1), input.size(2)});
583 grad_output.resize_(
584 {1, grad_output.size(0), grad_output.size(1), grad_output.size(2)});
585 }
586
587 int64_t input_width = input.size(3);
588 int64_t input_height = input.size(2);
589 int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
590 (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
591 int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
592 (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
593
594 // Batch size + input planes
595 int64_t batch_size = input.size(0);
596
597 // Create temporary columns
598 bool need_columns = (kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
599 stride_width != 1 || pad_height != 0 || pad_width != 0 ||
600 dilation_height != 1 || dilation_width != 1);
601 Tensor columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height,
602 input_height * input_width}, input.options()) : Tensor();
603
604 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
605 input.scalar_type(), "slow_conv_transpose2d_acc_grad_parameters_cuda", [&] {
606 // Helpers
607 Tensor input_n = Tensor();
608 Tensor grad_output_n = Tensor();
609
610 scalar_t scale = static_cast<scalar_t>(scale_);
611
612 // For each elt in batch, do:
613 for (int elt = 0; elt < batch_size; elt++) {
614 // Matrix multiply per output:
615 grad_output_n = grad_output.select(0, elt);
616
617 // Do Weight:
618 if (grad_weight.defined()) {
619 // Matrix multiply per output:
620 input_n = input.select(0, elt);
621
622 if (need_columns) {
623 // Extract columns:
624 im2col<scalar_t>(
625 at::cuda::getCurrentCUDAStream(),
626 grad_output_n.const_data_ptr<scalar_t>(),
627 n_output_plane,
628 output_height,
629 output_width,
630 input_height,
631 input_width,
632 kernel_height,
633 kernel_width,
634 pad_height,
635 pad_width,
636 stride_height,
637 stride_width,
638 dilation_height,
639 dilation_width,
640 columns.mutable_data_ptr<scalar_t>());
641 }
642
643 // M,N,K are dims of matrix A and B
644 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
645 int64_t n = n_output_plane * kernel_height * kernel_width;
646 int64_t m = input_n.size(0); // n_input_plane
647 int64_t k = input_height * input_width;
648
649 // Do GEMM (note: this is a bit confusing because gemm assumes
650 // column-major matrices)
651 auto gemm_in_ptr = need_columns ? columns.const_data_ptr<scalar_t>()
652 : grad_output_n.const_data_ptr<scalar_t>();
653 at::cuda::blas::gemm<scalar_t>(
654 't',
655 'n',
656 n,
657 m,
658 k,
659 scale,
660 gemm_in_ptr,
661 k,
662 input_n.const_data_ptr<scalar_t>(),
663 k,
664 1,
665 grad_weight.mutable_data_ptr<scalar_t>(),
666 n);
667 }
668 }
669
670 if (grad_bias.defined()) {
671 at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3});
672 }
673
674 // Resize
675 if (is_batch) {
676 grad_output.resize_({n_output_plane, output_height, output_width});
677 input.resize_({input.size(1), input_height, input_width});
678 }
679 }); // end of dispatch
680 }
681 } // namespace
682
TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cuda)683 TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cuda)
684 (const Tensor& input,
685 const Tensor& weight,
686 IntArrayRef kernel_size,
687 OptionalTensorRef bias_opt,
688 IntArrayRef stride,
689 IntArrayRef padding,
690 IntArrayRef output_padding,
691 IntArrayRef dilation,
692 const Tensor& output) {
693 const Tensor& bias = bias_opt.getTensorRef();
694
695 slow_conv_transpose2d_out_cuda_template(
696 output,
697 input,
698 weight,
699 kernel_size,
700 bias,
701 stride,
702 padding,
703 output_padding,
704 dilation);
705 }
706
slow_conv_transpose2d_backward_out_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)707 std::tuple<Tensor&, Tensor&, Tensor&> slow_conv_transpose2d_backward_out_cuda(const Tensor& grad_output,
708 const Tensor& input,
709 const Tensor& weight,
710 IntArrayRef kernel_size,
711 IntArrayRef stride,
712 IntArrayRef padding,
713 IntArrayRef output_padding,
714 IntArrayRef dilation,
715 Tensor& grad_input,
716 Tensor& grad_weight,
717 Tensor& grad_bias) {
718 if (grad_input.defined()) {
719 slow_conv_transpose2d_backward_out_cuda_template(
720 input,
721 grad_output,
722 grad_input,
723 weight,
724 kernel_size,
725 stride,
726 padding,
727 output_padding,
728 dilation);
729 }
730
731 if (grad_weight.defined()) {
732 grad_weight.resize_(weight.sizes());
733 grad_weight.zero_();
734 }
735
736 if (grad_bias.defined()) {
737 grad_bias.resize_({weight.size(1)});
738 grad_bias.zero_();
739 }
740
741 if (grad_weight.defined() || grad_bias.defined()) {
742 slow_conv_transpose2d_acc_grad_parameters_cuda_template(
743 input,
744 grad_output,
745 grad_weight,
746 grad_bias,
747 kernel_size,
748 stride,
749 padding,
750 output_padding,
751 dilation,
752 1);
753 }
754
755 return std::tuple<Tensor&, Tensor&, Tensor&>(
756 grad_input, grad_weight, grad_bias);
757 }
758
slow_conv_transpose2d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,std::array<bool,3> output_mask)759 std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose2d_backward_cuda(
760 const Tensor& grad_output,
761 const Tensor& input,
762 const Tensor& weight,
763 IntArrayRef kernel_size,
764 IntArrayRef stride,
765 IntArrayRef padding,
766 IntArrayRef output_padding,
767 IntArrayRef dilation,
768 std::array<bool, 3> output_mask) {
769 Tensor grad_input;
770 Tensor grad_weight;
771 Tensor grad_bias;
772
773 if (output_mask[0]) {
774 grad_input = at::empty({0}, grad_output.options());
775 } else {
776 grad_input = Tensor();
777 }
778
779 if (output_mask[1]) {
780 grad_weight = at::empty({0}, grad_output.options());
781 } else {
782 grad_weight = Tensor();
783 }
784
785 if (output_mask[2]) {
786 grad_bias = at::empty({0}, grad_output.options());
787 } else {
788 grad_bias = Tensor();
789 }
790
791 if (grad_input.defined()) {
792 slow_conv_transpose2d_backward_out_cuda_template(
793 input,
794 grad_output,
795 grad_input,
796 weight,
797 kernel_size,
798 stride,
799 padding,
800 output_padding,
801 dilation);
802 }
803
804 if (grad_weight.defined()) {
805 grad_weight.resize_(weight.sizes());
806 grad_weight.zero_();
807 }
808
809 if (grad_bias.defined()) {
810 grad_bias.resize_({weight.size(1)});
811 grad_bias.zero_();
812 }
813
814 if (grad_weight.defined() || grad_bias.defined()) {
815 slow_conv_transpose2d_acc_grad_parameters_cuda_template(
816 input,
817 grad_output,
818 grad_weight,
819 grad_bias,
820 kernel_size,
821 stride,
822 padding,
823 output_padding,
824 dilation,
825 1);
826 }
827
828 return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
829 }
830
831 REGISTER_CUDA_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cuda);
832
833 } // namespace at::native
834