xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/TensorMeta.h>
4 #include <ATen/TensorUtils.h>
5 
6 #include <ATen/core/Tensor.h>
7 #include <ATen/native/ConvUtils.h>
8 #include <ATen/native/CPUBlas.h>
9 #include <ATen/native/im2col.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/ones.h>
17 #include <ATen/ops/slow_conv_transpose2d_native.h>
18 #include <ATen/ops/sum.h>
19 #include <ATen/ops/zeros.h>
20 #endif
21 
22 #include <c10/core/TensorOptions.h>
23 #include <c10/util/irange.h>
24 
25 namespace at {
26 namespace {
slow_conv_transpose2d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int kernel_height,int kernel_width,int stride_height,int stride_width,int pad_height,int pad_width,int output_padding_height,int output_padding_width,int dilation_height,int dilation_width,bool weight_nullable)27 static inline void slow_conv_transpose2d_shape_check(
28     const Tensor& input,
29     const Tensor& grad_output,
30     const Tensor& weight,
31     const Tensor& bias,
32     int kernel_height,
33     int kernel_width,
34     int stride_height,
35     int stride_width,
36     int pad_height,
37     int pad_width,
38     int output_padding_height,
39     int output_padding_width,
40     int dilation_height,
41     int dilation_width,
42     bool weight_nullable) {
43   TORCH_CHECK(
44       kernel_width > 0 && kernel_height > 0,
45       "kernel size should be greater than zero, but got kernel_height: ",
46       kernel_height,
47       " kernel_width: ",
48       kernel_width);
49   TORCH_CHECK(
50       stride_width > 0 && stride_height > 0,
51       "stride should be greater than zero, but got stride_height: ",
52       stride_height,
53       " stride_width: ",
54       stride_width);
55   TORCH_CHECK(
56       dilation_width > 0 && dilation_height > 0,
57       "dilation should be greater than zero, but got dilation_height: ",
58       dilation_height,
59       ", dilation_width: ",
60       dilation_width);
61   TORCH_CHECK(
62       (output_padding_width < stride_width ||
63        output_padding_width < dilation_width) &&
64           (output_padding_height < stride_height ||
65            output_padding_height < dilation_height),
66       "output padding must be smaller than either stride or dilation, but got output_padding_height: ",
67       output_padding_height,
68       " output_padding_width: ",
69       output_padding_width,
70       " stride_height: ",
71       stride_height,
72       " stride_width: ",
73       stride_width,
74       " dilation_height: ",
75       dilation_height,
76       " dilation_width: ",
77       dilation_width);
78 
79   if (weight.defined()) {
80     TORCH_CHECK(
81         weight.numel() != 0 && (weight.dim() == 2 || weight.dim() == 4),
82         "non-empty 2D or 4D weight tensor expected, but got: ",
83         weight.sizes());
84     if (bias.defined()) {
85       check_dim_size(bias, 1, 0, weight.size(1));
86     }
87   } else if (!weight_nullable) {
88     AT_ERROR("weight tensor is expected to be non-nullable");
89   }
90 
91   int ndim = input.dim();
92   int dimf = 0;
93   int dimh = 1;
94   int dimw = 2;
95 
96   if (ndim == 4) {
97     dimf++;
98     dimh++;
99     dimw++;
100   }
101 
102   TORCH_CHECK(
103       input.numel() != 0 && (ndim == 3 || ndim == 4),
104       "non-empty 3D or 4D input tensor expected but got a tensor with size ",
105       input.sizes());
106 
107   int64_t input_height = input.size(dimh);
108   int64_t input_width = input.size(dimw);
109   int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
110       (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
111   int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
112       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
113 
114   if (output_width < 1 || output_height < 1) {
115     AT_ERROR(
116         "Given input size per channel: (",
117         input_height,
118         " x ",
119         input_width,
120         "). "
121         "Calculated output size per channel: (",
122         output_height,
123         " x ",
124         output_width,
125         "). Output size is too small");
126   }
127 
128   if (weight.defined()) {
129     int64_t n_input_plane = weight.size(0);
130     check_dim_size(input, ndim, dimf, n_input_plane);
131   }
132 
133   if (grad_output.defined()) {
134     if (weight.defined()) {
135       int64_t n_output_plane = weight.size(1);
136       check_dim_size(grad_output, ndim, dimf, n_output_plane);
137     } else if (bias.defined()) {
138       int64_t n_output_plane = bias.size(0);
139       check_dim_size(grad_output, ndim, dimf, n_output_plane);
140     }
141     check_dim_size(grad_output, ndim, dimh, output_height);
142     check_dim_size(grad_output, ndim, dimw, output_width);
143   }
144 }
145 } // namespace
146 
147 namespace meta {
TORCH_META_FUNC(slow_conv_transpose2d)148 TORCH_META_FUNC(slow_conv_transpose2d)
149 (const Tensor& input,
150  const Tensor& weight,
151  IntArrayRef kernel_size,
152  OptionalTensorRef bias_opt,
153  IntArrayRef stride,
154  IntArrayRef padding,
155  IntArrayRef output_padding,
156  IntArrayRef dilation) {
157   TORCH_CHECK(
158       kernel_size.size() == 2,
159       "It is expected kernel_size equals to 2, but got size ",
160       kernel_size.size());
161 
162   TORCH_CHECK(
163       dilation.size() == 2,
164       "It is expected dilation equals to 2, but got size ",
165       dilation.size());
166 
167   TORCH_CHECK(
168       padding.size() == 2,
169       "It is expected padding equals to 2, but got size ",
170       padding.size());
171 
172   TORCH_CHECK(
173       stride.size() == 2,
174       "It is expected stride equals to 2, but got size ",
175       stride.size());
176 
177   TORCH_CHECK(
178       output_padding.size() == 2,
179       "It is expected stride equals to 2, but got size ",
180       output_padding.size());
181 
182   int64_t kernel_height = kernel_size[0];
183   int64_t kernel_width = kernel_size[1];
184   int64_t dilation_height = dilation[0];
185   int64_t dilation_width = dilation[1];
186   int64_t pad_height = padding[0];
187   int64_t pad_width = padding[1];
188   int64_t stride_height = stride[0];
189   int64_t stride_width = stride[1];
190   int64_t output_padding_height = output_padding[0];
191   int64_t output_padding_width = output_padding[1];
192 
193   slow_conv_transpose2d_shape_check(
194       input,
195       Tensor(),
196       weight,
197       bias_opt.getTensorRef(),
198       kernel_height,
199       kernel_width,
200       stride_height,
201       stride_width,
202       pad_height,
203       pad_width,
204       output_padding_height,
205       output_padding_width,
206       dilation_height,
207       dilation_width,
208       false);
209 
210   int n_output_plane = weight.size(1);
211 
212   bool use_channels_last = native::thnn_conv_use_channels_last(input, weight);
213   auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
214 
215   Tensor input_ = input.contiguous(memory_format);
216 
217   if (input_.dim() == 3) {
218     input_.resize_({1, input_.size(0), input_.size(1), input_.size(2)});
219   }
220 
221   int64_t input_height = input_.size(2);
222   int64_t input_width = input_.size(3);
223   int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
224       (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
225   int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
226       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
227 
228   // Batch size + input planes
229   int64_t batch_size = input_.size(0);
230 
231   // Resize output
232   TensorOptions options(input.options());
233   set_output_raw_strided(
234       0,
235       {batch_size, n_output_plane, output_height, output_width},
236       {},
237       options.memory_format(memory_format));
238 }
239 } // namespace meta
240 
241 namespace native {
242 
243 namespace {
slow_conv_transpose2d_out_cpu_template(const Tensor & output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)244 void slow_conv_transpose2d_out_cpu_template(
245     const Tensor& output,
246     const Tensor& input,
247     const Tensor& weight,
248     IntArrayRef kernel_size,
249     const Tensor& bias,
250     IntArrayRef stride,
251     IntArrayRef padding,
252     IntArrayRef output_padding,
253     IntArrayRef dilation) {
254   int64_t kernel_height = kernel_size[0];
255   int64_t kernel_width = kernel_size[1];
256   int64_t dilation_height = dilation[0];
257   int64_t dilation_width = dilation[1];
258   int64_t pad_height = padding[0];
259   int64_t pad_width = padding[1];
260   int64_t stride_height = stride[0];
261   int64_t stride_width = stride[1];
262   int64_t output_padding_height = output_padding[0];
263   int64_t output_padding_width = output_padding[1];
264 
265   int n_input_plane = weight.size(0);
266   int n_output_plane = weight.size(1);
267 
268   bool use_channels_last = thnn_conv_use_channels_last(input, weight);
269   auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
270 
271   Tensor input_ = input.contiguous(memory_format);
272   Tensor weight_ = weight.contiguous(memory_format);
273   Tensor bias_ = bias.defined() ? bias.contiguous() : Tensor();
274 
275   bool is_batch = false;
276   if (input_.dim() == 3) {
277     // Force batch
278     is_batch = true;
279     input_.resize_({1, input.size(0), input.size(1), input.size(2)});
280   }
281 
282   int64_t input_height = input_.size(2);
283   int64_t input_width = input_.size(3);
284   int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
285       (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
286   int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
287       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
288 
289   // Batch size + input planes
290   int64_t batch_size = input_.size(0);
291 
292   // Create temporary columns
293   Tensor columns = at::empty({0}, input.options());
294   if (use_channels_last) {
295     columns.resize_({batch_size, input_height * input_width, kernel_height * kernel_width * n_output_plane});
296   } else {
297     columns.resize_({batch_size, n_output_plane * kernel_height * kernel_width, input_height * input_width});
298   }
299   columns.zero_();
300 
301   // Materialize if COW, since we cannot do so during parallel_for
302   output.mutable_data_ptr();
303 
304   AT_DISPATCH_FLOATING_TYPES_AND3(at::ScalarType::Long, at::ScalarType::BFloat16,
305       at::ScalarType::Half, input.scalar_type(), "slow_conv_transpose2d_out_cpu", [&] {
306 
307     at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
308       // For each elt in batch, do:
309       for (const auto elt : c10::irange(begin, end)) {
310         // Matrix multiply per output:
311         Tensor input_n = input_.select(0, elt);
312         Tensor output_n = output.select(0, elt);
313         Tensor columns_n = columns.select(0, elt);
314 
315         if (use_channels_last) {
316           int64_t m = kernel_height * kernel_width * n_output_plane;
317           int64_t n = input_height * input_width;
318           int64_t k = n_input_plane;
319 
320           // column-major matrices
321           cpublas::gemm(
322               TransposeType::NoTranspose,
323               TransposeType::NoTranspose,
324               m,
325               n,
326               k,
327               static_cast<scalar_t>(1),
328               weight_.const_data_ptr<scalar_t>(),
329               m,
330               input_n.const_data_ptr<scalar_t>(),
331               k,
332               static_cast<scalar_t>(0),
333               columns_n.mutable_data_ptr<scalar_t>(),
334               m);
335         } else {
336           int64_t m = input_height * input_width;
337           int64_t n = n_output_plane * kernel_height * kernel_width;
338           int64_t k = n_input_plane;
339 
340           // column-major matrices
341           cpublas::gemm(
342               TransposeType::NoTranspose,
343               TransposeType::Transpose,
344               m,
345               n,
346               k,
347               static_cast<scalar_t>(1),
348               input_n.const_data_ptr<scalar_t>(),
349               m,
350               weight_.const_data_ptr<scalar_t>(),
351               n,
352               static_cast<scalar_t>(0),
353               columns_n.mutable_data_ptr<scalar_t>(),
354               m);
355         }
356 
357         // Unpack columns back into input:
358         col2im<scalar_t>(
359             columns_n.const_data_ptr<scalar_t>(),
360             n_output_plane,
361             output_height,
362             output_width,
363             input_height,
364             input_width,
365             kernel_height,
366             kernel_width,
367             pad_height,
368             pad_width,
369             stride_height,
370             stride_width,
371             dilation_height,
372             dilation_width,
373             output_n.data_ptr<scalar_t>(),
374             use_channels_last);
375       }
376     });
377   });
378 
379   if (bias.defined()) {
380     output.add_(bias_.reshape({-1, 1, 1}));
381   }
382 
383   // Resize output
384   if (is_batch) {
385     output.resize_({n_output_plane, output_height, output_width});
386   }
387 }
388 
slow_conv_transpose2d_backward_out_cpu_template(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_input,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)389 static void slow_conv_transpose2d_backward_out_cpu_template(
390     const Tensor& input_,
391     const Tensor& grad_output_,
392     Tensor& grad_input,
393     const Tensor& weight_,
394     IntArrayRef kernel_size,
395     IntArrayRef stride,
396     IntArrayRef padding,
397     IntArrayRef output_padding,
398     IntArrayRef dilation) {
399   TORCH_CHECK(
400       kernel_size.size() == 2,
401       "It is expected kernel_size equals to 2, but got size ",
402       kernel_size.size());
403 
404   TORCH_CHECK(
405       dilation.size() == 2,
406       "It is expected dilation equals to 2, but got size ",
407       dilation.size());
408 
409   TORCH_CHECK(
410       padding.size() == 2,
411       "It is expected padding equals to 2, but got size ",
412       padding.size());
413 
414   TORCH_CHECK(
415       stride.size() == 2,
416       "It is expected stride equals to 2, but got size ",
417       stride.size());
418 
419   TORCH_CHECK(
420       output_padding.size() == 2,
421       "It is expected stride equals to 2, but got size ",
422       output_padding.size());
423 
424   int64_t kernel_height = kernel_size[0];
425   int64_t kernel_width = kernel_size[1];
426   int64_t dilation_height = dilation[0];
427   int64_t dilation_width = dilation[1];
428   int64_t pad_height = padding[0];
429   int64_t pad_width = padding[1];
430   int64_t stride_height = stride[0];
431   int64_t stride_width = stride[1];
432   int64_t output_padding_height = output_padding[0];
433   int64_t output_padding_width = output_padding[1];
434 
435   int64_t n_input_plane = weight_.size(0);
436   int64_t n_output_plane = weight_.size(1);
437 
438   bool use_channels_last = thnn_conv_use_channels_last(input_, weight_);
439   auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
440 
441   slow_conv_transpose2d_shape_check(
442       input_,
443       grad_output_,
444       weight_,
445       Tensor(),
446       kernel_height,
447       kernel_width,
448       stride_height,
449       stride_width,
450       pad_height,
451       pad_width,
452       output_padding_height,
453       output_padding_width,
454       dilation_height,
455       dilation_width,
456       false);
457 
458   Tensor input = input_.contiguous(memory_format);
459   Tensor grad_output = grad_output_.contiguous(memory_format);
460   Tensor weight = weight_.contiguous(memory_format);
461 
462   bool is_batch = false;
463   if (input.dim() == 3) {
464     // Force batch
465     is_batch = true;
466     input.resize_({1, input.size(0), input.size(1), input.size(2)});
467     grad_output.resize_(
468         {1, grad_output.size(0), grad_output.size(1), grad_output.size(2)});
469   }
470 
471   int64_t input_width = input.size(3);
472   int64_t input_height = input.size(2);
473   int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
474       (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
475   int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
476       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
477 
478   // Batch size + input planes
479   int64_t batch_size = input.size(0);
480 
481   // Resize output
482   grad_input.resize_({batch_size, n_input_plane, input_height, input_width}, memory_format);
483   grad_input.zero_();
484 
485   // Create temporary columns
486   bool need_columns = (kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
487       stride_width != 1 || pad_height != 0 || pad_width != 0 ||
488       dilation_height != 1 || dilation_width != 1);
489 
490   Tensor grad_columns = at::empty({0}, input.options());
491   if (need_columns) {
492     if (use_channels_last) {
493       grad_columns.resize_({input_height * input_width, kernel_height * kernel_width * n_output_plane});
494     } else {
495       grad_columns.resize_({n_output_plane * kernel_height * kernel_width, input_height * input_width});
496     }
497   }
498 
499   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
500       grad_output.scalar_type(), "slow_conv_transpose2d_backward_out_cpu", [&] {
501         // Helpers
502         Tensor grad_input_n = Tensor();
503         Tensor grad_output_n = Tensor();
504 
505         // For each elt in batch, do:
506         for (const auto elt : c10::irange(batch_size)) {
507           // Matrix multiply per sample:
508           grad_input_n = grad_input.select(0, elt);
509           grad_output_n = grad_output.select(0, elt);
510 
511           if (need_columns) {
512             // Extract columns:
513             im2col<scalar_t>(
514                   grad_output_n.const_data_ptr<scalar_t>(),
515                   n_output_plane,
516                   output_height,
517                   output_width,
518                   input_height,
519                   input_width,
520                   kernel_height,
521                   kernel_width,
522                   pad_height,
523                   pad_width,
524                   stride_height,
525                   stride_width,
526                   dilation_height,
527                   dilation_width,
528                   grad_columns.data_ptr<scalar_t>(),
529                   use_channels_last);
530           }
531 
532           auto gemm_in_ptr = need_columns ? grad_columns.const_data_ptr<scalar_t>()
533               : grad_output_n.const_data_ptr<scalar_t>();
534 
535           if (use_channels_last) {
536             int64_t m = n_input_plane;
537             int64_t n = input_height * input_width;
538             int64_t k = n_output_plane * kernel_height * kernel_width;
539 
540             // column-major matrices
541             cpublas::gemm(
542                 TransposeType::Transpose,
543                 TransposeType::NoTranspose,
544                 m,
545                 n,
546                 k,
547                 static_cast<scalar_t>(1),
548                 weight.const_data_ptr<scalar_t>(),
549                 k,
550                 gemm_in_ptr,
551                 k,
552                 static_cast<scalar_t>(0),
553                 grad_input_n.mutable_data_ptr<scalar_t>(),
554                 m);
555 
556           } else {
557             int64_t m = input_height * input_width;
558             int64_t n = n_input_plane;
559             int64_t k = n_output_plane * kernel_height * kernel_width;
560 
561             // column-major matrices
562             cpublas::gemm(
563                 TransposeType::NoTranspose,
564                 TransposeType::NoTranspose,
565                 m,
566                 n,
567                 k,
568                 static_cast<scalar_t>(1),
569                 gemm_in_ptr,
570                 m,
571                 weight.const_data_ptr<scalar_t>(),
572                 k,
573                 static_cast<scalar_t>(0),
574                 grad_input_n.mutable_data_ptr<scalar_t>(),
575                 m);
576           }
577         }
578 
579         // Resize output
580         if (is_batch) {
581           grad_input.resize_({n_input_plane, input_height, input_width});
582         }
583       });
584 }
585 
slow_conv_transpose2d_acc_grad_parameters_cpu(const Tensor & input_,const Tensor & weight_,const Tensor & grad_output_,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,int scale_)586 void slow_conv_transpose2d_acc_grad_parameters_cpu(
587     const Tensor& input_,
588     const Tensor& weight_,
589     const Tensor& grad_output_,
590     Tensor& grad_weight,
591     Tensor& grad_bias,
592     IntArrayRef kernel_size,
593     IntArrayRef stride,
594     IntArrayRef padding,
595     IntArrayRef output_padding,
596     IntArrayRef dilation,
597     int scale_) {
598   TORCH_CHECK(
599       kernel_size.size() == 2,
600       "It is expected kernel_size equals to 2, but got size ",
601       kernel_size.size());
602 
603   TORCH_CHECK(
604       dilation.size() == 2,
605       "It is expected dilation equals to 2, but got size ",
606       dilation.size());
607 
608   TORCH_CHECK(
609       padding.size() == 2,
610       "It is expected padding equals to 2, but got size ",
611       padding.size());
612 
613   TORCH_CHECK(
614       stride.size() == 2,
615       "It is expected stride equals to 2, but got size ",
616       stride.size());
617 
618   TORCH_CHECK(
619       output_padding.size() == 2,
620       "It is expected stride equals to 2, but got size ",
621       output_padding.size());
622 
623   int64_t kernel_height = kernel_size[0];
624   int64_t kernel_width = kernel_size[1];
625   int64_t dilation_height = dilation[0];
626   int64_t dilation_width = dilation[1];
627   int64_t pad_height = padding[0];
628   int64_t pad_width = padding[1];
629   int64_t stride_height = stride[0];
630   int64_t stride_width = stride[1];
631   int64_t output_padding_height = output_padding[0];
632   int64_t output_padding_width = output_padding[1];
633 
634   bool use_channels_last = thnn_conv_use_channels_last(input_, weight_);
635   auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
636 
637   slow_conv_transpose2d_shape_check(
638       input_,
639       grad_output_,
640       grad_weight,
641       grad_bias,
642       kernel_height,
643       kernel_width,
644       stride_height,
645       stride_width,
646       pad_height,
647       pad_width,
648       output_padding_height,
649       output_padding_width,
650       dilation_height,
651       dilation_width,
652       true);
653 
654   int n_input_plane = weight_.size(0);
655   int n_output_plane = weight_.size(1);
656 
657   Tensor input = input_.contiguous(memory_format);
658   Tensor grad_output = grad_output_.contiguous(memory_format);
659   TORCH_CHECK(grad_weight.is_contiguous(memory_format), "grad_weight needs to be contiguous");
660 
661   if (input.dim() == 3) {
662     input.resize_({1, input.size(0), input.size(1), input.size(2)});
663     grad_output.resize_(
664         {1, grad_output.size(0), grad_output.size(1), grad_output.size(2)});
665   }
666 
667   int64_t input_width = input.size(3);
668   int64_t input_height = input.size(2);
669   int64_t output_height = (input_height - 1) * stride_height - 2 * pad_height +
670       (dilation_height * (kernel_height - 1) + 1) + output_padding_height;
671   int64_t output_width = (input_width - 1) * stride_width - 2 * pad_width +
672       (dilation_width * (kernel_width - 1) + 1) + output_padding_width;
673 
674   // Batch size + input planes
675   int64_t batch_size = input.size(0);
676 
677   // Resize temporary columns
678   bool need_columns = (kernel_height != 1 || kernel_width != 1 || stride_height != 1 ||
679       stride_width != 1 || pad_height != 0 || pad_width != 0 ||
680       dilation_height != 1 || dilation_width != 1);
681 
682   Tensor columns = at::empty({0}, input.options());
683   if (need_columns) {
684     if (use_channels_last) {
685       columns.resize_({input_height * input_width, kernel_height * kernel_width * n_output_plane});
686     } else {
687       columns.resize_({n_output_plane * kernel_height * kernel_width, input_height * input_width});
688     }
689   }
690 
691   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
692       input.scalar_type(), "slow_conv_transpose2d_acc_grad_parameters_cpu", [&] {
693         // Helpers
694         Tensor input_n = Tensor();
695         Tensor grad_output_n = Tensor();
696 
697         scalar_t scale = static_cast<scalar_t>(scale_);
698 
699         // For each elt in batch, do:
700         for (const auto elt : c10::irange(batch_size)) {
701           // Matrix multiply per output:
702           grad_output_n = grad_output.select(0, elt);
703 
704           // Do Weight:
705           if (grad_weight.defined()) {
706             // Matrix multiply per output:
707             input_n = input.select(0, elt);
708 
709             if (need_columns) {
710               // Extract columns:
711               im2col<scalar_t>(
712                   grad_output_n.const_data_ptr<scalar_t>(),
713                   n_output_plane,
714                   output_height,
715                   output_width,
716                   input_height,
717                   input_width,
718                   kernel_height,
719                   kernel_width,
720                   pad_height,
721                   pad_width,
722                   stride_height,
723                   stride_width,
724                   dilation_height,
725                   dilation_width,
726                   columns.data_ptr<scalar_t>(),
727                   use_channels_last);
728             }
729 
730             auto gemm_in_ptr = need_columns ? columns.const_data_ptr<scalar_t>()
731                 : grad_output_n.const_data_ptr<scalar_t>();
732 
733             if (use_channels_last) {
734               int64_t m = kernel_height * kernel_width * n_output_plane;
735               int64_t n = n_input_plane;
736               int64_t k = input_height * input_width;
737 
738               // column-major matrices
739               cpublas::gemm(
740                   TransposeType::NoTranspose,
741                   TransposeType::Transpose,
742                   m,
743                   n,
744                   k,
745                   static_cast<scalar_t>(scale),
746                   gemm_in_ptr,
747                   m,
748                   input_n.const_data_ptr<scalar_t>(),
749                   n,
750                   static_cast<scalar_t>(1),
751                   grad_weight.mutable_data_ptr<scalar_t>(),
752                   m);
753             } else {
754               int64_t m = n_output_plane * kernel_height * kernel_width;
755               int64_t n = n_input_plane;
756               int64_t k = input_height * input_width;
757 
758               // column-major matrices
759               cpublas::gemm(
760                   TransposeType::Transpose,
761                   TransposeType::NoTranspose,
762                   m,
763                   n,
764                   k,
765                   static_cast<scalar_t>(scale),
766                   gemm_in_ptr,
767                   k,
768                   input_n.const_data_ptr<scalar_t>(),
769                   k,
770                   static_cast<scalar_t>(1),
771                   grad_weight.mutable_data_ptr<scalar_t>(),
772                   m);
773             }
774           }
775         }
776       });
777 }
778 
779 } // namespace
780 
TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cpu)781 TORCH_IMPL_FUNC(slow_conv_transpose2d_structured_cpu)
782 (const Tensor& input,
783  const Tensor& weight,
784  IntArrayRef kernel_size,
785  OptionalTensorRef bias_opt,
786  IntArrayRef stride,
787  IntArrayRef padding,
788  IntArrayRef output_padding,
789  IntArrayRef dilation,
790  const Tensor& output){
791   const Tensor& bias = bias_opt.getTensorRef();
792 
793   slow_conv_transpose2d_out_cpu_template(
794       output,
795       input,
796       weight,
797       kernel_size,
798       bias,
799       stride,
800       padding,
801       output_padding,
802       dilation);
803  }
804 
slow_conv_transpose2d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,std::array<bool,3> output_mask)805 static std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose2d_backward_cpu(
806     const Tensor& grad_output,
807     const Tensor& input,
808     const Tensor& weight,
809     IntArrayRef kernel_size,
810     IntArrayRef stride,
811     IntArrayRef padding,
812     IntArrayRef output_padding,
813     IntArrayRef dilation,
814     std::array<bool, 3> output_mask) {
815   Tensor grad_input;
816   Tensor grad_weight;
817   Tensor grad_bias;
818 
819   if (output_mask[0]) {
820     grad_input = at::empty({0}, grad_output.options());
821   } else {
822     grad_input = Tensor();
823   }
824 
825   if (output_mask[1]) {
826     grad_weight = at::empty({0}, grad_output.options());
827   } else {
828     grad_weight = Tensor();
829   }
830 
831   if (output_mask[2]) {
832     grad_bias = at::empty({0}, grad_output.options());
833   } else {
834     grad_bias = Tensor();
835   }
836 
837   if (grad_input.defined()) {
838     slow_conv_transpose2d_backward_out_cpu_template(
839         input,
840         grad_output,
841         grad_input,
842         weight,
843         kernel_size,
844         stride,
845         padding,
846         output_padding,
847         dilation);
848   }
849 
850   if (grad_bias.defined()) {
851     at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3});
852   }
853 
854   if (grad_weight.defined()) {
855     grad_weight.resize_(weight.sizes(), weight.suggest_memory_format());
856     grad_weight.zero_();
857     slow_conv_transpose2d_acc_grad_parameters_cpu(
858         input,
859         weight,
860         grad_output,
861         grad_weight,
862         grad_bias,
863         kernel_size,
864         stride,
865         padding,
866         output_padding,
867         dilation,
868         1);
869   }
870 
871   return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
872 }
873 
874 REGISTER_ALL_CPU_DISPATCH(slow_conv_transpose2d_backward_stub, &slow_conv_transpose2d_backward_cpu);
875 
876 } // namespace native
877 } // namespace at
878