xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorUtils.h>
5 
6 #include <ATen/native/ConvUtils.h>
7 #include <ATen/native/CPUBlas.h>
8 #include <ATen/native/vol2col.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/ones.h>
17 #include <ATen/ops/slow_conv_transpose3d_native.h>
18 #include <ATen/ops/sum.h>
19 #endif
20 
21 namespace at::native {
22 
23 template<typename scalar_t>
24 void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t lda, scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy);
25 
26 namespace {
27 
slow_conv_transpose3d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int kernel_depth,int kernel_width,int kernel_height,int stride_depth,int stride_width,int stride_height,int padding_depth,int padding_width,int padding_height,int dilation_depth,int dilation_width,int dilation_height,int output_padding_depth,int output_padding_width,int output_padding_height,int weight_nullable)28 static inline void slow_conv_transpose3d_shape_check(
29     const Tensor& input,
30     const Tensor& grad_output,
31     const Tensor& weight,
32     const Tensor& bias,
33     int kernel_depth,
34     int kernel_width,
35     int kernel_height,
36     int stride_depth,
37     int stride_width,
38     int stride_height,
39     int padding_depth,
40     int padding_width,
41     int padding_height,
42     int dilation_depth,
43     int dilation_width,
44     int dilation_height,
45     int output_padding_depth,
46     int output_padding_width,
47     int output_padding_height,
48     int weight_nullable) {
49   TORCH_CHECK(
50       input.numel() != 0 && (input.dim() == 4 || input.dim() == 5),
51       "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ",
52       input.sizes());
53   TORCH_CHECK(
54       stride_depth > 0 && stride_width > 0 && stride_height > 0,
55       "stride should be greater than zero, but got stride_depth: ",
56       stride_depth,
57       " stride_height: ",
58       stride_height,
59       " stride_width: ",
60       stride_width);
61   TORCH_CHECK(
62       dilation_depth > 0 && dilation_width > 0 && dilation_height > 0,
63       "dilation should be greater than zero, but got dilation_depth: ",
64       dilation_depth,
65       ", dilation_height: ",
66       dilation_height,
67       ", dilation_width: ",
68       dilation_width);
69   TORCH_CHECK(
70       (output_padding_depth < stride_depth ||
71        output_padding_depth < dilation_depth) &&
72           (output_padding_width < stride_width ||
73            output_padding_width < dilation_width) &&
74           (output_padding_height < stride_height ||
75            output_padding_height < dilation_height),
76       "output padding must be smaller than either stride or dilation,",
77       " but got output_padding_depth: ",
78       output_padding_depth,
79       " output_padding_height: ",
80       output_padding_height,
81       " output_padding_width: ",
82       output_padding_width,
83       " stride_depth: ",
84       stride_depth,
85       " stride_height: ",
86       stride_height,
87       " stride_width: ",
88       stride_width,
89       " dilation_depth: ",
90       dilation_depth,
91       " dilation_height: ",
92       dilation_height,
93       " dilation_width: ",
94       dilation_width);
95 
96   // number of input & output planes and kernel size is indirectly defined by
97   // the weight tensor
98   if (weight.defined()) {
99     /* TODO: TORCH_CHECK just have 2 args: condition and message */
100     TORCH_CHECK(
101         weight.numel() != 0 && weight.dim() == 5,
102         "non-empty 5D (n_output_plane x n_input_plane x kernel_depth",
103         " x kernel_height x kernel_width) tensor ",
104         "expected for weight, but got: ",
105         weight.sizes());
106     if (bias.defined()) {
107       check_dim_size(bias, 1, 0, weight.size(1));
108     }
109   } else if (!weight_nullable) {
110     AT_ERROR("weight tensor is expected to be non-nullable");
111   }
112 
113   int ndim = input.dim();
114   int dimf = 0;
115   int dimd = 1;
116   int dimh = 2;
117   int dimw = 3;
118 
119   if (ndim == 5) {
120     dimf++;
121     dimd++;
122     dimh++;
123     dimw++;
124   }
125 
126   if (weight.defined()) {
127     const int64_t n_input_plane = weight.size(0);
128     check_dim_size(input, ndim, dimf, n_input_plane);
129   }
130 
131   const int64_t input_width = input.size(dimw);
132   const int64_t input_height = input.size(dimh);
133   const int64_t input_depth = input.size(dimd);
134   const int64_t output_depth = (input_depth - 1) * stride_depth -
135       2 * padding_depth + (dilation_depth * (kernel_depth - 1) + 1) +
136       output_padding_depth;
137   const int64_t output_height = (input_height - 1) * stride_height -
138       2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
139       output_padding_height;
140   const int64_t output_width = (input_width - 1) * stride_width -
141       2 * padding_width + (dilation_width * (kernel_width - 1) + 1) +
142       output_padding_width;
143 
144   if (output_depth < 1 || output_width < 1 || output_height < 1) {
145     AT_ERROR(
146         "Given input size per channel: (",
147         input_depth,
148         " x ",
149         input_height,
150         " x ",
151         input_width,
152         "). "
153         "Calculated output size per channel: (",
154         output_depth,
155         " x ",
156         output_height,
157         " x ",
158         output_width,
159         "). Output size is too small");
160   }
161 
162   if (grad_output.defined()) {
163     if (weight.defined()) {
164       const int64_t n_output_plane = weight.size(1);
165       check_dim_size(grad_output, ndim, dimf, n_output_plane);
166     } else if (bias.defined()) {
167       const int64_t n_output_plane = bias.size(0);
168       check_dim_size(grad_output, ndim, dimf, n_output_plane);
169     }
170     check_dim_size(grad_output, ndim, dimd, output_depth);
171     check_dim_size(grad_output, ndim, dimh, output_height);
172     check_dim_size(grad_output, ndim, dimw, output_width);
173   }
174 }
175 
slow_conv_transpose3d_out_cpu_template(Tensor & output,const Tensor & input_,const Tensor & weight_,IntArrayRef kernel_size,const Tensor & bias_,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)176 void slow_conv_transpose3d_out_cpu_template(
177     Tensor& output,
178     const Tensor& input_, // 4D or 5D (batch) tensor
179     const Tensor& weight_, // weight tensor (n_input_plane x n_output_plane x
180                            // kernel_depth x kernel_height x kernel_width)
181     IntArrayRef kernel_size,
182     const Tensor& bias_,
183     IntArrayRef stride,
184     IntArrayRef padding,
185     IntArrayRef output_padding,
186     IntArrayRef dilation) {
187   TORCH_CHECK(
188       kernel_size.size() == 3,
189       "It is expected kernel_size equals to 3, but got size ",
190       kernel_size.size());
191 
192   TORCH_CHECK(
193       dilation.size() == 3,
194       "It is expected dilation equals to 3, but got size ",
195       dilation.size());
196 
197   TORCH_CHECK(
198       padding.size() == 3,
199       "It is expected padding equals to 3, but got size ",
200       padding.size());
201 
202   TORCH_CHECK(
203       stride.size() == 3,
204       "It is expected stride equals to 3, but got size ",
205       stride.size());
206 
207   TORCH_CHECK(
208       output_padding.size() == 3,
209       "It is expected stride equals to 3, but got size ",
210       output_padding.size());
211 
212   int64_t kernel_depth = kernel_size[0];
213   int64_t kernel_height = kernel_size[1];
214   int64_t kernel_width = kernel_size[2];
215   int64_t dilation_depth = dilation[0];
216   int64_t dilation_height = dilation[1];
217   int64_t dilation_width = dilation[2];
218   int64_t padding_depth = padding[0];
219   int64_t padding_height = padding[1];
220   int64_t padding_width = padding[2];
221   int64_t stride_depth = stride[0];
222   int64_t stride_height = stride[1];
223   int64_t stride_width = stride[2];
224   int64_t output_padding_depth = output_padding[0];
225   int64_t output_padding_height = output_padding[1];
226   int64_t output_padding_width = output_padding[2];
227 
228   slow_conv_transpose3d_shape_check(
229       input_,
230       Tensor(),
231       weight_,
232       bias_,
233       kernel_depth,
234       kernel_width,
235       kernel_height,
236       stride_depth,
237       stride_width,
238       stride_height,
239       padding_depth,
240       padding_width,
241       padding_height,
242       dilation_depth,
243       dilation_width,
244       dilation_height,
245       output_padding_depth,
246       output_padding_width,
247       output_padding_height,
248       0);
249 
250   Tensor input = input_.contiguous();
251   Tensor weight = weight_.contiguous();
252   Tensor bias = bias_.defined() ? bias_.contiguous() : bias_;
253 
254   const int n_input_plane = (int)weight.size(0);
255   const int n_output_plane = (int)weight.size(1);
256 
257   bool is_batch = false;
258   if (input.dim() == 4) {
259     // Force batch
260     is_batch = true;
261     input.resize_(
262         {1, input.size(0), input.size(1), input.size(2), input.size(3)});
263   }
264 
265   const int64_t input_width = input.size(4);
266   const int64_t input_height = input.size(3);
267   const int64_t input_depth = input.size(2);
268 
269   const int64_t output_depth = (input_depth - 1) * stride_depth -
270       2 * padding_depth + (dilation_depth * (kernel_depth - 1) + 1) +
271       output_padding_depth;
272   const int64_t output_height = (input_height - 1) * stride_height -
273       2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
274       output_padding_height;
275   const int64_t output_width = (input_width - 1) * stride_width -
276       2 * padding_width + (dilation_width * (kernel_width - 1) + 1) +
277       output_padding_width;
278 
279   // Batch size + input planes
280   const int64_t batch_size = input.size(0);
281 
282   // Resize output
283   output.resize_(
284       {batch_size, n_output_plane, output_depth, output_height, output_width});
285 
286   // Create temporary columns
287   Tensor columns = at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
288       input_depth * input_height * input_width}, input.options());
289 
290   // Define a buffer of ones, for bias accumulation
291   Tensor ones = bias.defined() ? at::ones({output_depth, output_height, output_width}, input_.options()) : Tensor();
292 
293   AT_DISPATCH_FLOATING_TYPES_AND3(at::ScalarType::Long, at::ScalarType::BFloat16, at::ScalarType::Half,
294       input.scalar_type(), "slow_conv_transpose3d_out_cpu", [&] {
295         // Helpers
296         Tensor input_n;
297         Tensor output_n;
298 
299         int64_t elt;
300         // For each elt in batch, do:
301         for (elt = 0; elt < batch_size; ++elt) {
302           // Matrix mulitply per output:
303           input_n = input.select(0, elt);
304           output_n = output.select(0, elt);
305 
306           // M,N,K are dims of matrix A and B
307           // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
308           const int64_t m =
309               weight.size(1) * weight.size(2) * weight.size(3) * weight.size(4);
310           const int64_t n = columns.size(1);
311           const int64_t k = weight.size(0);
312 
313           // Do GEMM (note: this is a bit confusing because gemm assumes
314           // column-major matrices)
315           cpublas::gemm(
316               TransposeType::NoTranspose,
317               TransposeType::Transpose,
318               n,
319               m,
320               k,
321               static_cast<scalar_t>(1),
322               input_n.const_data_ptr<scalar_t>(),
323               n,
324               weight.const_data_ptr<scalar_t>(),
325               m,
326               static_cast<scalar_t>(0),
327               columns.mutable_data_ptr<scalar_t>(),
328               n);
329 
330           // Unpack columns back into input:
331           at::native::col2vol<scalar_t>(
332               columns.const_data_ptr<scalar_t>(),
333               n_output_plane,
334               output_depth,
335               output_height,
336               output_width,
337               input_depth,
338               input_height,
339               input_width,
340               kernel_depth,
341               kernel_height,
342               kernel_width,
343               padding_depth,
344               padding_height,
345               padding_width,
346               stride_depth,
347               stride_height,
348               stride_width,
349               dilation_depth,
350               dilation_height,
351               dilation_width,
352               output_n.data_ptr<scalar_t>());
353 
354           // Do Bias after:
355           // M,N,K are dims of matrix A and B
356           // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
357           const int64_t m_ = n_output_plane;
358           const int64_t n_ = output_depth * output_height * output_width;
359           const int64_t k_ = 1;
360 
361           // Do GEMM (note: this is a bit confusing because gemm assumes
362           // column-major matrices)
363           if (bias.defined()) {
364             cpublas::gemm(
365                 TransposeType::Transpose,
366                 TransposeType::NoTranspose,
367                 n_,
368                 m_,
369                 k_,
370                 static_cast<scalar_t>(1),
371                 ones.const_data_ptr<scalar_t>(),
372                 k_,
373                 bias.const_data_ptr<scalar_t>(),
374                 k_,
375                 static_cast<scalar_t>(1),
376                 output_n.mutable_data_ptr<scalar_t>(),
377                 n_);
378           }
379         }
380 
381         // Resize output
382         if (is_batch) {
383           output.resize_(
384               {n_output_plane, output_depth, output_height, output_width});
385           input.resize_(
386               {n_input_plane, input_depth, input_height, input_width});
387         }
388       });
389 }
390 
slow_conv_transpose3d_backward_out_cpu_template(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_input,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)391 void slow_conv_transpose3d_backward_out_cpu_template(
392     const Tensor& input_,
393     const Tensor& grad_output_,
394     Tensor& grad_input,
395     const Tensor& weight_,
396     IntArrayRef kernel_size,
397     IntArrayRef stride,
398     IntArrayRef padding,
399     IntArrayRef output_padding,
400     IntArrayRef dilation) {
401   TORCH_CHECK(
402       kernel_size.size() == 3,
403       "It is expected kernel_size equals to 3, but got size ",
404       kernel_size.size());
405 
406   TORCH_CHECK(
407       dilation.size() == 3,
408       "It is expected dilation equals to 3, but got size ",
409       dilation.size());
410 
411   TORCH_CHECK(
412       padding.size() == 3,
413       "It is expected padding equals to 3, but got size ",
414       padding.size());
415 
416   TORCH_CHECK(
417       stride.size() == 3,
418       "It is expected stride equals to 3, but got size ",
419       stride.size());
420 
421   TORCH_CHECK(
422       output_padding.size() == 3,
423       "It is expected stride equals to 3, but got size ",
424       output_padding.size());
425 
426   int64_t kernel_depth = kernel_size[0];
427   int64_t kernel_height = kernel_size[1];
428   int64_t kernel_width = kernel_size[2];
429   int64_t dilation_depth = dilation[0];
430   int64_t dilation_height = dilation[1];
431   int64_t dilation_width = dilation[2];
432   int64_t padding_depth = padding[0];
433   int64_t padding_height = padding[1];
434   int64_t padding_width = padding[2];
435   int64_t stride_depth = stride[0];
436   int64_t stride_height = stride[1];
437   int64_t stride_width = stride[2];
438   int64_t output_padding_depth = output_padding[0];
439   int64_t output_padding_height = output_padding[1];
440   int64_t output_padding_width = output_padding[2];
441 
442   // number of input & output planes and kernel size is indirectly defined by
443   // the weight tensor
444   slow_conv_transpose3d_shape_check(
445       input_,
446       grad_output_,
447       weight_,
448       Tensor(),
449       kernel_depth,
450       kernel_width,
451       kernel_height,
452       stride_depth,
453       stride_width,
454       stride_height,
455       padding_depth,
456       padding_width,
457       padding_height,
458       dilation_depth,
459       dilation_width,
460       dilation_height,
461       output_padding_depth,
462       output_padding_width,
463       output_padding_height,
464       0);
465 
466   Tensor input = input_.contiguous();
467   Tensor weight = weight_.contiguous();
468   Tensor grad_output = grad_output_.contiguous();
469 
470   const int64_t n_input_plane = weight.size(0);
471   const int64_t n_output_plane = weight.size(1);
472 
473   bool is_batch = false;
474   if (input.dim() == 4) {
475     // Force batch
476     is_batch = true;
477     input.resize_(
478         {1, input.size(0), input.size(1), input.size(2), input.size(3)});
479     grad_output.resize_({1,
480                          grad_output.size(0),
481                          grad_output.size(1),
482                          grad_output.size(2),
483                          grad_output.size(3)});
484   }
485 
486   const int64_t input_width = input.size(4);
487   const int64_t input_height = input.size(3);
488   const int64_t input_depth = input.size(2);
489 
490   const int64_t output_depth = (input_depth - 1) * stride_depth -
491       2 * padding_depth + (dilation_depth * (kernel_depth - 1) + 1) +
492       output_padding_depth;
493   const int64_t output_height = (input_height - 1) * stride_height -
494       2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
495       output_padding_height;
496   const int64_t output_width = (input_width - 1) * stride_width -
497       2 * padding_width + (dilation_width * (kernel_width - 1) + 1) +
498       output_padding_width;
499 
500   // Batch size + input planes
501   const int64_t batch_size = input.size(0);
502 
503   // Resize output
504   grad_input.resize_(
505       {batch_size, n_input_plane, input_depth, input_height, input_width});
506   grad_input.zero_();
507 
508   // Create temporary columns
509   bool need_columns = (kernel_depth != 1 || kernel_height != 1 || kernel_width != 1 ||
510       stride_depth != 1 || stride_height != 1 || stride_width != 1 ||
511       dilation_depth != 1 || dilation_height != 1 ||
512       dilation_width != 1 || padding_depth != 0 ||
513       padding_height != 0 || padding_width != 0);
514   Tensor grad_columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
515       input_depth * input_height * input_width}, input.options()) : Tensor();
516 
517   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
518       input.scalar_type(), "slow_conv_transpose3d_backward_out_cpu", [&] {
519         // Helpers
520         Tensor grad_input_n;
521         Tensor grad_output_n;
522 
523         int64_t elt;
524         // For each elt in batch, do:
525         for (elt = 0; elt < batch_size; ++elt) {
526           // Matrix mulitply per sample:
527           grad_input_n = grad_input.select(0, elt);
528           grad_output_n = grad_output.select(0, elt);
529 
530           if (need_columns) {
531             // Extract columns:
532             at::native::vol2col<scalar_t>(
533                 grad_output_n.const_data_ptr<scalar_t>(),
534                 n_output_plane,
535                 output_depth,
536                 output_height,
537                 output_width,
538                 input_depth,
539                 input_height,
540                 input_width,
541                 kernel_depth,
542                 kernel_height,
543                 kernel_width,
544                 padding_depth,
545                 padding_height,
546                 padding_width,
547                 stride_depth,
548                 stride_height,
549                 stride_width,
550                 dilation_depth,
551                 dilation_height,
552                 dilation_width,
553                 grad_columns.mutable_data_ptr<scalar_t>());
554           }
555 
556           // M,N,K are dims of matrix A and B
557           // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
558           const int64_t m = weight.size(0);
559           const int64_t n = input_depth * input_height * input_width;
560           const int64_t k =
561               weight.size(1) * weight.size(2) * weight.size(3) * weight.size(4);
562 
563           // Do GEMM (note: this is a bit confusing because gemm assumes
564           // column-major matrices)
565           auto gemm_in_ptr = need_columns ? grad_columns.const_data_ptr<scalar_t>()
566               : grad_output_n.const_data_ptr<scalar_t>();
567           cpublas::gemm(
568               TransposeType::NoTranspose,
569               TransposeType::NoTranspose,
570               n,
571               m,
572               k,
573               static_cast<scalar_t>(1),
574               gemm_in_ptr,
575               n,
576               weight.const_data_ptr<scalar_t>(),
577               k,
578               static_cast<scalar_t>(0),
579               grad_input_n.mutable_data_ptr<scalar_t>(),
580               n);
581         }
582 
583         // Resize output
584         if (is_batch) {
585           grad_output.resize_(
586               {n_output_plane, output_depth, output_height, output_width});
587           input.resize_(
588               {n_input_plane, input_depth, input_height, input_width});
589           grad_input.resize_(
590               {n_input_plane, input_depth, input_height, input_width});
591         }
592       });
593 }
594 
slow_conv_transpose3d_acc_grad_parameters_cpu(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,int scale_)595 void slow_conv_transpose3d_acc_grad_parameters_cpu(
596     const Tensor& input_,
597     const Tensor& grad_output_,
598     Tensor& grad_weight,
599     Tensor& grad_bias,
600     IntArrayRef kernel_size,
601     IntArrayRef stride,
602     IntArrayRef padding,
603     IntArrayRef output_padding,
604     IntArrayRef dilation,
605     int scale_) {
606   TORCH_CHECK(
607       kernel_size.size() == 3,
608       "It is expected kernel_size equals to 3, but got size ",
609       kernel_size.size());
610 
611   TORCH_CHECK(
612       dilation.size() == 3,
613       "It is expected dilation equals to 3, but got size ",
614       dilation.size());
615 
616   TORCH_CHECK(
617       padding.size() == 3,
618       "It is expected padding equals to 3, but got size ",
619       padding.size());
620 
621   TORCH_CHECK(
622       stride.size() == 3,
623       "It is expected stride equals to 3, but got size ",
624       stride.size());
625 
626   TORCH_CHECK(
627       output_padding.size() == 3,
628       "It is expected stride equals to 3, but got size ",
629       output_padding.size());
630 
631   int64_t kernel_depth = kernel_size[0];
632   int64_t kernel_height = kernel_size[1];
633   int64_t kernel_width = kernel_size[2];
634   int64_t dilation_depth = dilation[0];
635   int64_t dilation_height = dilation[1];
636   int64_t dilation_width = dilation[2];
637   int64_t padding_depth = padding[0];
638   int64_t padding_height = padding[1];
639   int64_t padding_width = padding[2];
640   int64_t stride_depth = stride[0];
641   int64_t stride_height = stride[1];
642   int64_t stride_width = stride[2];
643   int64_t output_padding_depth = output_padding[0];
644   int64_t output_padding_height = output_padding[1];
645   int64_t output_padding_width = output_padding[2];
646 
647   // number of input & output planes and kernel size is indirectly defined by
648   // the grad_weight tensor
649   slow_conv_transpose3d_shape_check(
650       input_,
651       grad_output_,
652       grad_weight,
653       grad_bias,
654       kernel_depth,
655       kernel_width,
656       kernel_height,
657       stride_depth,
658       stride_width,
659       stride_height,
660       padding_depth,
661       padding_width,
662       padding_height,
663       dilation_depth,
664       dilation_width,
665       dilation_height,
666       output_padding_depth,
667       output_padding_width,
668       output_padding_height,
669       1);
670 
671   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
672   int64_t n_output_plane;
673   if (grad_weight.defined()) {
674     n_output_plane = grad_weight.size(1);
675   } else if (grad_bias.defined()) {
676     n_output_plane = grad_bias.size(0);
677   } else {
678     return;
679   }
680 
681   Tensor input = input_.contiguous();
682   Tensor grad_output = grad_output_.contiguous();
683 
684   if (grad_weight.defined()) {
685     TORCH_CHECK(grad_weight.is_contiguous(), "grad_weight needs to be contiguous");
686   }
687   if (grad_bias.defined()) {
688     TORCH_CHECK(grad_bias.is_contiguous(), "grad_bias needs to be contiguous");
689   }
690 
691   bool is_batch = false;
692   if (input.dim() == 4) {
693     // Force batch
694     is_batch = true;
695     input.resize_(
696         {1, input.size(0), input.size(1), input.size(2), input.size(3)});
697     grad_output.resize_({1,
698                          grad_output.size(0),
699                          grad_output.size(1),
700                          grad_output.size(2),
701                          grad_output.size(3)});
702   }
703 
704   const int64_t input_width = input.size(4);
705   const int64_t input_height = input.size(3);
706   const int64_t input_depth = input.size(2);
707 
708   const int64_t output_depth = (input_depth - 1) * stride_depth -
709       2 * padding_depth + (dilation_depth * (kernel_depth - 1) + 1) +
710       output_padding_depth;
711   const int64_t output_height = (input_height - 1) * stride_height -
712       2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
713       output_padding_height;
714   const int64_t output_width = (input_width - 1) * stride_width -
715       2 * padding_width + (dilation_width * (kernel_width - 1) + 1) +
716       output_padding_width;
717 
718   // Batch size + input planes
719   const int64_t batch_size = input.size(0);
720 
721   // Create temporary columns
722   bool need_columns = (kernel_depth != 1 || kernel_height != 1 || kernel_width != 1 ||
723       stride_depth != 1 || stride_height != 1 || stride_width != 1 ||
724       dilation_depth != 1 || dilation_height != 1 ||
725       dilation_width != 1 || padding_depth != 0 ||
726       padding_height != 0 || padding_width != 0);
727   Tensor columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
728       input_depth * input_height * input_width}, input.options()) : Tensor();
729 
730   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
731       input.scalar_type(),
732       "slow_conv_transpose3d_acc_grad_parameters_cpu",
733       [&] {
734         // Helpers
735         Tensor input_n;
736         Tensor grad_output_n;
737 
738         scalar_t scale = static_cast<scalar_t>(scale_);
739 
740         int64_t elt;
741         // For each elt in batch, do:
742         for (elt = 0; elt < batch_size; ++elt) {
743           // Matrix mulitply per output:
744           grad_output_n = grad_output.select(0, elt);
745 
746           // Do Weight:
747           if (grad_weight.defined()) {
748             // Matrix mulitply per output:
749             input_n = input.select(0, elt);
750 
751             if (need_columns) {
752               // Extract columns:
753               at::native::vol2col<scalar_t>(
754                   grad_output_n.const_data_ptr<scalar_t>(),
755                   n_output_plane,
756                   output_depth,
757                   output_height,
758                   output_width,
759                   input_depth,
760                   input_height,
761                   input_width,
762                   kernel_depth,
763                   kernel_height,
764                   kernel_width,
765                   padding_depth,
766                   padding_height,
767                   padding_width,
768                   stride_depth,
769                   stride_height,
770                   stride_width,
771                   dilation_depth,
772                   dilation_height,
773                   dilation_width,
774                   columns.mutable_data_ptr<scalar_t>());
775             }
776 
777             // M,N,K are dims of matrix A and B
778             // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
779             const int64_t n = n_output_plane * kernel_width * kernel_height * kernel_depth;
780             const int64_t m = input_n.size(0); // n_input_plane
781             const int64_t k = input_depth * input_height * input_width;
782 
783             // Do GEMM (note: this is a bit confusing because gemm assumes
784             // column-major matrices)
785             auto gemm_in_ptr = need_columns ? columns.const_data_ptr<scalar_t>()
786                 : grad_output_n.const_data_ptr<scalar_t>();
787             cpublas::gemm(
788                 TransposeType::Transpose,
789                 TransposeType::NoTranspose,
790                 n,
791                 m,
792                 k,
793                 static_cast<scalar_t>(scale),
794                 gemm_in_ptr,
795                 k,
796                 input_n.const_data_ptr<scalar_t>(),
797                 k,
798                 static_cast<scalar_t>(1),
799                 grad_weight.mutable_data_ptr<scalar_t>(),
800                 n);
801           }
802         }
803 
804         if (grad_bias.defined()) {
805           at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3, 4});
806         }
807 
808         // Resize
809         if (is_batch) {
810           grad_output.resize_(
811               {n_output_plane, output_depth, output_height, output_width});
812           input.resize_(
813               {input.size(1), input_depth, input_height, input_width});
814         }
815       });
816 }
817 
818 } // namespace
819 
slow_conv_transpose3d_out_cpu(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,Tensor & output)820 Tensor& slow_conv_transpose3d_out_cpu(const Tensor& input,
821     const Tensor& weight,
822     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
823     IntArrayRef stride,
824     IntArrayRef padding,
825     IntArrayRef output_padding,
826     IntArrayRef dilation,
827     Tensor& output) {
828   // See [Note: hacky wrapper removal for optional tensor]
829   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
830   const Tensor& bias = *bias_maybe_owned;
831 
832   slow_conv_transpose3d_out_cpu_template(
833       output,
834       input,
835       weight,
836       kernel_size,
837       bias,
838       stride,
839       padding,
840       output_padding,
841       dilation);
842 
843   return output;
844 }
845 
slow_conv_transpose3d_cpu(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)846 Tensor slow_conv_transpose3d_cpu(
847     const Tensor& input,
848     const Tensor& weight,
849     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
850     IntArrayRef stride,
851     IntArrayRef padding,
852     IntArrayRef output_padding,
853     IntArrayRef dilation) {
854   // See [Note: hacky wrapper removal for optional tensor]
855   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
856   const Tensor& bias = *bias_maybe_owned;
857 
858   Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
859 
860   slow_conv_transpose3d_out_cpu_template(
861       output,
862       input,
863       weight,
864       kernel_size,
865       bias,
866       stride,
867       padding,
868       output_padding,
869       dilation);
870 
871   return output;
872 }
873 
slow_conv_transpose3d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,std::array<bool,3> output_mask)874 static std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose3d_backward_cpu(
875     const Tensor& grad_output,
876     const Tensor& input,
877     const Tensor& weight,
878     IntArrayRef kernel_size,
879     IntArrayRef stride,
880     IntArrayRef padding,
881     IntArrayRef output_padding,
882     IntArrayRef dilation,
883     std::array<bool, 3> output_mask) {
884   Tensor grad_input;
885   Tensor grad_weight;
886   Tensor grad_bias;
887 
888   if (output_mask[0]) {
889     grad_input = at::empty({0}, grad_output.options());
890   } else {
891     grad_input = Tensor();
892   }
893 
894   if (output_mask[1]) {
895     grad_weight = at::empty({0}, grad_output.options());
896   } else {
897     grad_weight = Tensor();
898   }
899 
900   if (output_mask[2]) {
901     grad_bias = at::empty({0}, grad_output.options());
902   } else {
903     grad_bias = Tensor();
904   }
905 
906   if (grad_input.defined()) {
907     slow_conv_transpose3d_backward_out_cpu_template(
908         input,
909         grad_output,
910         grad_input,
911         weight,
912         kernel_size,
913         stride,
914         padding,
915         output_padding,
916         dilation);
917   }
918 
919   if (grad_weight.defined()) {
920     grad_weight.resize_(weight.sizes());
921     grad_weight.zero_();
922   }
923 
924   if (grad_bias.defined()) {
925     grad_bias.resize_({weight.size(1)});
926     grad_bias.zero_();
927   }
928 
929   if (grad_weight.defined() || grad_bias.defined()) {
930     slow_conv_transpose3d_acc_grad_parameters_cpu(
931         input,
932         grad_output,
933         grad_weight,
934         grad_bias,
935         kernel_size,
936         stride,
937         padding,
938         output_padding,
939         dilation,
940         1);
941   }
942 
943   return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
944 }
945 
946 REGISTER_ALL_CPU_DISPATCH(slow_conv_transpose3d_backward_stub, &slow_conv_transpose3d_backward_cpu);
947 
948 } // namespace at::native
949