xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ConvolutionMM3d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/div_rtn.h>
7 #include <ATen/native/ConvolutionMM3d.h>
8 #include <ATen/native/CPUBlas.h>
9 #include <ATen/native/TransposeType.h>
10 #include <ATen/native/Unfold3d.h>
11 #include <c10/util/irange.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/slow_conv3d_forward.h>
19 #include <ATen/ops/slow_conv3d_forward_native.h>
20 #include <ATen/ops/slow_conv3d_native.h>
21 #include <ATen/ops/sum.h>
22 #endif
23 
24 constexpr int64_t CONV3D_GRAIN_SALT = 20;
25 
26 namespace at::native {
27 
28 namespace {
29 
compute_columns3d(const Tensor & input_,IntArrayRef stride,IntArrayRef padding,IntArrayRef kernel_size,const int64_t groups)30 static Tensor compute_columns3d(
31     const Tensor& input_,
32     IntArrayRef stride,
33     IntArrayRef padding,
34     IntArrayRef kernel_size,
35     const int64_t groups) {
36   const Tensor input = input_.contiguous();
37   const int64_t kernel_depth = kernel_size[0];
38   const int64_t kernel_height = kernel_size[1];
39   const int64_t kernel_width = kernel_size[2];
40   const int64_t pad_depth = padding[0];
41   const int64_t pad_height = padding[1];
42   const int64_t pad_width = padding[2];
43   const int64_t stride_depth = stride[0];
44   const int64_t stride_height = stride[1];
45   const int64_t stride_width = stride[2];
46   const int64_t dim_planes = 1;
47   const int64_t dim_depth = 2;
48   const int64_t dim_height = 3;
49   const int64_t dim_width = 4;
50   const int64_t n_input_plane = input.size(dim_planes);
51   const int64_t input_depth = input.size(dim_depth);
52   const int64_t input_height = input.size(dim_height);
53   const int64_t input_width = input.size(dim_width);
54   const int64_t output_depth =
55       (input_depth + 2 * pad_depth - kernel_depth) / stride_depth + 1;
56   const int64_t output_height =
57       (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
58   const int64_t output_width =
59       (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
60   const int64_t batch_size = input.size(0);
61 
62   Tensor columns;
63   if ((kernel_depth == 1) && (kernel_height == 1) && (kernel_width == 1) &&
64       (pad_depth == 0) && (pad_height == 0) && (pad_width == 0) &&
65       (stride_depth == 1) && (stride_height == 1) && (stride_width == 1) && (groups == 1)) {
66     // Columns are just a view on the input for this special case.
67     columns = input.view({batch_size, n_input_plane, output_height * output_width * output_depth}).detach();
68   } else {
69     columns = at::empty({batch_size,
70                         n_input_plane * kernel_depth * kernel_height * kernel_width,
71                         output_depth * output_height * output_width},
72                         input.options());
73 
74     AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "compute_columns3d", [&] {
75       auto input_a = input.accessor<const scalar_t, 5>();
76       auto columns_a = columns.accessor<scalar_t, 3>();
77 
78       at::parallel_for(0, batch_size, CONV3D_GRAIN_SALT, [&](int64_t start, int64_t end) {
79         for (const auto t : c10::irange(start, end)) {
80           auto input_t = input_a[t];
81           auto columns_t = columns_a[t];
82           Unfold3dCopyCPU(
83             c10::CppTypeToScalarType<scalar_t>::value,
84             input_t.data(),
85             n_input_plane,
86             input_depth,
87             input_height,
88             input_width,
89             output_depth,
90             output_height,
91             output_width,
92             kernel_depth,
93             kernel_height,
94             kernel_width,
95             stride_depth,
96             stride_height,
97             stride_width,
98             pad_depth,
99             pad_height,
100             pad_width,
101             columns_t.data());
102           }
103       });
104     });
105   }
106 
107   return columns;
108 }
109 
slow_conv3d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int64_t kernel_depth,int64_t kernel_height,int64_t kernel_width,int64_t stride_depth,int64_t stride_height,int64_t stride_width,int64_t pad_depth,int64_t pad_height,int64_t pad_width,int64_t groups,bool weight_optional)110 static inline void slow_conv3d_shape_check(
111     const Tensor& input,
112     const Tensor& grad_output,
113     const Tensor& weight,
114     const Tensor& bias,
115     int64_t kernel_depth,
116     int64_t kernel_height,
117     int64_t kernel_width,
118     int64_t stride_depth,
119     int64_t stride_height,
120     int64_t stride_width,
121     int64_t pad_depth,
122     int64_t pad_height,
123     int64_t pad_width,
124     int64_t groups,
125     bool weight_optional) {
126   TORCH_CHECK(
127       kernel_width > 0 && kernel_height > 0 && kernel_depth > 0,
128       "kernel size should be greater than zero, but got: ",
129       kernel_depth,
130       " x ",
131       kernel_height,
132       " x ",
133       kernel_width,
134       " (TxHxW)");
135   TORCH_CHECK(
136       stride_width > 0 && stride_height > 0 && stride_depth > 0,
137       "stride should be greater than zero, but got: ",
138       stride_depth,
139       " x ",
140       stride_height,
141       " x ",
142       stride_width,
143       " (TxHxW)");
144   if (weight.defined()) {
145     TORCH_CHECK(
146         weight.numel() > 0 && (weight.dim() == 2 || weight.dim() == 5),
147         "non-empty 2D or 5D weight tensor expected, but got: ",
148         weight.sizes());
149     if (bias.defined()) {
150       check_dim_size(bias, 1, 0, weight.size(0));
151     }
152   } else {
153     TORCH_CHECK(weight_optional, "weight tensor is undefined");
154   }
155 
156   const int64_t ndim = input.dim();
157   const int64_t dim_batch = 0;
158   const int64_t dim_planes = 1;
159   const int64_t dim_depth = 2;
160   const int64_t dim_height = 3;
161   const int64_t dim_width = 4;
162 
163   // Allow for empty batch size but not other dimensions
164   bool valid_empty = ndim == 5 && input.size(dim_batch) == 0 &&
165       input.size(dim_planes) != 0 && input.size(dim_depth) != 0 &&
166       input.size(dim_height) != 0 && input.size(dim_width) != 0;
167 
168   TORCH_CHECK(
169       (input.numel() > 0 || valid_empty) && ndim == 5,
170       "non-empty 5D input tensor expected but got: ",
171       input.sizes());
172 
173   const int64_t input_depth = input.size(dim_depth);
174   const int64_t input_height = input.size(dim_height);
175   const int64_t input_width = input.size(dim_width);
176 
177   const int64_t exact_input_depth = input_depth + 2 * pad_depth;
178   const int64_t exact_input_height = input_height + 2 * pad_height;
179   const int64_t exact_input_width = input_width + 2 * pad_width;
180 
181   TORCH_CHECK(
182       exact_input_depth >= kernel_depth &&
183           exact_input_height >= kernel_height &&
184           exact_input_width >= kernel_width,
185       "Calculated padded input size per channel: (",
186       exact_input_depth,
187       " x ",
188       exact_input_height,
189       " x ",
190       exact_input_width,
191       "). ",
192       "Kernel size: (",
193       kernel_depth,
194       " x ",
195       kernel_height,
196       " x ",
197       kernel_width,
198       "). Kernel size can't be greater than actual input size");
199 
200   const int64_t output_depth =
201       div_rtn<int64_t>(exact_input_depth - kernel_depth, stride_depth) + 1;
202   const int64_t output_height =
203       div_rtn<int64_t>(exact_input_height - kernel_height, stride_height) + 1;
204   const int64_t output_width =
205       div_rtn<int64_t>(exact_input_width - kernel_width, stride_width) + 1;
206 
207   TORCH_CHECK(
208       output_depth >= 1 && output_width >= 1 && output_height >= 1,
209       "Given input size per channel: (",
210       input_depth,
211       " x ",
212       input_height,
213       " x ",
214       input_width,
215       "). "
216       "Calculated output size per channel: (",
217       output_depth,
218       " x ",
219       output_height,
220       " x ",
221       output_width,
222       "). Output size is too small");
223 
224   if (weight.defined()) {
225     int64_t n_input_plane = weight.size(1);
226     if (weight.dim() == 2) {
227       n_input_plane /= (kernel_height * kernel_width);
228     }
229     // to support grouped conv we need to check if input.size(dim_planes)
230     // is multiple of weight.size(dim_planes)
231     TORCH_CHECK(groups > 0, "none zero group size expected");
232     check_dim_size(input, ndim, dim_planes, n_input_plane * groups);
233   }
234 
235   if (grad_output.defined()) {
236     if (weight.defined()) {
237       int64_t n_output_plane = weight.size(0);
238       check_dim_size(grad_output, ndim, dim_planes, n_output_plane);
239     } else if (bias.defined()) {
240       TORCH_CHECK(bias.numel() > 0, "non-empty bias tensor expected");
241       const int64_t n_output_plane = bias.dim() == 0 ? 1 : bias.size(0);
242       check_dim_size(grad_output, ndim, dim_planes, n_output_plane);
243     }
244     check_dim_size(grad_output, ndim, dim_depth, output_depth);
245     check_dim_size(grad_output, ndim, dim_height, output_height);
246     check_dim_size(grad_output, ndim, dim_width, output_width);
247   }
248 }
249 
view_weight_2d(const Tensor & weight_)250 static Tensor view_weight_2d(const Tensor& weight_) {
251   Tensor weight = weight_.contiguous();
252   if (weight.dim() == 5) {
253     const int64_t s1 = weight.size(0);
254     const int64_t s2 =
255         weight.size(1) * weight.size(2) * weight.size(3) * weight.size(4);
256     return weight.view({s1, s2});
257   } else {
258     return weight;
259   }
260 }
261 
262 template <typename scalar_t>
slow_conv3d_update_output_frame(TensorAccessor<const scalar_t,4> input,TensorAccessor<scalar_t,4> output,TensorAccessor<const scalar_t,2> weight,bool has_bias,TensorAccessor<const scalar_t,2> finput,int64_t kernel_depth,int64_t kernel_height,int64_t kernel_width,int64_t stride_depth,int64_t stride_height,int64_t stride_width,int64_t pad_depth,int64_t pad_height,int64_t pad_width,int64_t n_input_plane,int64_t groups,int64_t input_depth,int64_t input_height,int64_t input_width,int64_t n_output_plane,int64_t output_depth,int64_t output_height,int64_t output_width)263 static void slow_conv3d_update_output_frame(
264     TensorAccessor<const scalar_t, 4> input,
265     TensorAccessor<scalar_t, 4> output,
266     TensorAccessor<const scalar_t, 2> weight,
267     bool has_bias,
268     TensorAccessor<const scalar_t, 2> finput,
269     int64_t kernel_depth,
270     int64_t kernel_height,
271     int64_t kernel_width,
272     int64_t stride_depth,
273     int64_t stride_height,
274     int64_t stride_width,
275     int64_t pad_depth,
276     int64_t pad_height,
277     int64_t pad_width,
278     int64_t n_input_plane,
279     int64_t groups,
280     int64_t input_depth,
281     int64_t input_height,
282     int64_t input_width,
283     int64_t n_output_plane,
284     int64_t output_depth,
285     int64_t output_height,
286     int64_t output_width) {
287   const int beta = has_bias ? 1 : 0;
288 
289   // Compute out = weight * input
290   // Note gemm expects fortran order, so all 3 matrices are transposed.
291   // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
292   const int64_t m = output_depth * output_height * output_width;
293   const int64_t n = (n_output_plane / groups);
294   const int64_t k = (n_input_plane / groups) * kernel_depth * kernel_height * kernel_width;
295 
296   const int64_t lda = m;
297   const int64_t ldb = k;
298   const int64_t ldc = m;
299 
300   at::native::cpublas::gemm_batched_with_stride(
301       TransposeType::NoTranspose,
302       TransposeType::NoTranspose,
303       groups, m, n, k,
304       static_cast<scalar_t>(1),
305       finput.data(), lda, finput.stride(0) * k,
306       weight.data(), ldb, weight.stride(0) * n,
307       static_cast<scalar_t>(beta),
308       output.data(), ldc, output.stride(0) * n);
309 }
310 
311 template <typename scalar_t>
slow_conv3d_backward_update_grad_input_frame(TensorAccessor<scalar_t,4> grad_input,TensorAccessor<const scalar_t,4> grad_output,TensorAccessor<const scalar_t,2> weight,TensorAccessor<scalar_t,2> fgrad_input,int64_t kernel_depth,int64_t kernel_height,int64_t kernel_width,int64_t stride_depth,int64_t stride_height,int64_t stride_width,int64_t pad_depth,int64_t pad_height,int64_t pad_width,int64_t groups)312 void slow_conv3d_backward_update_grad_input_frame(
313     TensorAccessor<scalar_t, 4> grad_input,
314     TensorAccessor<const scalar_t, 4> grad_output,
315     TensorAccessor<const scalar_t, 2> weight,
316     TensorAccessor<scalar_t, 2> fgrad_input,
317     int64_t kernel_depth,
318     int64_t kernel_height,
319     int64_t kernel_width,
320     int64_t stride_depth,
321     int64_t stride_height,
322     int64_t stride_width,
323     int64_t pad_depth,
324     int64_t pad_height,
325     int64_t pad_width,
326     int64_t groups) {
327   // Compute fgrad_input = weight.T * grad_output.reshape({grad_output.shape(0), -1})
328   // Note gemm expects fortran order, so all 3 matrices are transposed.
329   // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
330   const int64_t m = grad_output.size(1) * grad_output.size(2) * grad_output.size(3);
331   const int64_t n = weight.size(1);
332   const int64_t k = weight.size(0) / groups;
333 
334   const int64_t lda = m;
335   const int64_t ldb = n;
336   const int64_t ldc = m;
337 
338   at::native::cpublas::gemm_batched_with_stride(
339       TransposeType::NoTranspose,
340       TransposeType::Transpose,
341       groups, m, n, k,
342       static_cast<scalar_t>(1),
343       grad_output.data(), lda, grad_output.stride(0) * k,
344       weight.data(), ldb, weight.stride(0) * k,
345       static_cast<scalar_t>(0),
346       fgrad_input.data(), ldc, fgrad_input.stride(0) * n);
347 
348   Unfold3dAccCPU(
349       c10::CppTypeToScalarType<scalar_t>::value,
350       fgrad_input.data(),
351       grad_input.size(0),
352       grad_input.size(1),
353       grad_input.size(2),
354       grad_input.size(3),
355       grad_output.size(1),
356       grad_output.size(2),
357       grad_output.size(3),
358       kernel_depth,
359       kernel_height,
360       kernel_width,
361       stride_depth,
362       stride_height,
363       stride_width,
364       pad_depth,
365       pad_height,
366       pad_width,
367       grad_input.data());
368 }
369 
slow_conv3d_backward_out_cpu_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,int64_t groups)370 void slow_conv3d_backward_out_cpu_template(
371     Tensor& grad_input,
372     const Tensor& grad_output,
373     const Tensor& input,
374     const Tensor& weight,
375     IntArrayRef kernel_size,
376     IntArrayRef stride,
377     IntArrayRef padding,
378     int64_t groups) {
379   const int64_t kernel_depth = kernel_size[0];
380   const int64_t kernel_height = kernel_size[1];
381   const int64_t kernel_width = kernel_size[2];
382   const int64_t pad_depth = padding[0];
383   const int64_t pad_height = padding[1];
384   const int64_t pad_width = padding[2];
385   const int64_t stride_depth = stride[0];
386   const int64_t stride_height = stride[1];
387   const int64_t stride_width = stride[2];
388 
389   slow_conv3d_shape_check(
390       input,
391       grad_output,
392       weight,
393       Tensor(),
394       kernel_depth,
395       kernel_height,
396       kernel_width,
397       stride_depth,
398       stride_height,
399       stride_width,
400       pad_depth,
401       pad_height,
402       pad_width,
403       groups,
404       false);
405 
406   const Tensor weight2d = view_weight_2d(weight);
407   const Tensor grad_output_contiguous = grad_output.contiguous();
408   grad_input.resize_as_(input);
409   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous")
410 
411   const int64_t dim_planes = 1;
412   const int64_t dim_depth = 2;
413   const int64_t dim_height = 3;
414   const int64_t dim_width = 4;
415   const int64_t n_input_plane = input.size(dim_planes);
416   const int64_t input_depth = input.size(dim_depth);
417   const int64_t input_height = input.size(dim_height);
418   const int64_t input_width = input.size(dim_width);
419   const int64_t output_depth =
420       (input_depth + 2 * pad_depth - kernel_depth) / stride_depth + 1;
421   const int64_t output_height =
422       (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
423   const int64_t output_width =
424       (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
425   const int64_t batch_size = input.size(0);
426 
427   Tensor fgrad_input = at::empty({batch_size,
428       n_input_plane * kernel_depth * kernel_height * kernel_width,
429       output_depth * output_height * output_width}, input.options());
430 
431   AT_DISPATCH_FLOATING_TYPES_AND2(
432       kBFloat16, kHalf, input.scalar_type(), "slow_conv3d_cpu_grad_input", [&] {
433     auto grad_input_a = grad_input.accessor<scalar_t, 5>();
434     auto grad_output_a = grad_output_contiguous.accessor<const scalar_t, 5>();
435     auto fgrad_input_a = fgrad_input.accessor<scalar_t, 3>();
436     auto weight_2d_a = weight2d.accessor<const scalar_t, 2>();
437     at::parallel_for(0, batch_size, CONV3D_GRAIN_SALT,
438                     [&](int64_t start, int64_t end) {
439 
440         for (const auto t : c10::irange(start, end)) {
441           auto grad_input_t = grad_input_a[t];
442           auto grad_output_t = grad_output_a[t];
443           auto fgrad_input_t = fgrad_input_a[t];
444           slow_conv3d_backward_update_grad_input_frame(
445               grad_input_t,
446               grad_output_t,
447               weight_2d_a,
448               fgrad_input_t,
449               kernel_depth,
450               kernel_height,
451               kernel_width,
452               stride_depth,
453               stride_height,
454               stride_width,
455               pad_depth,
456               pad_height,
457               pad_width,
458               groups);
459         }
460     });
461   });
462 }
463 
464 template <typename scalar_t>
slow_conv3d_backward_weight_frame(TensorAccessor<scalar_t,2> grad_weight,TensorAccessor<const scalar_t,4> grad_output,TensorAccessor<const scalar_t,2> finput,int64_t groups)465 void slow_conv3d_backward_weight_frame(
466     TensorAccessor<scalar_t, 2> grad_weight,
467     TensorAccessor<const scalar_t, 4> grad_output,
468     TensorAccessor<const scalar_t, 2> finput,
469     int64_t groups) {
470   // Compute grad_weight += grad_output.reshape({grad_output.shape(0), -1}) * finput.T
471   // Note gemm expects fortran order, so all 3 matrices are transposed.
472   // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
473   const int64_t m = grad_weight.size(1);
474   const int64_t n = grad_weight.size(0) / groups;
475   const int64_t k = grad_output.size(1) * grad_output.size(2) * grad_output.size(3);
476 
477   const int64_t lda = k;
478   const int64_t ldb = k;
479   const int64_t ldc = m;
480 
481   at::native::cpublas::gemm_batched_with_stride(
482       TransposeType::Transpose,
483       TransposeType::NoTranspose,
484       groups, m, n, k,
485       static_cast<scalar_t>(1),
486       finput.data(), lda, finput.stride(0) * m,
487       grad_output.data(), ldb, grad_output.stride(0) * n,
488       static_cast<scalar_t>(1),
489       grad_weight.data(), ldc, grad_weight.stride(0) * n);
490 }
491 
slow_conv3d_backward_parameters_out_cpu_template(Tensor & grad_weight,const Tensor & input,const Tensor & grad_output,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,int64_t groups)492 static void slow_conv3d_backward_parameters_out_cpu_template(
493     Tensor& grad_weight,
494     const Tensor& input,
495     const Tensor& grad_output,
496     IntArrayRef kernel_size,
497     IntArrayRef stride,
498     IntArrayRef padding,
499     int64_t groups) {
500   CheckedFrom c = "slow_conv3d_backward_parameters_cpu";
501   auto grad_weight_arg = TensorArg(grad_weight, "grad_weight_arg", 0);
502 
503   const int64_t kernel_depth = kernel_size[0];
504   const int64_t kernel_height = kernel_size[1];
505   const int64_t kernel_width = kernel_size[2];
506   const int64_t pad_depth = padding[0];
507   const int64_t pad_height = padding[1];
508   const int64_t pad_width = padding[2];
509   const int64_t stride_depth = stride[0];
510   const int64_t stride_height = stride[1];
511   const int64_t stride_width = stride[2];
512 
513   slow_conv3d_shape_check(
514       input,
515       grad_output,
516       grad_weight,
517       {},
518       kernel_depth,
519       kernel_height,
520       kernel_width,
521       stride_depth,
522       stride_height,
523       stride_width,
524       pad_depth,
525       pad_height,
526       pad_width,
527       groups,
528       true);
529 
530   Tensor grad_weight_2d = view_weight_2d(grad_weight);
531   checkContiguous(c, grad_weight_arg);
532 
533   auto grad_output_contiguous = grad_output.contiguous();
534 
535   const int64_t batch_size = input.size(0);
536   Tensor finput = compute_columns3d(input, stride, padding, kernel_size, groups);
537 
538   AT_DISPATCH_FLOATING_TYPES_AND2(
539       kBFloat16, kHalf, input.scalar_type(), "slow_conv3d_cpu_grad_weight", [&] {
540     auto grad_weight_2d_a = grad_weight_2d.accessor<scalar_t, 2>();
541     auto grad_output_a = grad_output_contiguous.accessor<const scalar_t, 5>();
542     auto finput_a = finput.accessor<const scalar_t, 3>();
543     for (const auto t : c10::irange(batch_size)) {
544       auto grad_output_t = grad_output_a[t];
545       auto finput_t = finput_a[t];
546       slow_conv3d_backward_weight_frame(
547           grad_weight_2d_a, grad_output_t, finput_t, groups);
548     }
549   });
550 }
551 
552 } // namespace
553 
slow_conv3d_forward_out_cpu(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,Tensor & output)554 Tensor& slow_conv3d_forward_out_cpu(const Tensor& self,
555     const Tensor& weight,
556     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
557     IntArrayRef stride,
558     IntArrayRef padding,
559     Tensor& output) {
560   // See [Note: hacky wrapper removal for optional tensor]
561   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
562   const Tensor& bias = *bias_maybe_owned;
563 
564   const int64_t kernel_depth = kernel_size[0];
565   const int64_t kernel_height = kernel_size[1];
566   const int64_t kernel_width = kernel_size[2];
567   const int64_t pad_depth = padding[0];
568   const int64_t pad_height = padding[1];
569   const int64_t pad_width = padding[2];
570   const int64_t stride_depth = stride[0];
571   const int64_t stride_height = stride[1];
572   const int64_t stride_width = stride[2];
573 
574   // TODO: hacky way of deciding the groups
575   // Assuming the group size is checked in upstream functions
576   const int64_t groups = weight.size(1) > 0 ? self.size(1) / weight.size(1) : 0;
577 
578   slow_conv3d_shape_check(
579       self,
580       Tensor(),
581       weight,
582       bias,
583       kernel_depth,
584       kernel_height,
585       kernel_width,
586       stride_depth,
587       stride_height,
588       stride_width,
589       pad_depth,
590       pad_height,
591       pad_width,
592       groups,
593       false);
594 
595   const Tensor input = self.contiguous();
596   const Tensor weight_2d = view_weight_2d(weight);
597 
598   const int64_t dim_planes = 1;
599   const int64_t dim_depth = 2;
600   const int64_t dim_height = 3;
601   const int64_t dim_width = 4;
602 
603   const int64_t n_input_plane = input.size(dim_planes);
604   const int64_t input_depth = input.size(dim_depth);
605   const int64_t input_height = input.size(dim_height);
606   const int64_t input_width = input.size(dim_width);
607   const int64_t n_output_plane = weight_2d.size(0);
608   const int64_t output_depth =
609       (input_depth + 2 * pad_depth - kernel_depth) / stride_depth + 1;
610   const int64_t output_height =
611       (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
612   const int64_t output_width =
613       (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
614 
615   Tensor finput = compute_columns3d(input, stride, padding, kernel_size, groups);
616   const int64_t batch_size = input.size(0);
617   output.resize_(
618       {batch_size, n_output_plane, output_depth, output_height, output_width});
619   if (bias.defined()) {
620     output.copy_(bias.reshape({-1, 1, 1, 1}));
621   }
622 
623   TORCH_CHECK(output.is_contiguous(), "slow_conv3d output must be contiguous");
624 
625   AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "slow_conv3d_cpu", [&] {
626     auto input_a = input.accessor<const scalar_t, 5>();
627     auto output_a = output.accessor<scalar_t, 5>();
628     auto finput_a = finput.accessor<const scalar_t, 3>();
629     auto weight_2d_a = weight_2d.accessor<const scalar_t, 2>();
630 
631     at::parallel_for(
632         0, batch_size, CONV3D_GRAIN_SALT, [&](int64_t start, int64_t end) {
633           for (const auto t : c10::irange(start, end)) {
634             auto input_t = input_a[t];
635             auto output_t = output_a[t];
636             auto finput_t = finput_a[t];
637             slow_conv3d_update_output_frame(
638                 input_t,
639                 output_t,
640                 weight_2d_a,
641                 bias.defined(),
642                 finput_t,
643                 kernel_depth,
644                 kernel_height,
645                 kernel_width,
646                 stride_depth,
647                 stride_height,
648                 stride_width,
649                 pad_depth,
650                 pad_height,
651                 pad_width,
652                 n_input_plane,
653                 groups,
654                 input_depth,
655                 input_height,
656                 input_width,
657                 n_output_plane,
658                 output_depth,
659                 output_height,
660                 output_width);
661           }
662         });
663   });
664 
665   return output;
666 }
667 
slow_conv3d_forward_cpu(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding)668 Tensor slow_conv3d_forward_cpu(
669     const Tensor& self,
670     const Tensor& weight,
671     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
672     IntArrayRef stride,
673     IntArrayRef padding) {
674   // See [Note: hacky wrapper removal for optional tensor]
675   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
676   const Tensor& bias = *bias_maybe_owned;
677 
678   auto output = at::empty({0}, self.options());
679   at::native::slow_conv3d_forward_out_cpu(
680       self,
681       weight,
682       kernel_size,
683       bias,
684       stride,
685       padding,
686       output);
687   return output;
688 }
689 
slow_conv3d_backward_out_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)690 static std::tuple<Tensor&, Tensor&, Tensor&> slow_conv3d_backward_out_cpu(const Tensor& grad_output,
691     const Tensor& self,
692     const Tensor& weight,
693     IntArrayRef kernel_size,
694     IntArrayRef stride,
695     IntArrayRef padding,
696     Tensor& grad_input,
697     Tensor& grad_weight,
698     Tensor& grad_bias) {
699   // TODO: hacky way of determine the group size
700   int64_t groups = self.size(1) / weight.size(1);
701   if (grad_input.defined()) {
702     slow_conv3d_backward_out_cpu_template(
703         grad_input,
704         grad_output,
705         self,
706         weight,
707         kernel_size,
708         stride,
709         padding,
710         groups);
711   }
712 
713   if (grad_bias.defined()) {
714     at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3, 4});
715   }
716 
717   if (grad_weight.defined()) {
718     grad_weight.resize_(weight.sizes());
719     grad_weight.zero_();
720     slow_conv3d_backward_parameters_out_cpu_template(
721         grad_weight,
722         self,
723         grad_output,
724         kernel_size,
725         stride,
726         padding,
727         groups);
728   }
729 
730   return std::tuple<Tensor&, Tensor&, Tensor&>(
731       grad_input, grad_weight, grad_bias);
732 }
733 
slow_conv3d_backward_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,std::array<bool,3> output_mask)734 std::tuple<Tensor, Tensor, Tensor> slow_conv3d_backward_cpu(
735     const Tensor& grad_output,
736     const Tensor& self,
737     const Tensor& weight,
738     IntArrayRef kernel_size,
739     IntArrayRef stride,
740     IntArrayRef padding,
741     std::array<bool, 3> output_mask) {
742   Tensor grad_input;
743   Tensor grad_weight;
744   Tensor grad_bias;
745 
746   if (output_mask[0]) {
747     grad_input = at::empty({0}, grad_output.options());
748   }
749 
750   if (output_mask[1]) {
751     grad_weight = at::empty({0}, grad_output.options());
752   }
753 
754   if (output_mask[2]) {
755     grad_bias = at::empty({0}, grad_output.options());
756   }
757 
758   at::native::slow_conv3d_backward_out_cpu(
759       grad_output,
760       self,
761       weight,
762       kernel_size,
763       stride,
764       padding,
765       grad_input,
766       grad_weight,
767       grad_bias);
768 
769   return std::make_tuple(grad_input, grad_weight, grad_bias);
770 }
771 
slow_conv3d_out(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,Tensor & output)772 Tensor& slow_conv3d_out(const Tensor& self,
773     const Tensor& weight,
774     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
775     IntArrayRef stride,
776     IntArrayRef padding,
777     Tensor& output) {
778   // See [Note: hacky wrapper removal for optional tensor]
779   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
780   const Tensor& bias = *bias_maybe_owned;
781 
782   return at::slow_conv3d_forward_out(
783       output,
784       self,
785       weight,
786       kernel_size,
787       bias,
788       stride,
789       padding);
790 }
791 
slow_conv3d(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding)792 Tensor slow_conv3d(
793     const Tensor& self,
794     const Tensor& weight,
795     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
796     IntArrayRef stride,
797     IntArrayRef padding) {
798   // See [Note: hacky wrapper removal for optional tensor]
799   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
800   const Tensor& bias = *bias_maybe_owned;
801 
802   return at::slow_conv3d_forward(self, weight, kernel_size, bias, stride, padding);
803 }
804 
805 } // namespace at::native
806