xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/im2col.cuh>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/TensorMeta.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 
11 #include <ATen/cuda/CUDABlas.h>
12 #include <ATen/cuda/CUDAContext.h>
13 
14 #include <ATen/native/ConvUtils.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/sum.h>
22 #include <ATen/ops/ones.h>
23 #include <ATen/ops/slow_conv_transpose2d_native.h>
24 #endif
25 
26 namespace at::native {
27 namespace {
28 
slow_conv_transpose2d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int kernel_height,int kernel_width,int stride_height,int stride_width,int pad_height,int pad_width,int output_padding_height,int output_padding_width,int dilation_height,int dilation_width,bool weight_nullable)29 static inline void slow_conv_transpose2d_shape_check(
30     const Tensor& input,
31     const Tensor& grad_output,
32     const Tensor& weight,
33     const Tensor& bias,
34     int kernel_height,
35     int kernel_width,
36     int stride_height,
37     int stride_width,
38     int pad_height,
39     int pad_width,
40     int output_padding_height,
41     int output_padding_width,
42     int dilation_height,
43     int dilation_width,
44     bool weight_nullable) {
45   TORCH_CHECK(
46       kernel_width > 0 && kernel_height > 0,
47       "kernel size should be greater than zero, but got kernel_height: ",
48       kernel_height,
49       " kernel_width: ",
50       kernel_width);
51   TORCH_CHECK(
52       stride_width > 0 && stride_height > 0,
53       "stride should be greater than zero, but got stride_height: ",
54       stride_height,
55       " stride_width: ",
56       stride_width);
57   TORCH_CHECK(
58       dilation_width > 0 && dilation_height > 0,
59       "dilation should be greater than zero, but got dilation_height: ",
60       dilation_height,
61       ", dilation_width: ",
62       dilation_width);
63   TORCH_CHECK(
64       (output_padding_width < stride_width ||
65        output_padding_width < dilation_width) &&
66           (output_padding_height < stride_height ||
67            output_padding_height < dilation_height),
68       "output padding must be smaller than either stride or dilation, ",
69       "but got output_padding_height: ",
70       output_padding_height,
71       " output_padding_width: ",
72       output_padding_width,
73       " stride_height: ",
74       stride_height,
75       " stride_width: ",
76       stride_width,
77       " dilation_height: ",
78       dilation_height,
79       " dilation_width: ",
80       dilation_width);
81 
82   if (weight.defined()) {
83     TORCH_CHECK(
84         weight.numel() != 0 && (weight.dim() == 2 || weight.dim() == 4),
85         "non-empty 2D or 4D weight tensor expected, but got: ",
86         weight.sizes());
87     if (bias.defined()) {
88       check_dim_size(bias, 1, 0, weight.size(1));
89     }
90   } else if (!weight_nullable) {
91     AT_ERROR("weight tensor is expected to be non-nullable");
92   }
93 
94   int ndim = input.dim();
95   int dimf = 0;
96   int dimh = 1;
97   int dimw = 2;
98 
99   if (ndim == 4) {
100     dimf++;
101     dimh++;
102     dimw++;
103   }
104 
105   TORCH_CHECK(
106       input.numel() != 0 && (ndim == 3 || ndim == 4),
107       "non-empty 3D or 4D input tensor expected but got a tensor with size ",
108       input.sizes());
109 
110   int64_t input_height = input.size(dimh);
111   int64_t input_width = input.size(dimw);
112   int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
113       (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
114   int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
115       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
116 
117   if (output_width < 1 || output_height < 1) {
118     AT_ERROR(
119         "Given input size per channel: (",
120         input_height,
121         " x ",
122         input_width,
123         "). Calculated output spatial size per channel: (",
124         output_height,
125         " x ",
126         output_width,
127         "). Output size is too small");
128   }
129 
130   if (weight.defined()) {
131     int64_t n_input_plane = weight.size(0);
132     check_dim_size(input, ndim, dimf, n_input_plane);
133   }
134 
135   if (grad_output.defined()) {
136     if (weight.defined()) {
137       int64_t n_output_plane = weight.size(1);
138       check_dim_size(grad_output, ndim, dimf, n_output_plane);
139     } else if (bias.defined()) {
140       int64_t n_output_plane = bias.size(0);
141       check_dim_size(grad_output, ndim, dimf, n_output_plane);
142     }
143     check_dim_size(grad_output, ndim, dimh, output_height);
144     check_dim_size(grad_output, ndim, dimw, output_width);
145   }
146 }
147 
slow_conv_transpose2d_out_cuda_template(const Tensor & output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)148 void slow_conv_transpose2d_out_cuda_template(
149     const Tensor& output,
150     const Tensor& input,
151     const Tensor& weight,
152     IntArrayRef kernel_size,
153     const Tensor& bias,
154     IntArrayRef stride,
155     IntArrayRef padding,
156     IntArrayRef output_padding,
157     IntArrayRef dilation) {
158   TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2},
159       weight_arg{weight, "weight", 3}, bias_arg{bias, "bias", 4};
160 
161   checkAllSameGPU(
162       __func__,
163       {input_arg, output_arg, weight_arg, bias_arg});
164 
165   int n_input_plane = weight.size(0);
166   int n_output_plane = weight.size(1);
167 
168   int64_t kernel_height = kernel_size[0];
169   int64_t kernel_width = kernel_size[1];
170   int64_t dilation_height = dilation[0];
171   int64_t dilation_width = dilation[1];
172   int64_t pad_height = padding[0];
173   int64_t pad_width = padding[1];
174   int64_t stride_height = stride[0];
175   int64_t stride_width = stride[1];
176   int64_t output_padding_height = output_padding[0];
177   int64_t output_padding_width = output_padding[1];
178 
179   Tensor input_ = input.contiguous();
180   Tensor weight_ = weight.contiguous();
181 
182   Tensor bias_ = Tensor();
183 
184   if (bias.defined()) {
185     bias_ = bias.contiguous();
186   }
187 
188   bool is_batch = false;
189   if (input_.dim() == 3) {
190     // Force batch
191     is_batch = true;
192     input_.resize_({1, input_.size(0), input_.size(1), input_.size(2)});
193   }
194 
195   int64_t input_height = input_.size(2);
196   int64_t input_width = input_.size(3);
197   int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
198       (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
199   int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
200       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
201 
202   // Batch size + input planes
203   int64_t batch_size = input_.size(0);
204 
205   // Create temporary columns
206   Tensor columns_ = at::empty({n_output_plane * kernel_width * kernel_height,
207       input_height * input_width}, input_.options());
208 
209   // Define a buffer of ones, for bias accumulation
210   Tensor ones_ = bias.defined() ? at::ones({output_height, output_width}, input_.options()) : Tensor();
211 
212   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
213       input_.scalar_type(), "slow_conv_transpose2d_out_cuda", [&] {
214         using accscalar_t = at::acc_type<scalar_t, true>;
215 
216         // Helpers
217         Tensor input_n;
218         Tensor output_n;
219 
220         // For each elt in batch, do:
221         for (int elt = 0; elt < batch_size; elt++) {
222           // Matrix multiply per output:
223           input_n = input_.select(0, elt);
224           output_n = output.select(0, elt);
225 
226           // M,N,K are dims of matrix A and B
227           // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
228           int64_t m = weight_.size(1) * weight_.size(2) * weight_.size(3);
229           int64_t n = input_height * input_width;
230           int64_t k = weight_.size(0);
231 
232           // Do GEMM (note: this is a bit confusing because gemm assumes
233           // column-major matrices)
234           at::cuda::blas::gemm<scalar_t>(
235               'n',
236               't',
237               n,
238               m,
239               k,
240               1,
241               input_n.const_data_ptr<scalar_t>(),
242               n,
243               weight_.const_data_ptr<scalar_t>(),
244               m,
245               0,
246               columns_.mutable_data_ptr<scalar_t>(),
247               n);
248 
249           // Unpack columns back into input:
250           col2im<scalar_t, accscalar_t>(
251               at::cuda::getCurrentCUDAStream(),
252               columns_.const_data_ptr<scalar_t>(),
253               n_output_plane,
254               output_height,
255               output_width,
256               input_height,
257               input_width,
258               kernel_height,
259               kernel_width,
260               pad_height,
261               pad_width,
262               stride_height,
263               stride_width,
264               dilation_height,
265               dilation_width,
266               output_n.mutable_data_ptr<scalar_t>());
267 
268           // Do Bias after:
269           // M,N,K are dims of matrix A and B
270           // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
271           int64_t m_ = n_output_plane;
272           int64_t n_ = output_height * output_width;
273           int64_t k_ = 1;
274 
275           // Do GEMM (note: this is a bit confusing because gemm assumes
276           // column-major matrices)
277           if (bias.defined()) {
278             at::cuda::blas::gemm<scalar_t>(
279                 't',
280                 'n',
281                 n_,
282                 m_,
283                 k_,
284                 1,
285                 ones_.const_data_ptr<scalar_t>(),
286                 k_,
287                 bias_.const_data_ptr<scalar_t>(),
288                 k_,
289                 1,
290                 output_n.mutable_data_ptr<scalar_t>(),
291                 n_);
292           }
293         }
294 
295         // Resize output
296         if (is_batch) {
297           output.resize_({n_output_plane, output_height, output_width});
298           input_.resize_({n_input_plane, input_height, input_width});
299         }
300       }); // end of dispatch
301 }
302 
slow_conv_transpose2d_backward_out_cuda_template(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_input,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)303 static void slow_conv_transpose2d_backward_out_cuda_template(
304     const Tensor& input_,
305     const Tensor& grad_output_,
306     Tensor& grad_input,
307     const Tensor& weight_,
308     IntArrayRef kernel_size,
309     IntArrayRef stride,
310     IntArrayRef padding,
311     IntArrayRef output_padding,
312     IntArrayRef dilation) {
313   TORCH_CHECK(
314       kernel_size.size() == 2,
315       "It is expected kernel_size equals to 2, but got size ",
316       kernel_size.size());
317 
318   TORCH_CHECK(
319       dilation.size() == 2,
320       "It is expected dilation equals to 2, but got size ",
321       dilation.size());
322 
323   TORCH_CHECK(
324       padding.size() == 2,
325       "It is expected padding equals to 2, but got size ",
326       padding.size());
327 
328   TORCH_CHECK(
329       stride.size() == 2,
330       "It is expected stride equals to 2, but got size ",
331       stride.size());
332 
333   TORCH_CHECK(
334       output_padding.size() == 2,
335       "It is expected stride equals to 2, but got size ",
336       output_padding.size());
337 
338   TensorArg input_arg{input_, "input", 1},
339       grad_output_arg{grad_output_, "grad_output", 2},
340       weight_arg{weight_, "weight", 3},
341       grad_input_arg{grad_input, "grad_input", 4};
342 
343   checkAllSameGPU(
344       __func__,
345       {input_arg,
346        grad_output_arg,
347        weight_arg,
348        grad_input_arg});
349 
350   int n_input_plane = weight_.size(0);
351   int n_output_plane = weight_.size(1);
352 
353   int64_t kernel_height = kernel_size[0];
354   int64_t kernel_width = kernel_size[1];
355   int64_t dilation_height = dilation[0];
356   int64_t dilation_width = dilation[1];
357   int64_t pad_height = padding[0];
358   int64_t pad_width = padding[1];
359   int64_t stride_height = stride[0];
360   int64_t stride_width = stride[1];
361   int64_t output_padding_height = output_padding[0];
362   int64_t output_padding_width = output_padding[1];
363 
364   slow_conv_transpose2d_shape_check(
365       input_,
366       grad_output_,
367       weight_,
368       Tensor(),
369       kernel_height,
370       kernel_width,
371       stride_height,
372       stride_width,
373       pad_height,
374       pad_width,
375       output_padding_height,
376       output_padding_width,
377       dilation_height,
378       dilation_width,
379       false);
380 
381   Tensor input = input_.contiguous();
382   Tensor grad_output = grad_output_.contiguous();
383   Tensor weight = weight_.contiguous();
384 
385   bool is_batch = false;
386   if (input.dim() == 3) {
387     // Force batch
388     is_batch = true;
389     input.resize_({1, input.size(0), input.size(1), input.size(2)});
390     grad_output.resize_(
391         {1, grad_output.size(0), grad_output.size(1), grad_output.size(2)});
392   }
393 
394   int64_t input_width = input.size(3);
395   int64_t input_height = input.size(2);
396   int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
397       (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
398   int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
399       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
400 
401   // Batch size + input planes
402   int64_t batch_size = input.size(0);
403 
404   // Resize output
405   grad_input.resize_({batch_size, n_input_plane, input_height, input_width});
406 
407   // Create temporary columns
408   bool need_columns = (kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
409       stride_width != 1 || pad_height != 0 || pad_width != 0 ||
410       dilation_height != 1 || dilation_width != 1);
411   Tensor grad_columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height,
412       input_height * input_width}, input.options()) : Tensor();
413 
414   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
415       grad_output.scalar_type(), "slow_conv_transpose2d_backward_out_cuda", [&] {
416         // Helpers
417         Tensor grad_input_n = Tensor();
418         Tensor grad_output_n = Tensor();
419 
420         // For each elt in batch, do:
421         for (int elt = 0; elt < batch_size; elt++) {
422           // Matrix multiply per sample:
423           grad_input_n = grad_input.select(0, elt);
424           grad_output_n = grad_output.select(0, elt);
425 
426           if (need_columns) {
427             im2col<scalar_t>(
428                 at::cuda::getCurrentCUDAStream(),
429                 grad_output_n.const_data_ptr<scalar_t>(),
430                 n_output_plane,
431                 output_height,
432                 output_width,
433                 input_height,
434                 input_width,
435                 kernel_height,
436                 kernel_width,
437                 pad_height,
438                 pad_width,
439                 stride_height,
440                 stride_width,
441                 dilation_height,
442                 dilation_width,
443                 grad_columns.mutable_data_ptr<scalar_t>());
444           }
445 
446           // M,N,K are dims of matrix A and B
447           // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
448           int64_t m = weight.size(0);
449           int64_t n = input_height * input_width;
450           int64_t k = weight.size(1) * weight.size(2) * weight.size(3);
451 
452           // Do GEMM (note: this is a bit confusing because gemm assumes
453           // column-major matrices)
454           auto gemm_in_ptr = need_columns ? grad_columns.const_data_ptr<scalar_t>()
455               : grad_output_n.const_data_ptr<scalar_t>();
456           at::cuda::blas::gemm<scalar_t>(
457               'n',
458               'n',
459               n,
460               m,
461               k,
462               1,
463               gemm_in_ptr,
464               n,
465               weight.const_data_ptr<scalar_t>(),
466               k,
467               0,
468               grad_input_n.mutable_data_ptr<scalar_t>(),
469               n);
470         }
471 
472         // Resize output
473         if (is_batch) {
474           grad_output.resize_({n_output_plane, output_height, output_width});
475           input.resize_({n_input_plane, input_height, input_width});
476           grad_input.resize_({n_input_plane, input_height, input_width});
477         }
478       }); // end of dispatch
479 }
480 
slow_conv_transpose2d_acc_grad_parameters_cuda_template(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,int scale_)481 void slow_conv_transpose2d_acc_grad_parameters_cuda_template(
482     const Tensor& input_,
483     const Tensor& grad_output_,
484     Tensor& grad_weight,
485     Tensor& grad_bias,
486     IntArrayRef kernel_size,
487     IntArrayRef stride,
488     IntArrayRef padding,
489     IntArrayRef output_padding,
490     IntArrayRef dilation,
491     int scale_) {
492   TORCH_CHECK(
493       kernel_size.size() == 2,
494       "It is expected kernel_size equals to 2, but got size ",
495       kernel_size.size());
496 
497   TORCH_CHECK(
498       dilation.size() == 2,
499       "It is expected dilation equals to 2, but got size ",
500       dilation.size());
501 
502   TORCH_CHECK(
503       padding.size() == 2,
504       "It is expected padding equals to 2, but got size ",
505       padding.size());
506 
507   TORCH_CHECK(
508       stride.size() == 2,
509       "It is expected stride equals to 2, but got size ",
510       stride.size());
511 
512   TORCH_CHECK(
513       output_padding.size() == 2,
514       "It is expected stride equals to 2, but got size ",
515       output_padding.size());
516 
517   TensorArg input_arg{input_, "input", 1},
518       grad_output_arg{grad_output_, "grad_output", 2},
519       grad_weight_arg{grad_weight, "grad_weight", 3},
520       grad_bias_arg{grad_bias, "grad_bias", 4};
521 
522   checkAllSameGPU(
523       __func__,
524       {input_arg,
525        grad_output_arg,
526        grad_weight_arg,
527        grad_bias_arg});
528 
529   int64_t kernel_height = kernel_size[0];
530   int64_t kernel_width = kernel_size[1];
531   int64_t dilation_height = dilation[0];
532   int64_t dilation_width = dilation[1];
533   int64_t pad_height = padding[0];
534   int64_t pad_width = padding[1];
535   int64_t stride_height = stride[0];
536   int64_t stride_width = stride[1];
537   int64_t output_padding_height = output_padding[0];
538   int64_t output_padding_width = output_padding[1];
539 
540   slow_conv_transpose2d_shape_check(
541       input_,
542       grad_output_,
543       grad_weight,
544       grad_bias,
545       kernel_height,
546       kernel_width,
547       stride_height,
548       stride_width,
549       pad_height,
550       pad_width,
551       output_padding_height,
552       output_padding_width,
553       dilation_height,
554       dilation_width,
555       true);
556 
557   Tensor input = input_.contiguous();
558   Tensor grad_output = grad_output_.contiguous();
559 
560   int64_t n_output_plane;
561   if (grad_weight.defined()) {
562     n_output_plane = grad_weight.size(1);
563   } else if (grad_bias.defined()) {
564     n_output_plane = grad_bias.size(0);
565   } else {
566     return;
567   }
568 
569   if (grad_weight.defined()) {
570     TORCH_CHECK(
571         grad_weight.is_contiguous(), "grad_weight needs to be contiguous");
572   }
573 
574   if (grad_bias.defined()) {
575     TORCH_CHECK(grad_bias.is_contiguous(), "grad_bias needs to be contiguous");
576   }
577 
578   bool is_batch = false;
579   if (input.dim() == 3) {
580     // Force batch
581     is_batch = true;
582     input.resize_({1, input.size(0), input.size(1), input.size(2)});
583     grad_output.resize_(
584         {1, grad_output.size(0), grad_output.size(1), grad_output.size(2)});
585   }
586 
587   int64_t input_width = input.size(3);
588   int64_t input_height = input.size(2);
589   int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
590       (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
591   int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
592       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
593 
594   // Batch size + input planes
595   int64_t batch_size = input.size(0);
596 
597   // Create temporary columns
598   bool need_columns = (kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
599       stride_width != 1 || pad_height != 0 || pad_width != 0 ||
600       dilation_height != 1 || dilation_width != 1);
601   Tensor columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height,
602       input_height * input_width}, input.options()) : Tensor();
603 
604   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
605       input.scalar_type(), "slow_conv_transpose2d_acc_grad_parameters_cuda", [&] {
606         // Helpers
607         Tensor input_n = Tensor();
608         Tensor grad_output_n = Tensor();
609 
610         scalar_t scale = static_cast<scalar_t>(scale_);
611 
612         // For each elt in batch, do:
613         for (int elt = 0; elt < batch_size; elt++) {
614           // Matrix multiply per output:
615           grad_output_n = grad_output.select(0, elt);
616 
617           // Do Weight:
618           if (grad_weight.defined()) {
619             // Matrix multiply per output:
620             input_n = input.select(0, elt);
621 
622             if (need_columns) {
623               // Extract columns:
624               im2col<scalar_t>(
625                   at::cuda::getCurrentCUDAStream(),
626                   grad_output_n.const_data_ptr<scalar_t>(),
627                   n_output_plane,
628                   output_height,
629                   output_width,
630                   input_height,
631                   input_width,
632                   kernel_height,
633                   kernel_width,
634                   pad_height,
635                   pad_width,
636                   stride_height,
637                   stride_width,
638                   dilation_height,
639                   dilation_width,
640                   columns.mutable_data_ptr<scalar_t>());
641             }
642 
643             // M,N,K are dims of matrix A and B
644             // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
645             int64_t n = n_output_plane * kernel_height * kernel_width;
646             int64_t m = input_n.size(0); // n_input_plane
647             int64_t k = input_height * input_width;
648 
649             // Do GEMM (note: this is a bit confusing because gemm assumes
650             // column-major matrices)
651             auto gemm_in_ptr = need_columns ? columns.const_data_ptr<scalar_t>()
652                 : grad_output_n.const_data_ptr<scalar_t>();
653             at::cuda::blas::gemm<scalar_t>(
654                 't',
655                 'n',
656                 n,
657                 m,
658                 k,
659                 scale,
660                 gemm_in_ptr,
661                 k,
662                 input_n.const_data_ptr<scalar_t>(),
663                 k,
664                 1,
665                 grad_weight.mutable_data_ptr<scalar_t>(),
666                 n);
667           }
668         }
669 
670         if (grad_bias.defined()) {
671           at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3});
672         }
673 
674         // Resize
675         if (is_batch) {
676           grad_output.resize_({n_output_plane, output_height, output_width});
677           input.resize_({input.size(1), input_height, input_width});
678         }
679       }); // end of dispatch
680 }
681 } // namespace
682 
TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cuda)683 TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cuda)
684 (const Tensor& input,
685  const Tensor& weight,
686  IntArrayRef kernel_size,
687  OptionalTensorRef bias_opt,
688  IntArrayRef stride,
689  IntArrayRef padding,
690  IntArrayRef output_padding,
691  IntArrayRef dilation,
692  const Tensor& output) {
693   const Tensor& bias = bias_opt.getTensorRef();
694 
695   slow_conv_transpose2d_out_cuda_template(
696       output,
697       input,
698       weight,
699       kernel_size,
700       bias,
701       stride,
702       padding,
703       output_padding,
704       dilation);
705 }
706 
slow_conv_transpose2d_backward_out_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)707 std::tuple<Tensor&, Tensor&, Tensor&> slow_conv_transpose2d_backward_out_cuda(const Tensor& grad_output,
708     const Tensor& input,
709     const Tensor& weight,
710     IntArrayRef kernel_size,
711     IntArrayRef stride,
712     IntArrayRef padding,
713     IntArrayRef output_padding,
714     IntArrayRef dilation,
715     Tensor& grad_input,
716     Tensor& grad_weight,
717     Tensor& grad_bias) {
718   if (grad_input.defined()) {
719     slow_conv_transpose2d_backward_out_cuda_template(
720         input,
721         grad_output,
722         grad_input,
723         weight,
724         kernel_size,
725         stride,
726         padding,
727         output_padding,
728         dilation);
729   }
730 
731   if (grad_weight.defined()) {
732     grad_weight.resize_(weight.sizes());
733     grad_weight.zero_();
734   }
735 
736   if (grad_bias.defined()) {
737     grad_bias.resize_({weight.size(1)});
738     grad_bias.zero_();
739   }
740 
741   if (grad_weight.defined() || grad_bias.defined()) {
742     slow_conv_transpose2d_acc_grad_parameters_cuda_template(
743         input,
744         grad_output,
745         grad_weight,
746         grad_bias,
747         kernel_size,
748         stride,
749         padding,
750         output_padding,
751         dilation,
752         1);
753   }
754 
755   return std::tuple<Tensor&, Tensor&, Tensor&>(
756       grad_input, grad_weight, grad_bias);
757 }
758 
slow_conv_transpose2d_backward_cuda(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,std::array<bool,3> output_mask)759 std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose2d_backward_cuda(
760     const Tensor& grad_output,
761     const Tensor& input,
762     const Tensor& weight,
763     IntArrayRef kernel_size,
764     IntArrayRef stride,
765     IntArrayRef padding,
766     IntArrayRef output_padding,
767     IntArrayRef dilation,
768     std::array<bool, 3> output_mask) {
769   Tensor grad_input;
770   Tensor grad_weight;
771   Tensor grad_bias;
772 
773   if (output_mask[0]) {
774     grad_input = at::empty({0}, grad_output.options());
775   } else {
776     grad_input = Tensor();
777   }
778 
779   if (output_mask[1]) {
780     grad_weight = at::empty({0}, grad_output.options());
781   } else {
782     grad_weight = Tensor();
783   }
784 
785   if (output_mask[2]) {
786     grad_bias = at::empty({0}, grad_output.options());
787   } else {
788     grad_bias = Tensor();
789   }
790 
791   if (grad_input.defined()) {
792     slow_conv_transpose2d_backward_out_cuda_template(
793         input,
794         grad_output,
795         grad_input,
796         weight,
797         kernel_size,
798         stride,
799         padding,
800         output_padding,
801         dilation);
802   }
803 
804   if (grad_weight.defined()) {
805     grad_weight.resize_(weight.sizes());
806     grad_weight.zero_();
807   }
808 
809   if (grad_bias.defined()) {
810     grad_bias.resize_({weight.size(1)});
811     grad_bias.zero_();
812   }
813 
814   if (grad_weight.defined() || grad_bias.defined()) {
815     slow_conv_transpose2d_acc_grad_parameters_cuda_template(
816         input,
817         grad_output,
818         grad_weight,
819         grad_bias,
820         kernel_size,
821         stride,
822         padding,
823         output_padding,
824         dilation,
825         1);
826   }
827 
828   return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
829 }
830 
831 REGISTER_CUDA_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cuda);
832 
833 } // namespace at::native
834