1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorUtils.h>
5
6 #include <ATen/native/ConvUtils.h>
7 #include <ATen/native/CPUBlas.h>
8 #include <ATen/native/vol2col.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/ones.h>
17 #include <ATen/ops/slow_conv_transpose3d_native.h>
18 #include <ATen/ops/sum.h>
19 #endif
20
21 namespace at::native {
22
23 template<typename scalar_t>
24 void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t lda, scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy);
25
26 namespace {
27
slow_conv_transpose3d_shape_check(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const Tensor & bias,int kernel_depth,int kernel_width,int kernel_height,int stride_depth,int stride_width,int stride_height,int padding_depth,int padding_width,int padding_height,int dilation_depth,int dilation_width,int dilation_height,int output_padding_depth,int output_padding_width,int output_padding_height,int weight_nullable)28 static inline void slow_conv_transpose3d_shape_check(
29 const Tensor& input,
30 const Tensor& grad_output,
31 const Tensor& weight,
32 const Tensor& bias,
33 int kernel_depth,
34 int kernel_width,
35 int kernel_height,
36 int stride_depth,
37 int stride_width,
38 int stride_height,
39 int padding_depth,
40 int padding_width,
41 int padding_height,
42 int dilation_depth,
43 int dilation_width,
44 int dilation_height,
45 int output_padding_depth,
46 int output_padding_width,
47 int output_padding_height,
48 int weight_nullable) {
49 TORCH_CHECK(
50 input.numel() != 0 && (input.dim() == 4 || input.dim() == 5),
51 "non-empty 4D or 5D (batch mode) tensor expected for input, but got: ",
52 input.sizes());
53 TORCH_CHECK(
54 stride_depth > 0 && stride_width > 0 && stride_height > 0,
55 "stride should be greater than zero, but got stride_depth: ",
56 stride_depth,
57 " stride_height: ",
58 stride_height,
59 " stride_width: ",
60 stride_width);
61 TORCH_CHECK(
62 dilation_depth > 0 && dilation_width > 0 && dilation_height > 0,
63 "dilation should be greater than zero, but got dilation_depth: ",
64 dilation_depth,
65 ", dilation_height: ",
66 dilation_height,
67 ", dilation_width: ",
68 dilation_width);
69 TORCH_CHECK(
70 (output_padding_depth < stride_depth ||
71 output_padding_depth < dilation_depth) &&
72 (output_padding_width < stride_width ||
73 output_padding_width < dilation_width) &&
74 (output_padding_height < stride_height ||
75 output_padding_height < dilation_height),
76 "output padding must be smaller than either stride or dilation,",
77 " but got output_padding_depth: ",
78 output_padding_depth,
79 " output_padding_height: ",
80 output_padding_height,
81 " output_padding_width: ",
82 output_padding_width,
83 " stride_depth: ",
84 stride_depth,
85 " stride_height: ",
86 stride_height,
87 " stride_width: ",
88 stride_width,
89 " dilation_depth: ",
90 dilation_depth,
91 " dilation_height: ",
92 dilation_height,
93 " dilation_width: ",
94 dilation_width);
95
96 // number of input & output planes and kernel size is indirectly defined by
97 // the weight tensor
98 if (weight.defined()) {
99 /* TODO: TORCH_CHECK just have 2 args: condition and message */
100 TORCH_CHECK(
101 weight.numel() != 0 && weight.dim() == 5,
102 "non-empty 5D (n_output_plane x n_input_plane x kernel_depth",
103 " x kernel_height x kernel_width) tensor ",
104 "expected for weight, but got: ",
105 weight.sizes());
106 if (bias.defined()) {
107 check_dim_size(bias, 1, 0, weight.size(1));
108 }
109 } else if (!weight_nullable) {
110 AT_ERROR("weight tensor is expected to be non-nullable");
111 }
112
113 int ndim = input.dim();
114 int dimf = 0;
115 int dimd = 1;
116 int dimh = 2;
117 int dimw = 3;
118
119 if (ndim == 5) {
120 dimf++;
121 dimd++;
122 dimh++;
123 dimw++;
124 }
125
126 if (weight.defined()) {
127 const int64_t n_input_plane = weight.size(0);
128 check_dim_size(input, ndim, dimf, n_input_plane);
129 }
130
131 const int64_t input_width = input.size(dimw);
132 const int64_t input_height = input.size(dimh);
133 const int64_t input_depth = input.size(dimd);
134 const int64_t output_depth = (input_depth - 1) * stride_depth -
135 2 * padding_depth + (dilation_depth * (kernel_depth - 1) + 1) +
136 output_padding_depth;
137 const int64_t output_height = (input_height - 1) * stride_height -
138 2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
139 output_padding_height;
140 const int64_t output_width = (input_width - 1) * stride_width -
141 2 * padding_width + (dilation_width * (kernel_width - 1) + 1) +
142 output_padding_width;
143
144 if (output_depth < 1 || output_width < 1 || output_height < 1) {
145 AT_ERROR(
146 "Given input size per channel: (",
147 input_depth,
148 " x ",
149 input_height,
150 " x ",
151 input_width,
152 "). "
153 "Calculated output size per channel: (",
154 output_depth,
155 " x ",
156 output_height,
157 " x ",
158 output_width,
159 "). Output size is too small");
160 }
161
162 if (grad_output.defined()) {
163 if (weight.defined()) {
164 const int64_t n_output_plane = weight.size(1);
165 check_dim_size(grad_output, ndim, dimf, n_output_plane);
166 } else if (bias.defined()) {
167 const int64_t n_output_plane = bias.size(0);
168 check_dim_size(grad_output, ndim, dimf, n_output_plane);
169 }
170 check_dim_size(grad_output, ndim, dimd, output_depth);
171 check_dim_size(grad_output, ndim, dimh, output_height);
172 check_dim_size(grad_output, ndim, dimw, output_width);
173 }
174 }
175
slow_conv_transpose3d_out_cpu_template(Tensor & output,const Tensor & input_,const Tensor & weight_,IntArrayRef kernel_size,const Tensor & bias_,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)176 void slow_conv_transpose3d_out_cpu_template(
177 Tensor& output,
178 const Tensor& input_, // 4D or 5D (batch) tensor
179 const Tensor& weight_, // weight tensor (n_input_plane x n_output_plane x
180 // kernel_depth x kernel_height x kernel_width)
181 IntArrayRef kernel_size,
182 const Tensor& bias_,
183 IntArrayRef stride,
184 IntArrayRef padding,
185 IntArrayRef output_padding,
186 IntArrayRef dilation) {
187 TORCH_CHECK(
188 kernel_size.size() == 3,
189 "It is expected kernel_size equals to 3, but got size ",
190 kernel_size.size());
191
192 TORCH_CHECK(
193 dilation.size() == 3,
194 "It is expected dilation equals to 3, but got size ",
195 dilation.size());
196
197 TORCH_CHECK(
198 padding.size() == 3,
199 "It is expected padding equals to 3, but got size ",
200 padding.size());
201
202 TORCH_CHECK(
203 stride.size() == 3,
204 "It is expected stride equals to 3, but got size ",
205 stride.size());
206
207 TORCH_CHECK(
208 output_padding.size() == 3,
209 "It is expected stride equals to 3, but got size ",
210 output_padding.size());
211
212 int64_t kernel_depth = kernel_size[0];
213 int64_t kernel_height = kernel_size[1];
214 int64_t kernel_width = kernel_size[2];
215 int64_t dilation_depth = dilation[0];
216 int64_t dilation_height = dilation[1];
217 int64_t dilation_width = dilation[2];
218 int64_t padding_depth = padding[0];
219 int64_t padding_height = padding[1];
220 int64_t padding_width = padding[2];
221 int64_t stride_depth = stride[0];
222 int64_t stride_height = stride[1];
223 int64_t stride_width = stride[2];
224 int64_t output_padding_depth = output_padding[0];
225 int64_t output_padding_height = output_padding[1];
226 int64_t output_padding_width = output_padding[2];
227
228 slow_conv_transpose3d_shape_check(
229 input_,
230 Tensor(),
231 weight_,
232 bias_,
233 kernel_depth,
234 kernel_width,
235 kernel_height,
236 stride_depth,
237 stride_width,
238 stride_height,
239 padding_depth,
240 padding_width,
241 padding_height,
242 dilation_depth,
243 dilation_width,
244 dilation_height,
245 output_padding_depth,
246 output_padding_width,
247 output_padding_height,
248 0);
249
250 Tensor input = input_.contiguous();
251 Tensor weight = weight_.contiguous();
252 Tensor bias = bias_.defined() ? bias_.contiguous() : bias_;
253
254 const int n_input_plane = (int)weight.size(0);
255 const int n_output_plane = (int)weight.size(1);
256
257 bool is_batch = false;
258 if (input.dim() == 4) {
259 // Force batch
260 is_batch = true;
261 input.resize_(
262 {1, input.size(0), input.size(1), input.size(2), input.size(3)});
263 }
264
265 const int64_t input_width = input.size(4);
266 const int64_t input_height = input.size(3);
267 const int64_t input_depth = input.size(2);
268
269 const int64_t output_depth = (input_depth - 1) * stride_depth -
270 2 * padding_depth + (dilation_depth * (kernel_depth - 1) + 1) +
271 output_padding_depth;
272 const int64_t output_height = (input_height - 1) * stride_height -
273 2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
274 output_padding_height;
275 const int64_t output_width = (input_width - 1) * stride_width -
276 2 * padding_width + (dilation_width * (kernel_width - 1) + 1) +
277 output_padding_width;
278
279 // Batch size + input planes
280 const int64_t batch_size = input.size(0);
281
282 // Resize output
283 output.resize_(
284 {batch_size, n_output_plane, output_depth, output_height, output_width});
285
286 // Create temporary columns
287 Tensor columns = at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
288 input_depth * input_height * input_width}, input.options());
289
290 // Define a buffer of ones, for bias accumulation
291 Tensor ones = bias.defined() ? at::ones({output_depth, output_height, output_width}, input_.options()) : Tensor();
292
293 AT_DISPATCH_FLOATING_TYPES_AND3(at::ScalarType::Long, at::ScalarType::BFloat16, at::ScalarType::Half,
294 input.scalar_type(), "slow_conv_transpose3d_out_cpu", [&] {
295 // Helpers
296 Tensor input_n;
297 Tensor output_n;
298
299 int64_t elt;
300 // For each elt in batch, do:
301 for (elt = 0; elt < batch_size; ++elt) {
302 // Matrix mulitply per output:
303 input_n = input.select(0, elt);
304 output_n = output.select(0, elt);
305
306 // M,N,K are dims of matrix A and B
307 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
308 const int64_t m =
309 weight.size(1) * weight.size(2) * weight.size(3) * weight.size(4);
310 const int64_t n = columns.size(1);
311 const int64_t k = weight.size(0);
312
313 // Do GEMM (note: this is a bit confusing because gemm assumes
314 // column-major matrices)
315 cpublas::gemm(
316 TransposeType::NoTranspose,
317 TransposeType::Transpose,
318 n,
319 m,
320 k,
321 static_cast<scalar_t>(1),
322 input_n.const_data_ptr<scalar_t>(),
323 n,
324 weight.const_data_ptr<scalar_t>(),
325 m,
326 static_cast<scalar_t>(0),
327 columns.mutable_data_ptr<scalar_t>(),
328 n);
329
330 // Unpack columns back into input:
331 at::native::col2vol<scalar_t>(
332 columns.const_data_ptr<scalar_t>(),
333 n_output_plane,
334 output_depth,
335 output_height,
336 output_width,
337 input_depth,
338 input_height,
339 input_width,
340 kernel_depth,
341 kernel_height,
342 kernel_width,
343 padding_depth,
344 padding_height,
345 padding_width,
346 stride_depth,
347 stride_height,
348 stride_width,
349 dilation_depth,
350 dilation_height,
351 dilation_width,
352 output_n.data_ptr<scalar_t>());
353
354 // Do Bias after:
355 // M,N,K are dims of matrix A and B
356 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
357 const int64_t m_ = n_output_plane;
358 const int64_t n_ = output_depth * output_height * output_width;
359 const int64_t k_ = 1;
360
361 // Do GEMM (note: this is a bit confusing because gemm assumes
362 // column-major matrices)
363 if (bias.defined()) {
364 cpublas::gemm(
365 TransposeType::Transpose,
366 TransposeType::NoTranspose,
367 n_,
368 m_,
369 k_,
370 static_cast<scalar_t>(1),
371 ones.const_data_ptr<scalar_t>(),
372 k_,
373 bias.const_data_ptr<scalar_t>(),
374 k_,
375 static_cast<scalar_t>(1),
376 output_n.mutable_data_ptr<scalar_t>(),
377 n_);
378 }
379 }
380
381 // Resize output
382 if (is_batch) {
383 output.resize_(
384 {n_output_plane, output_depth, output_height, output_width});
385 input.resize_(
386 {n_input_plane, input_depth, input_height, input_width});
387 }
388 });
389 }
390
slow_conv_transpose3d_backward_out_cpu_template(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_input,const Tensor & weight_,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)391 void slow_conv_transpose3d_backward_out_cpu_template(
392 const Tensor& input_,
393 const Tensor& grad_output_,
394 Tensor& grad_input,
395 const Tensor& weight_,
396 IntArrayRef kernel_size,
397 IntArrayRef stride,
398 IntArrayRef padding,
399 IntArrayRef output_padding,
400 IntArrayRef dilation) {
401 TORCH_CHECK(
402 kernel_size.size() == 3,
403 "It is expected kernel_size equals to 3, but got size ",
404 kernel_size.size());
405
406 TORCH_CHECK(
407 dilation.size() == 3,
408 "It is expected dilation equals to 3, but got size ",
409 dilation.size());
410
411 TORCH_CHECK(
412 padding.size() == 3,
413 "It is expected padding equals to 3, but got size ",
414 padding.size());
415
416 TORCH_CHECK(
417 stride.size() == 3,
418 "It is expected stride equals to 3, but got size ",
419 stride.size());
420
421 TORCH_CHECK(
422 output_padding.size() == 3,
423 "It is expected stride equals to 3, but got size ",
424 output_padding.size());
425
426 int64_t kernel_depth = kernel_size[0];
427 int64_t kernel_height = kernel_size[1];
428 int64_t kernel_width = kernel_size[2];
429 int64_t dilation_depth = dilation[0];
430 int64_t dilation_height = dilation[1];
431 int64_t dilation_width = dilation[2];
432 int64_t padding_depth = padding[0];
433 int64_t padding_height = padding[1];
434 int64_t padding_width = padding[2];
435 int64_t stride_depth = stride[0];
436 int64_t stride_height = stride[1];
437 int64_t stride_width = stride[2];
438 int64_t output_padding_depth = output_padding[0];
439 int64_t output_padding_height = output_padding[1];
440 int64_t output_padding_width = output_padding[2];
441
442 // number of input & output planes and kernel size is indirectly defined by
443 // the weight tensor
444 slow_conv_transpose3d_shape_check(
445 input_,
446 grad_output_,
447 weight_,
448 Tensor(),
449 kernel_depth,
450 kernel_width,
451 kernel_height,
452 stride_depth,
453 stride_width,
454 stride_height,
455 padding_depth,
456 padding_width,
457 padding_height,
458 dilation_depth,
459 dilation_width,
460 dilation_height,
461 output_padding_depth,
462 output_padding_width,
463 output_padding_height,
464 0);
465
466 Tensor input = input_.contiguous();
467 Tensor weight = weight_.contiguous();
468 Tensor grad_output = grad_output_.contiguous();
469
470 const int64_t n_input_plane = weight.size(0);
471 const int64_t n_output_plane = weight.size(1);
472
473 bool is_batch = false;
474 if (input.dim() == 4) {
475 // Force batch
476 is_batch = true;
477 input.resize_(
478 {1, input.size(0), input.size(1), input.size(2), input.size(3)});
479 grad_output.resize_({1,
480 grad_output.size(0),
481 grad_output.size(1),
482 grad_output.size(2),
483 grad_output.size(3)});
484 }
485
486 const int64_t input_width = input.size(4);
487 const int64_t input_height = input.size(3);
488 const int64_t input_depth = input.size(2);
489
490 const int64_t output_depth = (input_depth - 1) * stride_depth -
491 2 * padding_depth + (dilation_depth * (kernel_depth - 1) + 1) +
492 output_padding_depth;
493 const int64_t output_height = (input_height - 1) * stride_height -
494 2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
495 output_padding_height;
496 const int64_t output_width = (input_width - 1) * stride_width -
497 2 * padding_width + (dilation_width * (kernel_width - 1) + 1) +
498 output_padding_width;
499
500 // Batch size + input planes
501 const int64_t batch_size = input.size(0);
502
503 // Resize output
504 grad_input.resize_(
505 {batch_size, n_input_plane, input_depth, input_height, input_width});
506 grad_input.zero_();
507
508 // Create temporary columns
509 bool need_columns = (kernel_depth != 1 || kernel_height != 1 || kernel_width != 1 ||
510 stride_depth != 1 || stride_height != 1 || stride_width != 1 ||
511 dilation_depth != 1 || dilation_height != 1 ||
512 dilation_width != 1 || padding_depth != 0 ||
513 padding_height != 0 || padding_width != 0);
514 Tensor grad_columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
515 input_depth * input_height * input_width}, input.options()) : Tensor();
516
517 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
518 input.scalar_type(), "slow_conv_transpose3d_backward_out_cpu", [&] {
519 // Helpers
520 Tensor grad_input_n;
521 Tensor grad_output_n;
522
523 int64_t elt;
524 // For each elt in batch, do:
525 for (elt = 0; elt < batch_size; ++elt) {
526 // Matrix mulitply per sample:
527 grad_input_n = grad_input.select(0, elt);
528 grad_output_n = grad_output.select(0, elt);
529
530 if (need_columns) {
531 // Extract columns:
532 at::native::vol2col<scalar_t>(
533 grad_output_n.const_data_ptr<scalar_t>(),
534 n_output_plane,
535 output_depth,
536 output_height,
537 output_width,
538 input_depth,
539 input_height,
540 input_width,
541 kernel_depth,
542 kernel_height,
543 kernel_width,
544 padding_depth,
545 padding_height,
546 padding_width,
547 stride_depth,
548 stride_height,
549 stride_width,
550 dilation_depth,
551 dilation_height,
552 dilation_width,
553 grad_columns.mutable_data_ptr<scalar_t>());
554 }
555
556 // M,N,K are dims of matrix A and B
557 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
558 const int64_t m = weight.size(0);
559 const int64_t n = input_depth * input_height * input_width;
560 const int64_t k =
561 weight.size(1) * weight.size(2) * weight.size(3) * weight.size(4);
562
563 // Do GEMM (note: this is a bit confusing because gemm assumes
564 // column-major matrices)
565 auto gemm_in_ptr = need_columns ? grad_columns.const_data_ptr<scalar_t>()
566 : grad_output_n.const_data_ptr<scalar_t>();
567 cpublas::gemm(
568 TransposeType::NoTranspose,
569 TransposeType::NoTranspose,
570 n,
571 m,
572 k,
573 static_cast<scalar_t>(1),
574 gemm_in_ptr,
575 n,
576 weight.const_data_ptr<scalar_t>(),
577 k,
578 static_cast<scalar_t>(0),
579 grad_input_n.mutable_data_ptr<scalar_t>(),
580 n);
581 }
582
583 // Resize output
584 if (is_batch) {
585 grad_output.resize_(
586 {n_output_plane, output_depth, output_height, output_width});
587 input.resize_(
588 {n_input_plane, input_depth, input_height, input_width});
589 grad_input.resize_(
590 {n_input_plane, input_depth, input_height, input_width});
591 }
592 });
593 }
594
slow_conv_transpose3d_acc_grad_parameters_cpu(const Tensor & input_,const Tensor & grad_output_,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,int scale_)595 void slow_conv_transpose3d_acc_grad_parameters_cpu(
596 const Tensor& input_,
597 const Tensor& grad_output_,
598 Tensor& grad_weight,
599 Tensor& grad_bias,
600 IntArrayRef kernel_size,
601 IntArrayRef stride,
602 IntArrayRef padding,
603 IntArrayRef output_padding,
604 IntArrayRef dilation,
605 int scale_) {
606 TORCH_CHECK(
607 kernel_size.size() == 3,
608 "It is expected kernel_size equals to 3, but got size ",
609 kernel_size.size());
610
611 TORCH_CHECK(
612 dilation.size() == 3,
613 "It is expected dilation equals to 3, but got size ",
614 dilation.size());
615
616 TORCH_CHECK(
617 padding.size() == 3,
618 "It is expected padding equals to 3, but got size ",
619 padding.size());
620
621 TORCH_CHECK(
622 stride.size() == 3,
623 "It is expected stride equals to 3, but got size ",
624 stride.size());
625
626 TORCH_CHECK(
627 output_padding.size() == 3,
628 "It is expected stride equals to 3, but got size ",
629 output_padding.size());
630
631 int64_t kernel_depth = kernel_size[0];
632 int64_t kernel_height = kernel_size[1];
633 int64_t kernel_width = kernel_size[2];
634 int64_t dilation_depth = dilation[0];
635 int64_t dilation_height = dilation[1];
636 int64_t dilation_width = dilation[2];
637 int64_t padding_depth = padding[0];
638 int64_t padding_height = padding[1];
639 int64_t padding_width = padding[2];
640 int64_t stride_depth = stride[0];
641 int64_t stride_height = stride[1];
642 int64_t stride_width = stride[2];
643 int64_t output_padding_depth = output_padding[0];
644 int64_t output_padding_height = output_padding[1];
645 int64_t output_padding_width = output_padding[2];
646
647 // number of input & output planes and kernel size is indirectly defined by
648 // the grad_weight tensor
649 slow_conv_transpose3d_shape_check(
650 input_,
651 grad_output_,
652 grad_weight,
653 grad_bias,
654 kernel_depth,
655 kernel_width,
656 kernel_height,
657 stride_depth,
658 stride_width,
659 stride_height,
660 padding_depth,
661 padding_width,
662 padding_height,
663 dilation_depth,
664 dilation_width,
665 dilation_height,
666 output_padding_depth,
667 output_padding_width,
668 output_padding_height,
669 1);
670
671 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
672 int64_t n_output_plane;
673 if (grad_weight.defined()) {
674 n_output_plane = grad_weight.size(1);
675 } else if (grad_bias.defined()) {
676 n_output_plane = grad_bias.size(0);
677 } else {
678 return;
679 }
680
681 Tensor input = input_.contiguous();
682 Tensor grad_output = grad_output_.contiguous();
683
684 if (grad_weight.defined()) {
685 TORCH_CHECK(grad_weight.is_contiguous(), "grad_weight needs to be contiguous");
686 }
687 if (grad_bias.defined()) {
688 TORCH_CHECK(grad_bias.is_contiguous(), "grad_bias needs to be contiguous");
689 }
690
691 bool is_batch = false;
692 if (input.dim() == 4) {
693 // Force batch
694 is_batch = true;
695 input.resize_(
696 {1, input.size(0), input.size(1), input.size(2), input.size(3)});
697 grad_output.resize_({1,
698 grad_output.size(0),
699 grad_output.size(1),
700 grad_output.size(2),
701 grad_output.size(3)});
702 }
703
704 const int64_t input_width = input.size(4);
705 const int64_t input_height = input.size(3);
706 const int64_t input_depth = input.size(2);
707
708 const int64_t output_depth = (input_depth - 1) * stride_depth -
709 2 * padding_depth + (dilation_depth * (kernel_depth - 1) + 1) +
710 output_padding_depth;
711 const int64_t output_height = (input_height - 1) * stride_height -
712 2 * padding_height + (dilation_height * (kernel_height - 1) + 1) +
713 output_padding_height;
714 const int64_t output_width = (input_width - 1) * stride_width -
715 2 * padding_width + (dilation_width * (kernel_width - 1) + 1) +
716 output_padding_width;
717
718 // Batch size + input planes
719 const int64_t batch_size = input.size(0);
720
721 // Create temporary columns
722 bool need_columns = (kernel_depth != 1 || kernel_height != 1 || kernel_width != 1 ||
723 stride_depth != 1 || stride_height != 1 || stride_width != 1 ||
724 dilation_depth != 1 || dilation_height != 1 ||
725 dilation_width != 1 || padding_depth != 0 ||
726 padding_height != 0 || padding_width != 0);
727 Tensor columns = need_columns ? at::empty({n_output_plane * kernel_width * kernel_height * kernel_depth,
728 input_depth * input_height * input_width}, input.options()) : Tensor();
729
730 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
731 input.scalar_type(),
732 "slow_conv_transpose3d_acc_grad_parameters_cpu",
733 [&] {
734 // Helpers
735 Tensor input_n;
736 Tensor grad_output_n;
737
738 scalar_t scale = static_cast<scalar_t>(scale_);
739
740 int64_t elt;
741 // For each elt in batch, do:
742 for (elt = 0; elt < batch_size; ++elt) {
743 // Matrix mulitply per output:
744 grad_output_n = grad_output.select(0, elt);
745
746 // Do Weight:
747 if (grad_weight.defined()) {
748 // Matrix mulitply per output:
749 input_n = input.select(0, elt);
750
751 if (need_columns) {
752 // Extract columns:
753 at::native::vol2col<scalar_t>(
754 grad_output_n.const_data_ptr<scalar_t>(),
755 n_output_plane,
756 output_depth,
757 output_height,
758 output_width,
759 input_depth,
760 input_height,
761 input_width,
762 kernel_depth,
763 kernel_height,
764 kernel_width,
765 padding_depth,
766 padding_height,
767 padding_width,
768 stride_depth,
769 stride_height,
770 stride_width,
771 dilation_depth,
772 dilation_height,
773 dilation_width,
774 columns.mutable_data_ptr<scalar_t>());
775 }
776
777 // M,N,K are dims of matrix A and B
778 // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
779 const int64_t n = n_output_plane * kernel_width * kernel_height * kernel_depth;
780 const int64_t m = input_n.size(0); // n_input_plane
781 const int64_t k = input_depth * input_height * input_width;
782
783 // Do GEMM (note: this is a bit confusing because gemm assumes
784 // column-major matrices)
785 auto gemm_in_ptr = need_columns ? columns.const_data_ptr<scalar_t>()
786 : grad_output_n.const_data_ptr<scalar_t>();
787 cpublas::gemm(
788 TransposeType::Transpose,
789 TransposeType::NoTranspose,
790 n,
791 m,
792 k,
793 static_cast<scalar_t>(scale),
794 gemm_in_ptr,
795 k,
796 input_n.const_data_ptr<scalar_t>(),
797 k,
798 static_cast<scalar_t>(1),
799 grad_weight.mutable_data_ptr<scalar_t>(),
800 n);
801 }
802 }
803
804 if (grad_bias.defined()) {
805 at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3, 4});
806 }
807
808 // Resize
809 if (is_batch) {
810 grad_output.resize_(
811 {n_output_plane, output_depth, output_height, output_width});
812 input.resize_(
813 {input.size(1), input_depth, input_height, input_width});
814 }
815 });
816 }
817
818 } // namespace
819
slow_conv_transpose3d_out_cpu(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,Tensor & output)820 Tensor& slow_conv_transpose3d_out_cpu(const Tensor& input,
821 const Tensor& weight,
822 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
823 IntArrayRef stride,
824 IntArrayRef padding,
825 IntArrayRef output_padding,
826 IntArrayRef dilation,
827 Tensor& output) {
828 // See [Note: hacky wrapper removal for optional tensor]
829 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
830 const Tensor& bias = *bias_maybe_owned;
831
832 slow_conv_transpose3d_out_cpu_template(
833 output,
834 input,
835 weight,
836 kernel_size,
837 bias,
838 stride,
839 padding,
840 output_padding,
841 dilation);
842
843 return output;
844 }
845
slow_conv_transpose3d_cpu(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation)846 Tensor slow_conv_transpose3d_cpu(
847 const Tensor& input,
848 const Tensor& weight,
849 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
850 IntArrayRef stride,
851 IntArrayRef padding,
852 IntArrayRef output_padding,
853 IntArrayRef dilation) {
854 // See [Note: hacky wrapper removal for optional tensor]
855 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
856 const Tensor& bias = *bias_maybe_owned;
857
858 Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
859
860 slow_conv_transpose3d_out_cpu_template(
861 output,
862 input,
863 weight,
864 kernel_size,
865 bias,
866 stride,
867 padding,
868 output_padding,
869 dilation);
870
871 return output;
872 }
873
slow_conv_transpose3d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef dilation,std::array<bool,3> output_mask)874 static std::tuple<Tensor, Tensor, Tensor> slow_conv_transpose3d_backward_cpu(
875 const Tensor& grad_output,
876 const Tensor& input,
877 const Tensor& weight,
878 IntArrayRef kernel_size,
879 IntArrayRef stride,
880 IntArrayRef padding,
881 IntArrayRef output_padding,
882 IntArrayRef dilation,
883 std::array<bool, 3> output_mask) {
884 Tensor grad_input;
885 Tensor grad_weight;
886 Tensor grad_bias;
887
888 if (output_mask[0]) {
889 grad_input = at::empty({0}, grad_output.options());
890 } else {
891 grad_input = Tensor();
892 }
893
894 if (output_mask[1]) {
895 grad_weight = at::empty({0}, grad_output.options());
896 } else {
897 grad_weight = Tensor();
898 }
899
900 if (output_mask[2]) {
901 grad_bias = at::empty({0}, grad_output.options());
902 } else {
903 grad_bias = Tensor();
904 }
905
906 if (grad_input.defined()) {
907 slow_conv_transpose3d_backward_out_cpu_template(
908 input,
909 grad_output,
910 grad_input,
911 weight,
912 kernel_size,
913 stride,
914 padding,
915 output_padding,
916 dilation);
917 }
918
919 if (grad_weight.defined()) {
920 grad_weight.resize_(weight.sizes());
921 grad_weight.zero_();
922 }
923
924 if (grad_bias.defined()) {
925 grad_bias.resize_({weight.size(1)});
926 grad_bias.zero_();
927 }
928
929 if (grad_weight.defined() || grad_bias.defined()) {
930 slow_conv_transpose3d_acc_grad_parameters_cpu(
931 input,
932 grad_output,
933 grad_weight,
934 grad_bias,
935 kernel_size,
936 stride,
937 padding,
938 output_padding,
939 dilation,
940 1);
941 }
942
943 return std::tuple<Tensor, Tensor, Tensor>(grad_input, grad_weight, grad_bias);
944 }
945
946 REGISTER_ALL_CPU_DISPATCH(slow_conv_transpose3d_backward_stub, &slow_conv_transpose3d_backward_cpu);
947
948 } // namespace at::native
949