1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorUtils.h>
5 #include <ATen/native/ConvUtils.h>
6 #include <ATen/native/CPUBlas.h>
7 #include <ATen/native/DilatedConvolutionUtils.h>
8 #include <ATen/native/im2col.h>
9 #include <ATen/native/vol2col.h>
10 #include <c10/util/accumulate.h>
11 #include <c10/util/irange.h>
12 #include <tuple>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/slow_conv_dilated2d_native.h>
20 #include <ATen/ops/slow_conv_dilated3d_native.h>
21 #endif
22
23 namespace at::native {
24 namespace {
25
26 // hyper-volume to column, CPU
27 template <typename Dtype, int64_t dim>
hvol2col(const Dtype * data_hvol,const int channels,const IntArrayRef input_size,const IntArrayRef output_size,const IntArrayRef kernel_size,const IntArrayRef stride_size,const IntArrayRef pad_size,const IntArrayRef dilation_size,Dtype * data_col,bool is_channels_last=false)28 void hvol2col(
29 const Dtype* data_hvol,
30 const int channels,
31 const IntArrayRef input_size,
32 const IntArrayRef output_size,
33 const IntArrayRef kernel_size,
34 const IntArrayRef stride_size,
35 const IntArrayRef pad_size,
36 const IntArrayRef dilation_size,
37 Dtype* data_col,
38 bool is_channels_last = false) {
39 if (dim == 3) {
40 vol2col<Dtype>(
41 data_hvol,
42 channels,
43 input_size[0],
44 input_size[1],
45 input_size[2],
46 output_size[0],
47 output_size[1],
48 output_size[2],
49 kernel_size[0],
50 kernel_size[1],
51 kernel_size[2],
52 pad_size[0],
53 pad_size[1],
54 pad_size[2],
55 stride_size[0],
56 stride_size[1],
57 stride_size[2],
58 dilation_size[0],
59 dilation_size[1],
60 dilation_size[2],
61 data_col);
62 }
63 if (dim == 2) {
64 im2col<Dtype>(
65 data_hvol,
66 channels,
67 input_size[0],
68 input_size[1],
69 output_size[0],
70 output_size[1],
71 kernel_size[0],
72 kernel_size[1],
73 pad_size[0],
74 pad_size[1],
75 stride_size[0],
76 stride_size[1],
77 dilation_size[0],
78 dilation_size[1],
79 data_col,
80 is_channels_last);
81 }
82 }
83
84 // column to hyper-volume, CPU
85 template <typename Dtype, int64_t dim>
col2hvol(const Dtype * data_col,const int channels,const IntArrayRef input_size,const IntArrayRef output_size,const IntArrayRef kernel_size,const IntArrayRef stride_size,const IntArrayRef pad_size,const IntArrayRef dilation_size,Dtype * data_hvol,bool is_channels_last=false)86 void col2hvol(
87 const Dtype* data_col,
88 const int channels,
89 const IntArrayRef input_size,
90 const IntArrayRef output_size,
91 const IntArrayRef kernel_size,
92 const IntArrayRef stride_size,
93 const IntArrayRef pad_size,
94 const IntArrayRef dilation_size,
95 Dtype* data_hvol,
96 bool is_channels_last = false) {
97 if (dim == 3) {
98 col2vol<Dtype>(
99 data_col,
100 channels,
101 input_size[0],
102 input_size[1],
103 input_size[2],
104 output_size[0],
105 output_size[1],
106 output_size[2],
107 kernel_size[0],
108 kernel_size[1],
109 kernel_size[2],
110 pad_size[0],
111 pad_size[1],
112 pad_size[2],
113 stride_size[0],
114 stride_size[1],
115 stride_size[2],
116 dilation_size[0],
117 dilation_size[1],
118 dilation_size[2],
119 data_hvol);
120 }
121 if (dim == 2) {
122 col2im<Dtype>(
123 data_col,
124 channels,
125 input_size[0],
126 input_size[1],
127 output_size[0],
128 output_size[1],
129 kernel_size[0],
130 kernel_size[1],
131 pad_size[0],
132 pad_size[1],
133 stride_size[0],
134 stride_size[1],
135 dilation_size[0],
136 dilation_size[1],
137 data_hvol,
138 is_channels_last);
139 }
140 }
141
142 /*
143 check tensor data locations
144 */
slow_conv_dilated_location_check(const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & grad_output)145 void slow_conv_dilated_location_check(
146 const Tensor& input,
147 const Tensor& weight,
148 const Tensor& bias,
149 const Tensor& grad_output) {
150 // checking data locations of user-provided tensor arguments
151 checkBackend("slow_conv_dilated_location_check", {input, weight}, Backend::CPU);
152 if (bias.defined()) {
153 checkBackend("slow_conv_dilated_location_check", {bias}, Backend::CPU);
154 }
155 if (grad_output.defined()) {
156 checkBackend("slow_conv_dilated_location_check", {grad_output}, Backend::CPU);
157 }
158 // we are not checking the data locations of other tensor
159 // arguments such as output, grad_input, etc because of these are
160 // allocated based on input options and hence these tensors always
161 // have the same data location as of input tensor.
162 }
163
164 /*
165 slow_conv_dilated_all_cpu_template
166
167 Main worker. Computes tensors output, grad_input, grad_weight,
168 and/or grad_bias if defined, respectively.
169 */
170
171 template <int64_t dim>
slow_conv_dilated_all_cpu_template(Tensor & output,const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & grad_output,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size,bool is_channels_last=false)172 void slow_conv_dilated_all_cpu_template(
173 Tensor& output,
174 const Tensor& input,
175 const Tensor& weight,
176 const Tensor& bias,
177 const Tensor& grad_output,
178 Tensor& grad_input,
179 Tensor& grad_weight,
180 Tensor& grad_bias,
181 IntArrayRef kernel_size,
182 IntArrayRef stride_size,
183 IntArrayRef pad_size,
184 IntArrayRef dilation_size,
185 bool is_channels_last = false) {
186 slow_conv_dilated_location_check(input, weight, bias, grad_output);
187 auto options = input.options();
188 // The rear part of input tensor sizes:
189 auto input_size = input.sizes().slice(2);
190 // The rear part of output tensor sizes:
191 auto output_size = internal::get_output_size<dim>(
192 input, kernel_size, stride_size, pad_size, dilation_size);
193 int64_t batchSize = input.size(0);
194 int64_t nInputPlane = weight.size(1);
195 int64_t nOutputPlane = weight.size(0);
196 // Temporary buffer:
197 Tensor columns = at::empty({0}, options);
198 if (output.defined() || grad_weight.defined() || grad_input.defined()) {
199 const int64_t m = c10::multiply_integers(kernel_size);
200 const int64_t n = c10::multiply_integers(output_size);
201 if (is_channels_last) {
202 columns.resize_({n, m * nInputPlane});
203 } else {
204 columns.resize_({nInputPlane * m, n});
205 }
206 }
207 // Initialize
208 if (grad_weight.defined()) {
209 grad_weight.zero_();
210 }
211 if (grad_bias.defined()) {
212 grad_bias.zero_();
213 }
214 if (output.defined() && !bias.defined()) {
215 output.zero_();
216 }
217 // Helpers
218 Tensor grad_output_n;
219 std::vector<int64_t> dims(dim);
220 std::iota(dims.begin(), dims.end(), 1);
221
222 AT_DISPATCH_FLOATING_TYPES_AND3(
223 at::ScalarType::Long, at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "slow_conv_dilated<>", [&] {
224 // For each elt in batch, do:
225 for (const auto elt : c10::irange(batchSize)) {
226 // Matrix multiply per output:
227 Tensor input_n = input.select(0, elt);
228
229 // Output
230 if (output.defined()) {
231 Tensor output_n = output.select(0, elt);
232 if (bias.defined()) {
233 /*
234 Compute:
235
236 output_n = bias * ones^T
237
238 where
239
240 bias is viewed as bias.view(nOutputPlane, 1)
241
242 ones is viewed as ones.view(outputHeight * outputWidth, 1)
243
244 output_n is viewed as output_n.view(nOutputPlane, outputHeight
245 * outputWidth)
246
247 gemm assumes column-major matrices:
248
249 output_n^T = ones * bias^T
250 C = alpha * op(A) * op(B)
251 op(A) = 't', op(B) = 'n', alpha=1, beta=0
252 */
253 // The following for-loop is equivalent to the above
254 // gemm setup but avoids allocation of ones tensor:
255 for (const auto n : c10::irange(nOutputPlane)) {
256 output_n.select(0, n).fill_(bias[n]);
257 }
258 }
259 // Extract columns:
260 hvol2col<scalar_t, dim>(
261 input_n.const_data_ptr<scalar_t>(),
262 nInputPlane,
263 input_size,
264 output_size,
265 kernel_size,
266 stride_size,
267 pad_size,
268 dilation_size,
269 columns.mutable_data_ptr<scalar_t>(),
270 is_channels_last);
271 /*
272 Compute:
273
274 output_n = weight * columns + output_n
275
276 where
277
278 weight is viewed as weight.view(nOutputPlane, nInputPlane * kD *
279 kH * kW)
280
281 columns size is (nInputPlane * kH * kW) x (outputHeight *
282 outputWidth)
283
284 output_n is viewed as output_n.view(nOutputPlane, outputHeight *
285 outputWidth)
286
287 gemm assumes column-major matrices:
288
289 channels last:
290 output_n^T = weight *columns^T + output_n^T
291 C = alpha * op(A) * op(B) + beta * C
292 op(A) = 't', op(B) = 'n', alpha=1, beta=1
293
294 channels first:
295 output_n^T = columns^T * weight^T + output_n^T
296 C = alpha * op(A) * op(B) + beta * C
297 op(A) = 'n', op(B) = 'n', alpha=1, beta=1
298 */
299 if (is_channels_last) {
300 cpublas::gemm(
301 /*transa=*/TransposeType::Transpose,
302 /*transb=*/TransposeType::NoTranspose,
303 /* m=*/nOutputPlane,
304 /* n=*/columns.size(0),
305 /* k=*/columns.size(1),
306 /* alpha=*/static_cast<scalar_t>(1),
307 /* A=*/weight.const_data_ptr<scalar_t>(),
308 /* lda=*/columns.size(1),
309 /* B=*/columns.const_data_ptr<scalar_t>(),
310 /* lda=*/columns.size(1),
311 /* beta=*/static_cast<scalar_t>(1),
312 /* C=*/output_n.mutable_data_ptr<scalar_t>(),
313 /* ldc=*/nOutputPlane);
314 } else {
315 cpublas::gemm(
316 /*transa=*/TransposeType::NoTranspose,
317 /*transb=*/TransposeType::NoTranspose,
318 /* m=*/columns.size(1),
319 /* n=*/nOutputPlane,
320 /* k=*/columns.size(0),
321 /* alpha=*/static_cast<scalar_t>(1),
322 /* A=*/columns.const_data_ptr<scalar_t>(),
323 /* lda=*/columns.size(1),
324 /* B=*/weight.const_data_ptr<scalar_t>(),
325 /* ldb=*/columns.size(0),
326 /* beta=*/static_cast<scalar_t>(1),
327 /* C=*/output_n.mutable_data_ptr<scalar_t>(),
328 /* ldc=*/columns.size(1));
329 }
330 } else {
331 // All gradients
332 grad_output_n = grad_output.select(0, elt);
333 }
334
335 // Gradient of input:
336 if (grad_input.defined()) {
337 /*
338 Compute:
339
340 columns = weight^T * grad_output_n
341
342 where
343
344 weight is viewed as weight.view(nOutputPlane, nInputPlane * kH *
345 kW)
346
347 grad_output_n is viewed as grad_output_n.view(nOutputPlane,
348 outputHeight * outputWidth)
349
350 columns size is (nInputPlane * kH * kW) x (outputHeight *
351 outputWidth)
352
353 gemm assumes column-major matrices:
354
355 channels last:
356 columns^T = weight^T * grad_output_n^T
357 C = alpha * op(A) * op(B) + beta * C
358 op(A) = 'n', op(B) = 'n', alpha=1, beta=0
359
360 channels first:
361 columns^T = grad_output_n^T * weight
362 C = alpha * op(A) * op(B) + beta * C
363 op(A) = 'n', op(B) = 't', alpha=1, beta=0
364 */
365 if (is_channels_last) {
366 cpublas::gemm(
367 /*transa=*/TransposeType::NoTranspose,
368 /*transb=*/TransposeType::NoTranspose,
369 /* m=*/columns.size(1),
370 /* n=*/columns.size(0),
371 /* k=*/nOutputPlane,
372 /* alpha=*/static_cast<scalar_t>(1),
373 /* A=*/weight.const_data_ptr<scalar_t>(),
374 /* lda=*/columns.size(1),
375 /* B=*/grad_output_n.const_data_ptr<scalar_t>(),
376 /* ldb=*/nOutputPlane,
377 /* beta=*/static_cast<scalar_t>(0),
378 /* C=*/columns.mutable_data_ptr<scalar_t>(),
379 /* ldc=*/columns.size(1));
380 } else {
381 cpublas::gemm(
382 /*transa=*/TransposeType::NoTranspose,
383 /*transb=*/TransposeType::Transpose,
384 /* m=*/columns.size(1),
385 /* n=*/columns.size(0),
386 /* k=*/nOutputPlane,
387 /* alpha=*/static_cast<scalar_t>(1),
388 /* A=*/grad_output_n.const_data_ptr<scalar_t>(),
389 /* lda=*/columns.size(1),
390 /* B=*/weight.const_data_ptr<scalar_t>(),
391 /* ldb=*/columns.size(0),
392 /* beta=*/static_cast<scalar_t>(0),
393 /* C=*/columns.mutable_data_ptr<scalar_t>(),
394 /* ldc=*/columns.size(1));
395 }
396 // Unpack columns back into input:
397 Tensor grad_input_n = grad_input.select(0, elt);
398
399 col2hvol<scalar_t, dim>(
400 columns.data_ptr<scalar_t>(),
401 nInputPlane,
402 input_size,
403 output_size,
404 kernel_size,
405 stride_size,
406 pad_size,
407 dilation_size,
408 grad_input_n.data_ptr<scalar_t>(),
409 is_channels_last);
410 }
411
412 // Gradient of weight:
413 if (grad_weight.defined()) {
414 // Extract columns:
415 hvol2col<scalar_t, dim>(
416 input_n.const_data_ptr<scalar_t>(),
417 nInputPlane,
418 input_size,
419 output_size,
420 kernel_size,
421 stride_size,
422 pad_size,
423 dilation_size,
424 columns.mutable_data_ptr<scalar_t>(),
425 is_channels_last);
426 scalar_t scale = 1; // TODO: expose as argument?
427 /*
428 Compute:
429
430 grad_weight = scale * grad_output_n * columns^T + grad_weight
431
432 where
433
434 grad_output_n is viewed as grad_output_n.view(nOutputPlane,
435 outputHeight * outputWidth)
436
437 columns size is (nInputPlane * kD * kH * kW) x (outputHeight *
438 outputWidth)
439
440 grad_weight is viewed as grad_weight.view(nOutputPlane,
441 nInputPlane * kH * kW)
442
443 gemm assumes column-major matrices:
444
445 channels last:
446 grad_weight^T = scale * columns^T * grad_output_n + grad_weight^T
447 C = alpha * op(A) * op(B) + beta * C
448 op(A) = 'n', op(B) = 't', alpha=scale, beta=1
449
450 channels first:
451 grad_weight^T = scale * columns * grad_output_n^T + grad_weight^T
452 C = alpha * op(A) * op(B) + beta * C
453 op(A) = 't', op(B) = 'n', alpha=scale, beta=1
454 */
455 if (is_channels_last) {
456 cpublas::gemm(
457 /*transa=*/TransposeType::NoTranspose,
458 /*transb=*/TransposeType::Transpose,
459 /* m=*/columns.size(1),
460 /* n=*/nOutputPlane,
461 /* k=*/columns.size(0),
462 /* alpha=*/static_cast<scalar_t>(scale),
463 /* A=*/columns.const_data_ptr<scalar_t>(),
464 /* lda=*/columns.size(1),
465 /* B=*/grad_output_n.const_data_ptr<scalar_t>(),
466 /* ldb=*/nOutputPlane,
467 /* beta=*/static_cast<scalar_t>(1),
468 /* C=*/grad_weight.mutable_data_ptr<scalar_t>(),
469 /* ldc=*/columns.size(1));
470 } else {
471 cpublas::gemm(
472 /*transa=*/TransposeType::Transpose,
473 /*transb=*/TransposeType::NoTranspose,
474 /* m=*/columns.size(0),
475 /* n=*/nOutputPlane,
476 /* k=*/columns.size(1),
477 /* alpha=*/static_cast<scalar_t>(scale),
478 /* A=*/columns.const_data_ptr<scalar_t>(),
479 /* lda=*/columns.size(1),
480 /* B=*/grad_output_n.const_data_ptr<scalar_t>(),
481 /* ldb=*/columns.size(1),
482 /* beta=*/static_cast<scalar_t>(1),
483 /* C=*/grad_weight.mutable_data_ptr<scalar_t>(),
484 /* ldc=*/columns.size(0));
485 }
486 }
487
488 // Gradient of bias:
489 if (grad_bias.defined()) {
490 /*
491 Compute:
492 grad_bias = scale * grad_output_n * ones + grad_bias
493
494 where
495
496 grad_bias is viewed as grad_bias.view(nOutputPlane, 1)
497
498 ones is viewed as ones.view(outputHeight * outputWidth, 1)
499
500 grad_output_n is viewed as grad_output_n.view(nOutputPlane,
501 outputHeight * outputWidth)
502
503 gemm assumes column-major matrices:
504
505 grad_bias^T = scale * grad_output_n * ones + grad_bias^T
506 y = alpha * op(A) * x + beta * y
507 op(A) = 't', alpha=scale, beta=1
508 */
509 // The following expression is equivalent to the above
510 // gemm setup but avoids allocation of ones tensor:
511 grad_bias += grad_output_n.sum(dims);
512 /*
513 TODO: when scale != 1 is introduced then use:
514 grad_bias += scale * grad_output_n.sum(dims);
515 */
516 }
517 }
518 });
519
520 } // slow_conv_dilated_all_cpu_template
521
522 } // namespace
523
slow_conv_dilated2d_cpu(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)524 Tensor slow_conv_dilated2d_cpu(
525 const Tensor& input,
526 const Tensor& weight,
527 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
528 IntArrayRef stride_size,
529 IntArrayRef pad_size,
530 IntArrayRef dilation_size) {
531 // See [Note: hacky wrapper removal for optional tensor]
532 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
533 const Tensor& bias = *bias_maybe_owned;
534
535 bool use_channels_last = thnn_conv_use_channels_last(input, weight);
536 auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
537
538 Tensor undefined;
539 internal::slow_conv_dilated_shape_check<2>(
540 input,
541 weight,
542 bias,
543 undefined,
544 kernel_size,
545 stride_size,
546 pad_size,
547 dilation_size);
548 auto is_batch = input.dim() == 4;
549 auto options = input.options();
550 // calculate output tensor size
551 auto output_size = internal::get_output_size<2>(
552 input, weight, kernel_size, stride_size, pad_size, dilation_size);
553 // template function assumes batched tensors. unsqueeze(0) will
554 // insert batch dimension without affecting the original tensor.
555 const Tensor input_ =
556 (is_batch ? input.contiguous(memory_format) : input.contiguous().unsqueeze(0));
557 const Tensor weight_ = weight.contiguous(memory_format);
558 const Tensor bias_ = (bias.defined() ? bias.contiguous() : undefined);
559 Tensor output = at::empty(output_size, options.memory_format(memory_format));
560 Tensor output_ = (is_batch ? output : output.unsqueeze(0));
561
562 slow_conv_dilated_all_cpu_template<2>(
563 output_,
564 input_,
565 weight_,
566 bias_,
567 undefined,
568 undefined,
569 undefined,
570 undefined,
571 kernel_size,
572 stride_size,
573 pad_size,
574 dilation_size,
575 use_channels_last);
576 return output;
577 }
578
slow_conv_dilated3d_cpu(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)579 Tensor slow_conv_dilated3d_cpu(
580 const Tensor& input,
581 const Tensor& weight,
582 IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
583 IntArrayRef stride_size,
584 IntArrayRef pad_size,
585 IntArrayRef dilation_size) {
586 // See [Note: hacky wrapper removal for optional tensor]
587 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
588 const Tensor& bias = *bias_maybe_owned;
589
590 Tensor undefined;
591 internal::slow_conv_dilated_shape_check<3>(
592 input,
593 weight,
594 bias,
595 undefined,
596 kernel_size,
597 stride_size,
598 pad_size,
599 dilation_size);
600 auto is_batch = input.dim() == 5;
601 auto options = input.options();
602 // calculate output tensor size
603 auto output_size = internal::get_output_size<3>(
604 input, weight, kernel_size, stride_size, pad_size, dilation_size);
605 // template function assumes batched tensors. unsqueeze(0) will
606 // insert batch dimension without affecting the original tensor.
607 const Tensor input_ =
608 (is_batch ? input.contiguous() : input.contiguous().unsqueeze(0));
609 const Tensor weight_ = weight.contiguous();
610 const Tensor bias_ = (bias.defined() ? bias.contiguous() : undefined);
611 Tensor output = at::empty(output_size, options);
612 Tensor output_ = (is_batch ? output : output.unsqueeze(0));
613
614 slow_conv_dilated_all_cpu_template<3>(
615 output,
616 input_,
617 weight_,
618 bias_,
619 undefined,
620 undefined,
621 undefined,
622 undefined,
623 kernel_size,
624 stride_size,
625 pad_size,
626 dilation_size);
627 return output;
628 }
629
slow_conv_dilated2d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size,const std::array<bool,3ul> output_mask)630 static std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated2d_backward_cpu(
631 const Tensor& grad_output,
632 const Tensor& input,
633 const Tensor& weight,
634 IntArrayRef kernel_size,
635 IntArrayRef stride_size,
636 IntArrayRef pad_size,
637 IntArrayRef dilation_size,
638 const std::array<bool, 3ul> output_mask) {
639 bool use_channels_last = thnn_conv_use_channels_last(input, weight);
640 auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
641
642 Tensor undefined;
643 internal::slow_conv_dilated_shape_check<2>(
644 input,
645 weight,
646 undefined,
647 grad_output,
648 kernel_size,
649 stride_size,
650 pad_size,
651 dilation_size);
652 auto is_batch = input.dim() == 4;
653 auto options = grad_output.options();
654 // template function assumes batched tensors. unsqueeze(0) will
655 // insert batch dimension without affecting the original tensor.
656 const Tensor grad_output_ =
657 (is_batch ? grad_output.contiguous(memory_format)
658 : grad_output.contiguous().unsqueeze(0));
659 const Tensor input_ =
660 (is_batch ? input.contiguous(memory_format) : input.contiguous().unsqueeze(0));
661 const Tensor weight_ = weight.contiguous(memory_format);
662 // compute only gradients for which the corresponding output_mask is true:
663 Tensor grad_input =
664 (output_mask[0] ? at::empty(input.sizes(), options.memory_format(memory_format)) : undefined);
665 Tensor grad_weight =
666 (output_mask[1] ? at::empty(weight.sizes(), options.memory_format(memory_format)) : undefined);
667 Tensor grad_bias =
668 (output_mask[2] ? at::empty(weight.size(0), options) : undefined);
669 Tensor grad_input_ =
670 (output_mask[0] ? (is_batch ? grad_input : grad_input.unsqueeze(0))
671 : undefined);
672 slow_conv_dilated_all_cpu_template<2>(
673 undefined,
674 input_,
675 weight_,
676 undefined,
677 grad_output_,
678 grad_input,
679 grad_weight,
680 grad_bias,
681 kernel_size,
682 stride_size,
683 pad_size,
684 dilation_size,
685 use_channels_last);
686 return std::tie(grad_input, grad_weight, grad_bias);
687 }
688
slow_conv_dilated3d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size,const std::array<bool,3ul> output_mask)689 static std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated3d_backward_cpu(
690 const Tensor& grad_output,
691 const Tensor& input,
692 const Tensor& weight,
693 IntArrayRef kernel_size,
694 IntArrayRef stride_size,
695 IntArrayRef pad_size,
696 IntArrayRef dilation_size,
697 const std::array<bool, 3ul> output_mask) {
698 Tensor undefined;
699 internal::slow_conv_dilated_shape_check<3>(
700 input,
701 weight,
702 undefined,
703 grad_output,
704 kernel_size,
705 stride_size,
706 pad_size,
707 dilation_size);
708 auto is_batch = input.dim() == 5;
709 auto options = grad_output.options();
710 // template function assumes batched tensors. unsqueeze(0) will
711 // insert batch dimension without affecting the original tensor.
712 const Tensor grad_output_ =
713 (is_batch ? grad_output.contiguous()
714 : grad_output.contiguous().unsqueeze(0));
715 const Tensor input_ =
716 (is_batch ? input.contiguous() : input.contiguous().unsqueeze(0));
717 const Tensor weight_ = weight.contiguous();
718 // compute only gradients for which the corresponding output_mask is true:
719 Tensor grad_input =
720 (output_mask[0] ? at::empty(input.sizes(), options) : undefined);
721 Tensor grad_weight =
722 (output_mask[1] ? at::empty(weight.sizes(), options) : undefined);
723 Tensor grad_bias =
724 (output_mask[2] ? at::empty(weight.size(0), options) : undefined);
725 Tensor grad_input_ =
726 (output_mask[0] ? (is_batch ? grad_input : grad_input.unsqueeze(0))
727 : undefined);
728 slow_conv_dilated_all_cpu_template<3>(
729 undefined,
730 input_,
731 weight_,
732 undefined,
733 grad_output_,
734 grad_input,
735 grad_weight,
736 grad_bias,
737 kernel_size,
738 stride_size,
739 pad_size,
740 dilation_size);
741 return std::tie(grad_input, grad_weight, grad_bias);
742 }
743
744 REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cpu);
745 REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cpu);
746
747 } // namespace at::native
748