1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/div_rtn.h>
7 #include <ATen/native/ConvolutionMM3d.h>
8 #include <ATen/native/CPUBlas.h>
9 #include <ATen/native/TransposeType.h>
10 #include <ATen/native/Unfold3d.h>
11 #include <c10/util/irange.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/slow_conv3d_forward.h>
19 #include <ATen/ops/slow_conv3d_forward_native.h>
20 #include <ATen/ops/slow_conv3d_native.h>
21 #include <ATen/ops/sum.h>
22 #endif
23
24 constexpr int64_t CONV3D_GRAIN_SALT = 20;
25
26 namespace at::native {
27
28 namespace {
29
compute_columns3d(const Tensor & input_,IntArrayRef stride,IntArrayRef padding,IntArrayRef kernel_size,const int64_t groups)30 static Tensor compute_columns3d(
31 const Tensor& input_,
32 IntArrayRef stride,
33 IntArrayRef padding,
34 IntArrayRef kernel_size,
35 const int64_t groups) {
36 const Tensor input = input_.contiguous();
37 const int64_t kernel_depth = kernel_size[0];
38 const int64_t kernel_height = kernel_size[1];
39 const int64_t kernel_width = kernel_size[2];
40 const int64_t pad_depth = padding[0];
41 const int64_t pad_height = padding[1];
42 const int64_t pad_width = padding[2];
43 const int64_t stride_depth = stride[0];
44 const int64_t stride_height = stride[1];
45 const int64_t stride_width = stride[2];
46 const int64_t dim_planes = 1;
47 const int64_t dim_depth = 2;
48 const int64_t dim_height = 3;
49 const int64_t dim_width = 4;
50 const int64_t n_input_plane = input.size(dim_planes);
51 const int64_t input_depth = input.size(dim_depth);
52 const int64_t input_height = input.size(dim_height);
53 const int64_t input_width = input.size(dim_width);
54 const int64_t output_depth =
55 (input_depth + 2 * pad_depth - kernel_depth) / stride_depth + 1;
56 const int64_t output_height =
57 (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
58 const int64_t output_width =
59 (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
60 const int64_t batch_size = input.size(0);
61
62 Tensor columns;
63 if ((kernel_depth == 1) && (kernel_height == 1) && (kernel_width == 1) &&
64 (pad_depth == 0) && (pad_height == 0) && (pad_width == 0) &&
65 (stride_depth == 1) && (stride_height == 1) && (stride_width == 1) && (groups == 1)) {
66 // Columns are just a view on the input for this special case.
67 columns = input.view({batch_size, n_input_plane, output_height * output_width * output_depth}).detach();
68 } else {
69 columns = at::empty({batch_size,
70 n_input_plane * kernel_depth * kernel_height * kernel_width,
71 output_depth * output_height * output_width},
72 input.options());
73
74 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "compute_columns3d", [&] {
75 auto input_a = input.accessor<const scalar_t, 5>();
76 auto columns_a = columns.accessor<scalar_t, 3>();
77
78 at::parallel_for(0, batch_size, CONV3D_GRAIN_SALT, [&](int64_t start, int64_t end) {
79 for (const auto t : c10::irange(start, end)) {
80 auto input_t = input_a[t];
81 auto columns_t = columns_a[t];
82 Unfold3dCopyCPU(
83 c10::CppTypeToScalarType<scalar_t>::value,
84 input_t.data(),
85 n_input_plane,
86 input_depth,
87 input_height,
88 input_width,
89 output_depth,
90 output_height,
91 output_width,
92 kernel_depth,
93 kernel_height,
94 kernel_width,
95 stride_depth,
96 stride_height,
97 stride_width,
98 pad_depth,
99 pad_height,
100 pad_width,
101 columns_t.data());
102 }
103 });
104 });
105 }
106
107 return columns;
108 }
109
slow_conv3d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int64_t kernel_depth,int64_t kernel_height,int64_t kernel_width,int64_t stride_depth,int64_t stride_height,int64_t stride_width,int64_t pad_depth,int64_t pad_height,int64_t pad_width,int64_t groups,bool weight_optional)110 static inline void slow_conv3d_shape_check(
111 const Tensor& input,
112 const Tensor& grad_output,
113 const Tensor& weight,
114 const Tensor& bias,
115 int64_t kernel_depth,
116 int64_t kernel_height,
117 int64_t kernel_width,
118 int64_t stride_depth,
119 int64_t stride_height,
120 int64_t stride_width,
121 int64_t pad_depth,
122 int64_t pad_height,
123 int64_t pad_width,
124 int64_t groups,
125 bool weight_optional) {
126 TORCH_CHECK(
127 kernel_width > 0 && kernel_height > 0 && kernel_depth > 0,
128 "kernel size should be greater than zero, but got: ",
129 kernel_depth,
130 " x ",
131 kernel_height,
132 " x ",
133 kernel_width,
134 " (TxHxW)");
135 TORCH_CHECK(
136 stride_width > 0 && stride_height > 0 && stride_depth > 0,
137 "stride should be greater than zero, but got: ",
138 stride_depth,
139 " x ",
140 stride_height,
141 " x ",
142 stride_width,
143 " (TxHxW)");
144 if (weight.defined()) {
145 TORCH_CHECK(
146 weight.numel() > 0 && (weight.dim() == 2 || weight.dim() == 5),
147 "non-empty 2D or 5D weight tensor expected, but got: ",
148 weight.sizes());
149 if (bias.defined()) {
150 check_dim_size(bias, 1, 0, weight.size(0));
151 }
152 } else {
153 TORCH_CHECK(weight_optional, "weight tensor is undefined");
154 }
155
156 const int64_t ndim = input.dim();
157 const int64_t dim_batch = 0;
158 const int64_t dim_planes = 1;
159 const int64_t dim_depth = 2;
160 const int64_t dim_height = 3;
161 const int64_t dim_width = 4;
162
163 // Allow for empty batch size but not other dimensions
164 bool valid_empty = ndim == 5 && input.size(dim_batch) == 0 &&
165 input.size(dim_planes) != 0 && input.size(dim_depth) != 0 &&
166 input.size(dim_height) != 0 && input.size(dim_width) != 0;
167
168 TORCH_CHECK(
169 (input.numel() > 0 || valid_empty) && ndim == 5,
170 "non-empty 5D input tensor expected but got: ",
171 input.sizes());
172
173 const int64_t input_depth = input.size(dim_depth);
174 const int64_t input_height = input.size(dim_height);
175 const int64_t input_width = input.size(dim_width);
176
177 const int64_t exact_input_depth = input_depth + 2 * pad_depth;
178 const int64_t exact_input_height = input_height + 2 * pad_height;
179 const int64_t exact_input_width = input_width + 2 * pad_width;
180
181 TORCH_CHECK(
182 exact_input_depth >= kernel_depth &&
183 exact_input_height >= kernel_height &&
184 exact_input_width >= kernel_width,
185 "Calculated padded input size per channel: (",
186 exact_input_depth,
187 " x ",
188 exact_input_height,
189 " x ",
190 exact_input_width,
191 "). ",
192 "Kernel size: (",
193 kernel_depth,
194 " x ",
195 kernel_height,
196 " x ",
197 kernel_width,
198 "). Kernel size can't be greater than actual input size");
199
200 const int64_t output_depth =
201 div_rtn<int64_t>(exact_input_depth - kernel_depth, stride_depth) + 1;
202 const int64_t output_height =
203 div_rtn<int64_t>(exact_input_height - kernel_height, stride_height) + 1;
204 const int64_t output_width =
205 div_rtn<int64_t>(exact_input_width - kernel_width, stride_width) + 1;
206
207 TORCH_CHECK(
208 output_depth >= 1 && output_width >= 1 && output_height >= 1,
209 "Given input size per channel: (",
210 input_depth,
211 " x ",
212 input_height,
213 " x ",
214 input_width,
215 "). "
216 "Calculated output size per channel: (",
217 output_depth,
218 " x ",
219 output_height,
220 " x ",
221 output_width,
222 "). Output size is too small");
223
224 if (weight.defined()) {
225 int64_t n_input_plane = weight.size(1);
226 if (weight.dim() == 2) {
227 n_input_plane /= (kernel_height * kernel_width);
228 }
229 // to support grouped conv we need to check if input.size(dim_planes)
230 // is multiple of weight.size(dim_planes)
231 TORCH_CHECK(groups > 0, "none zero group size expected");
232 check_dim_size(input, ndim, dim_planes, n_input_plane * groups);
233 }
234
235 if (grad_output.defined()) {
236 if (weight.defined()) {
237 int64_t n_output_plane = weight.size(0);
238 check_dim_size(grad_output, ndim, dim_planes, n_output_plane);
239 } else if (bias.defined()) {
240 TORCH_CHECK(bias.numel() > 0, "non-empty bias tensor expected");
241 const int64_t n_output_plane = bias.dim() == 0 ? 1 : bias.size(0);
242 check_dim_size(grad_output, ndim, dim_planes, n_output_plane);
243 }
244 check_dim_size(grad_output, ndim, dim_depth, output_depth);
245 check_dim_size(grad_output, ndim, dim_height, output_height);
246 check_dim_size(grad_output, ndim, dim_width, output_width);
247 }
248 }
249
view_weight_2d(const Tensor & weight_)250 static Tensor view_weight_2d(const Tensor& weight_) {
251 Tensor weight = weight_.contiguous();
252 if (weight.dim() == 5) {
253 const int64_t s1 = weight.size(0);
254 const int64_t s2 =
255 weight.size(1) * weight.size(2) * weight.size(3) * weight.size(4);
256 return weight.view({s1, s2});
257 } else {
258 return weight;
259 }
260 }
261
262 template <typename scalar_t>
slow_conv3d_update_output_frame(TensorAccessor<const scalar_t,4> input,TensorAccessor<scalar_t,4> output,TensorAccessor<const scalar_t,2> weight,bool has_bias,TensorAccessor<const scalar_t,2> finput,int64_t kernel_depth,int64_t kernel_height,int64_t kernel_width,int64_t stride_depth,int64_t stride_height,int64_t stride_width,int64_t pad_depth,int64_t pad_height,int64_t pad_width,int64_t n_input_plane,int64_t groups,int64_t input_depth,int64_t input_height,int64_t input_width,int64_t n_output_plane,int64_t output_depth,int64_t output_height,int64_t output_width)263 static void slow_conv3d_update_output_frame(
264 TensorAccessor<const scalar_t, 4> input,
265 TensorAccessor<scalar_t, 4> output,
266 TensorAccessor<const scalar_t, 2> weight,
267 bool has_bias,
268 TensorAccessor<const scalar_t, 2> finput,
269 int64_t kernel_depth,
270 int64_t kernel_height,
271 int64_t kernel_width,
272 int64_t stride_depth,
273 int64_t stride_height,
274 int64_t stride_width,
275 int64_t pad_depth,
276 int64_t pad_height,
277 int64_t pad_width,
278 int64_t n_input_plane,
279 int64_t groups,
280 int64_t input_depth,
281 int64_t input_height,
282 int64_t input_width,
283 int64_t n_output_plane,
284 int64_t output_depth,
285 int64_t output_height,
286 int64_t output_width) {
287 const int beta = has_bias ? 1 : 0;
288
289 // Compute out = weight * input
290 // Note gemm expects fortran order, so all 3 matrices are transposed.
291 // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
292 const int64_t m = output_depth * output_height * output_width;
293 const int64_t n = (n_output_plane / groups);
294 const int64_t k = (n_input_plane / groups) * kernel_depth * kernel_height * kernel_width;
295
296 const int64_t lda = m;
297 const int64_t ldb = k;
298 const int64_t ldc = m;
299
300 at::native::cpublas::gemm_batched_with_stride(
301 TransposeType::NoTranspose,
302 TransposeType::NoTranspose,
303 groups, m, n, k,
304 static_cast<scalar_t>(1),
305 finput.data(), lda, finput.stride(0) * k,
306 weight.data(), ldb, weight.stride(0) * n,
307 static_cast<scalar_t>(beta),
308 output.data(), ldc, output.stride(0) * n);
309 }
310
311 template <typename scalar_t>
slow_conv3d_backward_update_grad_input_frame(TensorAccessor<scalar_t,4> grad_input,TensorAccessor<const scalar_t,4> grad_output,TensorAccessor<const scalar_t,2> weight,TensorAccessor<scalar_t,2> fgrad_input,int64_t kernel_depth,int64_t kernel_height,int64_t kernel_width,int64_t stride_depth,int64_t stride_height,int64_t stride_width,int64_t pad_depth,int64_t pad_height,int64_t pad_width,int64_t groups)312 void slow_conv3d_backward_update_grad_input_frame(
313 TensorAccessor<scalar_t, 4> grad_input,
314 TensorAccessor<const scalar_t, 4> grad_output,
315 TensorAccessor<const scalar_t, 2> weight,
316 TensorAccessor<scalar_t, 2> fgrad_input,
317 int64_t kernel_depth,
318 int64_t kernel_height,
319 int64_t kernel_width,
320 int64_t stride_depth,
321 int64_t stride_height,
322 int64_t stride_width,
323 int64_t pad_depth,
324 int64_t pad_height,
325 int64_t pad_width,
326 int64_t groups) {
327 // Compute fgrad_input = weight.T * grad_output.reshape({grad_output.shape(0), -1})
328 // Note gemm expects fortran order, so all 3 matrices are transposed.
329 // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
330 const int64_t m = grad_output.size(1) * grad_output.size(2) * grad_output.size(3);
331 const int64_t n = weight.size(1);
332 const int64_t k = weight.size(0) / groups;
333
334 const int64_t lda = m;
335 const int64_t ldb = n;
336 const int64_t ldc = m;
337
338 at::native::cpublas::gemm_batched_with_stride(
339 TransposeType::NoTranspose,
340 TransposeType::Transpose,
341 groups, m, n, k,
342 static_cast<scalar_t>(1),
343 grad_output.data(), lda, grad_output.stride(0) * k,
344 weight.data(), ldb, weight.stride(0) * k,
345 static_cast<scalar_t>(0),
346 fgrad_input.data(), ldc, fgrad_input.stride(0) * n);
347
348 Unfold3dAccCPU(
349 c10::CppTypeToScalarType<scalar_t>::value,
350 fgrad_input.data(),
351 grad_input.size(0),
352 grad_input.size(1),
353 grad_input.size(2),
354 grad_input.size(3),
355 grad_output.size(1),
356 grad_output.size(2),
357 grad_output.size(3),
358 kernel_depth,
359 kernel_height,
360 kernel_width,
361 stride_depth,
362 stride_height,
363 stride_width,
364 pad_depth,
365 pad_height,
366 pad_width,
367 grad_input.data());
368 }
369
slow_conv3d_backward_out_cpu_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,int64_t groups)370 void slow_conv3d_backward_out_cpu_template(
371 Tensor& grad_input,
372 const Tensor& grad_output,
373 const Tensor& input,
374 const Tensor& weight,
375 IntArrayRef kernel_size,
376 IntArrayRef stride,
377 IntArrayRef padding,
378 int64_t groups) {
379 const int64_t kernel_depth = kernel_size[0];
380 const int64_t kernel_height = kernel_size[1];
381 const int64_t kernel_width = kernel_size[2];
382 const int64_t pad_depth = padding[0];
383 const int64_t pad_height = padding[1];
384 const int64_t pad_width = padding[2];
385 const int64_t stride_depth = stride[0];
386 const int64_t stride_height = stride[1];
387 const int64_t stride_width = stride[2];
388
389 slow_conv3d_shape_check(
390 input,
391 grad_output,
392 weight,
393 Tensor(),
394 kernel_depth,
395 kernel_height,
396 kernel_width,
397 stride_depth,
398 stride_height,
399 stride_width,
400 pad_depth,
401 pad_height,
402 pad_width,
403 groups,
404 false);
405
406 const Tensor weight2d = view_weight_2d(weight);
407 const Tensor grad_output_contiguous = grad_output.contiguous();
408 grad_input.resize_as_(input);
409 TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous")
410
411 const int64_t dim_planes = 1;
412 const int64_t dim_depth = 2;
413 const int64_t dim_height = 3;
414 const int64_t dim_width = 4;
415 const int64_t n_input_plane = input.size(dim_planes);
416 const int64_t input_depth = input.size(dim_depth);
417 const int64_t input_height = input.size(dim_height);
418 const int64_t input_width = input.size(dim_width);
419 const int64_t output_depth =
420 (input_depth + 2 * pad_depth - kernel_depth) / stride_depth + 1;
421 const int64_t output_height =
422 (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
423 const int64_t output_width =
424 (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
425 const int64_t batch_size = input.size(0);
426
427 Tensor fgrad_input = at::empty({batch_size,
428 n_input_plane * kernel_depth * kernel_height * kernel_width,
429 output_depth * output_height * output_width}, input.options());
430
431 AT_DISPATCH_FLOATING_TYPES_AND2(
432 kBFloat16, kHalf, input.scalar_type(), "slow_conv3d_cpu_grad_input", [&] {
433 auto grad_input_a = grad_input.accessor<scalar_t, 5>();
434 auto grad_output_a = grad_output_contiguous.accessor<const scalar_t, 5>();
435 auto fgrad_input_a = fgrad_input.accessor<scalar_t, 3>();
436 auto weight_2d_a = weight2d.accessor<const scalar_t, 2>();
437 at::parallel_for(0, batch_size, CONV3D_GRAIN_SALT,
438 [&](int64_t start, int64_t end) {
439
440 for (const auto t : c10::irange(start, end)) {
441 auto grad_input_t = grad_input_a[t];
442 auto grad_output_t = grad_output_a[t];
443 auto fgrad_input_t = fgrad_input_a[t];
444 slow_conv3d_backward_update_grad_input_frame(
445 grad_input_t,
446 grad_output_t,
447 weight_2d_a,
448 fgrad_input_t,
449 kernel_depth,
450 kernel_height,
451 kernel_width,
452 stride_depth,
453 stride_height,
454 stride_width,
455 pad_depth,
456 pad_height,
457 pad_width,
458 groups);
459 }
460 });
461 });
462 }
463
464 template <typename scalar_t>
slow_conv3d_backward_weight_frame(TensorAccessor<scalar_t,2> grad_weight,TensorAccessor<const scalar_t,4> grad_output,TensorAccessor<const scalar_t,2> finput,int64_t groups)465 void slow_conv3d_backward_weight_frame(
466 TensorAccessor<scalar_t, 2> grad_weight,
467 TensorAccessor<const scalar_t, 4> grad_output,
468 TensorAccessor<const scalar_t, 2> finput,
469 int64_t groups) {
470 // Compute grad_weight += grad_output.reshape({grad_output.shape(0), -1}) * finput.T
471 // Note gemm expects fortran order, so all 3 matrices are transposed.
472 // Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
473 const int64_t m = grad_weight.size(1);
474 const int64_t n = grad_weight.size(0) / groups;
475 const int64_t k = grad_output.size(1) * grad_output.size(2) * grad_output.size(3);
476
477 const int64_t lda = k;
478 const int64_t ldb = k;
479 const int64_t ldc = m;
480
481 at::native::cpublas::gemm_batched_with_stride(
482 TransposeType::Transpose,
483 TransposeType::NoTranspose,
484 groups, m, n, k,
485 static_cast<scalar_t>(1),
486 finput.data(), lda, finput.stride(0) * m,
487 grad_output.data(), ldb, grad_output.stride(0) * n,
488 static_cast<scalar_t>(1),
489 grad_weight.data(), ldc, grad_weight.stride(0) * n);
490 }
491
slow_conv3d_backward_parameters_out_cpu_template(Tensor & grad_weight,const Tensor & input,const Tensor & grad_output,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,int64_t groups)492 static void slow_conv3d_backward_parameters_out_cpu_template(
493 Tensor& grad_weight,
494 const Tensor& input,
495 const Tensor& grad_output,
496 IntArrayRef kernel_size,
497 IntArrayRef stride,
498 IntArrayRef padding,
499 int64_t groups) {
500 CheckedFrom c = "slow_conv3d_backward_parameters_cpu";
501 auto grad_weight_arg = TensorArg(grad_weight, "grad_weight_arg", 0);
502
503 const int64_t kernel_depth = kernel_size[0];
504 const int64_t kernel_height = kernel_size[1];
505 const int64_t kernel_width = kernel_size[2];
506 const int64_t pad_depth = padding[0];
507 const int64_t pad_height = padding[1];
508 const int64_t pad_width = padding[2];
509 const int64_t stride_depth = stride[0];
510 const int64_t stride_height = stride[1];
511 const int64_t stride_width = stride[2];
512
513 slow_conv3d_shape_check(
514 input,
515 grad_output,
516 grad_weight,
517 {},
518 kernel_depth,
519 kernel_height,
520 kernel_width,
521 stride_depth,
522 stride_height,
523 stride_width,
524 pad_depth,
525 pad_height,
526 pad_width,
527 groups,
528 true);
529
530 Tensor grad_weight_2d = view_weight_2d(grad_weight);
531 checkContiguous(c, grad_weight_arg);
532
533 auto grad_output_contiguous = grad_output.contiguous();
534
535 const int64_t batch_size = input.size(0);
536 Tensor finput = compute_columns3d(input, stride, padding, kernel_size, groups);
537
538 AT_DISPATCH_FLOATING_TYPES_AND2(
539 kBFloat16, kHalf, input.scalar_type(), "slow_conv3d_cpu_grad_weight", [&] {
540 auto grad_weight_2d_a = grad_weight_2d.accessor<scalar_t, 2>();
541 auto grad_output_a = grad_output_contiguous.accessor<const scalar_t, 5>();
542 auto finput_a = finput.accessor<const scalar_t, 3>();
543 for (const auto t : c10::irange(batch_size)) {
544 auto grad_output_t = grad_output_a[t];
545 auto finput_t = finput_a[t];
546 slow_conv3d_backward_weight_frame(
547 grad_weight_2d_a, grad_output_t, finput_t, groups);
548 }
549 });
550 }
551
552 } // namespace
553
slow_conv3d_forward_out_cpu(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,Tensor & output)554 Tensor& slow_conv3d_forward_out_cpu(const Tensor& self,
555 const Tensor& weight,
556 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
557 IntArrayRef stride,
558 IntArrayRef padding,
559 Tensor& output) {
560 // See [Note: hacky wrapper removal for optional tensor]
561 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
562 const Tensor& bias = *bias_maybe_owned;
563
564 const int64_t kernel_depth = kernel_size[0];
565 const int64_t kernel_height = kernel_size[1];
566 const int64_t kernel_width = kernel_size[2];
567 const int64_t pad_depth = padding[0];
568 const int64_t pad_height = padding[1];
569 const int64_t pad_width = padding[2];
570 const int64_t stride_depth = stride[0];
571 const int64_t stride_height = stride[1];
572 const int64_t stride_width = stride[2];
573
574 // TODO: hacky way of deciding the groups
575 // Assuming the group size is checked in upstream functions
576 const int64_t groups = weight.size(1) > 0 ? self.size(1) / weight.size(1) : 0;
577
578 slow_conv3d_shape_check(
579 self,
580 Tensor(),
581 weight,
582 bias,
583 kernel_depth,
584 kernel_height,
585 kernel_width,
586 stride_depth,
587 stride_height,
588 stride_width,
589 pad_depth,
590 pad_height,
591 pad_width,
592 groups,
593 false);
594
595 const Tensor input = self.contiguous();
596 const Tensor weight_2d = view_weight_2d(weight);
597
598 const int64_t dim_planes = 1;
599 const int64_t dim_depth = 2;
600 const int64_t dim_height = 3;
601 const int64_t dim_width = 4;
602
603 const int64_t n_input_plane = input.size(dim_planes);
604 const int64_t input_depth = input.size(dim_depth);
605 const int64_t input_height = input.size(dim_height);
606 const int64_t input_width = input.size(dim_width);
607 const int64_t n_output_plane = weight_2d.size(0);
608 const int64_t output_depth =
609 (input_depth + 2 * pad_depth - kernel_depth) / stride_depth + 1;
610 const int64_t output_height =
611 (input_height + 2 * pad_height - kernel_height) / stride_height + 1;
612 const int64_t output_width =
613 (input_width + 2 * pad_width - kernel_width) / stride_width + 1;
614
615 Tensor finput = compute_columns3d(input, stride, padding, kernel_size, groups);
616 const int64_t batch_size = input.size(0);
617 output.resize_(
618 {batch_size, n_output_plane, output_depth, output_height, output_width});
619 if (bias.defined()) {
620 output.copy_(bias.reshape({-1, 1, 1, 1}));
621 }
622
623 TORCH_CHECK(output.is_contiguous(), "slow_conv3d output must be contiguous");
624
625 AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "slow_conv3d_cpu", [&] {
626 auto input_a = input.accessor<const scalar_t, 5>();
627 auto output_a = output.accessor<scalar_t, 5>();
628 auto finput_a = finput.accessor<const scalar_t, 3>();
629 auto weight_2d_a = weight_2d.accessor<const scalar_t, 2>();
630
631 at::parallel_for(
632 0, batch_size, CONV3D_GRAIN_SALT, [&](int64_t start, int64_t end) {
633 for (const auto t : c10::irange(start, end)) {
634 auto input_t = input_a[t];
635 auto output_t = output_a[t];
636 auto finput_t = finput_a[t];
637 slow_conv3d_update_output_frame(
638 input_t,
639 output_t,
640 weight_2d_a,
641 bias.defined(),
642 finput_t,
643 kernel_depth,
644 kernel_height,
645 kernel_width,
646 stride_depth,
647 stride_height,
648 stride_width,
649 pad_depth,
650 pad_height,
651 pad_width,
652 n_input_plane,
653 groups,
654 input_depth,
655 input_height,
656 input_width,
657 n_output_plane,
658 output_depth,
659 output_height,
660 output_width);
661 }
662 });
663 });
664
665 return output;
666 }
667
slow_conv3d_forward_cpu(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding)668 Tensor slow_conv3d_forward_cpu(
669 const Tensor& self,
670 const Tensor& weight,
671 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
672 IntArrayRef stride,
673 IntArrayRef padding) {
674 // See [Note: hacky wrapper removal for optional tensor]
675 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
676 const Tensor& bias = *bias_maybe_owned;
677
678 auto output = at::empty({0}, self.options());
679 at::native::slow_conv3d_forward_out_cpu(
680 self,
681 weight,
682 kernel_size,
683 bias,
684 stride,
685 padding,
686 output);
687 return output;
688 }
689
slow_conv3d_backward_out_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias)690 static std::tuple<Tensor&, Tensor&, Tensor&> slow_conv3d_backward_out_cpu(const Tensor& grad_output,
691 const Tensor& self,
692 const Tensor& weight,
693 IntArrayRef kernel_size,
694 IntArrayRef stride,
695 IntArrayRef padding,
696 Tensor& grad_input,
697 Tensor& grad_weight,
698 Tensor& grad_bias) {
699 // TODO: hacky way of determine the group size
700 int64_t groups = self.size(1) / weight.size(1);
701 if (grad_input.defined()) {
702 slow_conv3d_backward_out_cpu_template(
703 grad_input,
704 grad_output,
705 self,
706 weight,
707 kernel_size,
708 stride,
709 padding,
710 groups);
711 }
712
713 if (grad_bias.defined()) {
714 at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3, 4});
715 }
716
717 if (grad_weight.defined()) {
718 grad_weight.resize_(weight.sizes());
719 grad_weight.zero_();
720 slow_conv3d_backward_parameters_out_cpu_template(
721 grad_weight,
722 self,
723 grad_output,
724 kernel_size,
725 stride,
726 padding,
727 groups);
728 }
729
730 return std::tuple<Tensor&, Tensor&, Tensor&>(
731 grad_input, grad_weight, grad_bias);
732 }
733
slow_conv3d_backward_cpu(const Tensor & grad_output,const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,std::array<bool,3> output_mask)734 std::tuple<Tensor, Tensor, Tensor> slow_conv3d_backward_cpu(
735 const Tensor& grad_output,
736 const Tensor& self,
737 const Tensor& weight,
738 IntArrayRef kernel_size,
739 IntArrayRef stride,
740 IntArrayRef padding,
741 std::array<bool, 3> output_mask) {
742 Tensor grad_input;
743 Tensor grad_weight;
744 Tensor grad_bias;
745
746 if (output_mask[0]) {
747 grad_input = at::empty({0}, grad_output.options());
748 }
749
750 if (output_mask[1]) {
751 grad_weight = at::empty({0}, grad_output.options());
752 }
753
754 if (output_mask[2]) {
755 grad_bias = at::empty({0}, grad_output.options());
756 }
757
758 at::native::slow_conv3d_backward_out_cpu(
759 grad_output,
760 self,
761 weight,
762 kernel_size,
763 stride,
764 padding,
765 grad_input,
766 grad_weight,
767 grad_bias);
768
769 return std::make_tuple(grad_input, grad_weight, grad_bias);
770 }
771
slow_conv3d_out(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,Tensor & output)772 Tensor& slow_conv3d_out(const Tensor& self,
773 const Tensor& weight,
774 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
775 IntArrayRef stride,
776 IntArrayRef padding,
777 Tensor& output) {
778 // See [Note: hacky wrapper removal for optional tensor]
779 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
780 const Tensor& bias = *bias_maybe_owned;
781
782 return at::slow_conv3d_forward_out(
783 output,
784 self,
785 weight,
786 kernel_size,
787 bias,
788 stride,
789 padding);
790 }
791
slow_conv3d(const Tensor & self,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding)792 Tensor slow_conv3d(
793 const Tensor& self,
794 const Tensor& weight,
795 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
796 IntArrayRef stride,
797 IntArrayRef padding) {
798 // See [Note: hacky wrapper removal for optional tensor]
799 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
800 const Tensor& bias = *bias_maybe_owned;
801
802 return at::slow_conv3d_forward(self, weight, kernel_size, bias, stride, padding);
803 }
804
805 } // namespace at::native
806