xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/NaiveDilatedConvolution.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/TensorUtils.h>
5 #include <ATen/native/ConvUtils.h>
6 #include <ATen/native/CPUBlas.h>
7 #include <ATen/native/DilatedConvolutionUtils.h>
8 #include <ATen/native/im2col.h>
9 #include <ATen/native/vol2col.h>
10 #include <c10/util/accumulate.h>
11 #include <c10/util/irange.h>
12 #include <tuple>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/empty.h>
19 #include <ATen/ops/slow_conv_dilated2d_native.h>
20 #include <ATen/ops/slow_conv_dilated3d_native.h>
21 #endif
22 
23 namespace at::native {
24 namespace {
25 
26 // hyper-volume to column, CPU
27 template <typename Dtype, int64_t dim>
hvol2col(const Dtype * data_hvol,const int channels,const IntArrayRef input_size,const IntArrayRef output_size,const IntArrayRef kernel_size,const IntArrayRef stride_size,const IntArrayRef pad_size,const IntArrayRef dilation_size,Dtype * data_col,bool is_channels_last=false)28 void hvol2col(
29     const Dtype* data_hvol,
30     const int channels,
31     const IntArrayRef input_size,
32     const IntArrayRef output_size,
33     const IntArrayRef kernel_size,
34     const IntArrayRef stride_size,
35     const IntArrayRef pad_size,
36     const IntArrayRef dilation_size,
37     Dtype* data_col,
38     bool is_channels_last = false) {
39   if (dim == 3) {
40     vol2col<Dtype>(
41         data_hvol,
42         channels,
43         input_size[0],
44         input_size[1],
45         input_size[2],
46         output_size[0],
47         output_size[1],
48         output_size[2],
49         kernel_size[0],
50         kernel_size[1],
51         kernel_size[2],
52         pad_size[0],
53         pad_size[1],
54         pad_size[2],
55         stride_size[0],
56         stride_size[1],
57         stride_size[2],
58         dilation_size[0],
59         dilation_size[1],
60         dilation_size[2],
61         data_col);
62   }
63   if (dim == 2) {
64     im2col<Dtype>(
65         data_hvol,
66         channels,
67         input_size[0],
68         input_size[1],
69         output_size[0],
70         output_size[1],
71         kernel_size[0],
72         kernel_size[1],
73         pad_size[0],
74         pad_size[1],
75         stride_size[0],
76         stride_size[1],
77         dilation_size[0],
78         dilation_size[1],
79         data_col,
80         is_channels_last);
81   }
82 }
83 
84 // column to hyper-volume, CPU
85 template <typename Dtype, int64_t dim>
col2hvol(const Dtype * data_col,const int channels,const IntArrayRef input_size,const IntArrayRef output_size,const IntArrayRef kernel_size,const IntArrayRef stride_size,const IntArrayRef pad_size,const IntArrayRef dilation_size,Dtype * data_hvol,bool is_channels_last=false)86 void col2hvol(
87     const Dtype* data_col,
88     const int channels,
89     const IntArrayRef input_size,
90     const IntArrayRef output_size,
91     const IntArrayRef kernel_size,
92     const IntArrayRef stride_size,
93     const IntArrayRef pad_size,
94     const IntArrayRef dilation_size,
95     Dtype* data_hvol,
96     bool is_channels_last = false) {
97   if (dim == 3) {
98     col2vol<Dtype>(
99         data_col,
100         channels,
101         input_size[0],
102         input_size[1],
103         input_size[2],
104         output_size[0],
105         output_size[1],
106         output_size[2],
107         kernel_size[0],
108         kernel_size[1],
109         kernel_size[2],
110         pad_size[0],
111         pad_size[1],
112         pad_size[2],
113         stride_size[0],
114         stride_size[1],
115         stride_size[2],
116         dilation_size[0],
117         dilation_size[1],
118         dilation_size[2],
119         data_hvol);
120   }
121   if (dim == 2) {
122     col2im<Dtype>(
123         data_col,
124         channels,
125         input_size[0],
126         input_size[1],
127         output_size[0],
128         output_size[1],
129         kernel_size[0],
130         kernel_size[1],
131         pad_size[0],
132         pad_size[1],
133         stride_size[0],
134         stride_size[1],
135         dilation_size[0],
136         dilation_size[1],
137         data_hvol,
138         is_channels_last);
139   }
140 }
141 
142 /*
143    check tensor data locations
144 */
slow_conv_dilated_location_check(const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & grad_output)145 void slow_conv_dilated_location_check(
146     const Tensor& input,
147     const Tensor& weight,
148     const Tensor& bias,
149     const Tensor& grad_output) {
150   // checking data locations of user-provided tensor arguments
151   checkBackend("slow_conv_dilated_location_check", {input, weight}, Backend::CPU);
152   if (bias.defined()) {
153     checkBackend("slow_conv_dilated_location_check", {bias}, Backend::CPU);
154   }
155   if (grad_output.defined()) {
156     checkBackend("slow_conv_dilated_location_check", {grad_output}, Backend::CPU);
157   }
158   // we are not checking the data locations of other tensor
159   // arguments such as output, grad_input, etc because of these are
160   // allocated based on input options and hence these tensors always
161   // have the same data location as of input tensor.
162 }
163 
164 /*
165   slow_conv_dilated_all_cpu_template
166 
167   Main worker. Computes tensors output, grad_input, grad_weight,
168   and/or grad_bias if defined, respectively.
169  */
170 
171 template <int64_t dim>
slow_conv_dilated_all_cpu_template(Tensor & output,const Tensor & input,const Tensor & weight,const Tensor & bias,const Tensor & grad_output,Tensor & grad_input,Tensor & grad_weight,Tensor & grad_bias,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size,bool is_channels_last=false)172 void slow_conv_dilated_all_cpu_template(
173     Tensor& output,
174     const Tensor& input,
175     const Tensor& weight,
176     const Tensor& bias,
177     const Tensor& grad_output,
178     Tensor& grad_input,
179     Tensor& grad_weight,
180     Tensor& grad_bias,
181     IntArrayRef kernel_size,
182     IntArrayRef stride_size,
183     IntArrayRef pad_size,
184     IntArrayRef dilation_size,
185     bool is_channels_last = false) {
186   slow_conv_dilated_location_check(input, weight, bias, grad_output);
187   auto options = input.options();
188   // The rear part of input tensor sizes:
189   auto input_size = input.sizes().slice(2);
190   // The rear part of output tensor sizes:
191   auto output_size = internal::get_output_size<dim>(
192       input, kernel_size, stride_size, pad_size, dilation_size);
193   int64_t batchSize = input.size(0);
194   int64_t nInputPlane = weight.size(1);
195   int64_t nOutputPlane = weight.size(0);
196   // Temporary buffer:
197   Tensor columns = at::empty({0}, options);
198   if (output.defined() || grad_weight.defined() || grad_input.defined()) {
199     const int64_t m = c10::multiply_integers(kernel_size);
200     const int64_t n = c10::multiply_integers(output_size);
201     if (is_channels_last) {
202       columns.resize_({n, m * nInputPlane});
203     } else {
204       columns.resize_({nInputPlane * m, n});
205     }
206   }
207   // Initialize
208   if (grad_weight.defined()) {
209     grad_weight.zero_();
210   }
211   if (grad_bias.defined()) {
212     grad_bias.zero_();
213   }
214   if (output.defined() && !bias.defined()) {
215     output.zero_();
216   }
217   // Helpers
218   Tensor grad_output_n;
219   std::vector<int64_t> dims(dim);
220   std::iota(dims.begin(), dims.end(), 1);
221 
222     AT_DISPATCH_FLOATING_TYPES_AND3(
223         at::ScalarType::Long, at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "slow_conv_dilated<>", [&] {
224     // For each elt in batch, do:
225     for (const auto elt : c10::irange(batchSize)) {
226       // Matrix multiply per output:
227       Tensor input_n = input.select(0, elt);
228 
229       // Output
230       if (output.defined()) {
231         Tensor output_n = output.select(0, elt);
232         if (bias.defined()) {
233           /*
234             Compute:
235 
236               output_n = bias * ones^T
237 
238             where
239 
240               bias is viewed as bias.view(nOutputPlane, 1)
241 
242               ones is viewed as ones.view(outputHeight * outputWidth, 1)
243 
244               output_n is viewed as output_n.view(nOutputPlane, outputHeight
245           * outputWidth)
246 
247           gemm assumes column-major matrices:
248 
249             output_n^T = ones * bias^T
250             C = alpha * op(A) * op(B)
251             op(A) = 't', op(B) = 'n', alpha=1, beta=0
252           */
253           // The following for-loop is equivalent to the above
254           // gemm setup but avoids allocation of ones tensor:
255           for (const auto n : c10::irange(nOutputPlane)) {
256             output_n.select(0, n).fill_(bias[n]);
257           }
258         }
259         // Extract columns:
260         hvol2col<scalar_t, dim>(
261             input_n.const_data_ptr<scalar_t>(),
262             nInputPlane,
263             input_size,
264             output_size,
265             kernel_size,
266             stride_size,
267             pad_size,
268             dilation_size,
269             columns.mutable_data_ptr<scalar_t>(),
270             is_channels_last);
271         /*
272           Compute:
273 
274             output_n = weight * columns + output_n
275 
276           where
277 
278             weight is viewed as weight.view(nOutputPlane, nInputPlane * kD *
279           kH * kW)
280 
281             columns size is (nInputPlane * kH * kW) x (outputHeight *
282           outputWidth)
283 
284             output_n is viewed as output_n.view(nOutputPlane, outputHeight *
285           outputWidth)
286 
287           gemm assumes column-major matrices:
288 
289           channels last:
290             output_n^T = weight *columns^T + output_n^T
291             C = alpha * op(A) * op(B) + beta * C
292             op(A) = 't', op(B) = 'n', alpha=1, beta=1
293 
294           channels first:
295             output_n^T = columns^T * weight^T + output_n^T
296             C = alpha * op(A) * op(B) + beta * C
297             op(A) = 'n', op(B) = 'n', alpha=1, beta=1
298         */
299         if (is_channels_last) {
300           cpublas::gemm(
301               /*transa=*/TransposeType::Transpose,
302               /*transb=*/TransposeType::NoTranspose,
303               /*     m=*/nOutputPlane,
304               /*     n=*/columns.size(0),
305               /*     k=*/columns.size(1),
306               /* alpha=*/static_cast<scalar_t>(1),
307               /*     A=*/weight.const_data_ptr<scalar_t>(),
308               /*   lda=*/columns.size(1),
309               /*     B=*/columns.const_data_ptr<scalar_t>(),
310               /*   lda=*/columns.size(1),
311               /*  beta=*/static_cast<scalar_t>(1),
312               /*     C=*/output_n.mutable_data_ptr<scalar_t>(),
313               /*   ldc=*/nOutputPlane);
314         } else {
315           cpublas::gemm(
316               /*transa=*/TransposeType::NoTranspose,
317               /*transb=*/TransposeType::NoTranspose,
318               /*     m=*/columns.size(1),
319               /*     n=*/nOutputPlane,
320               /*     k=*/columns.size(0),
321               /* alpha=*/static_cast<scalar_t>(1),
322               /*     A=*/columns.const_data_ptr<scalar_t>(),
323               /*   lda=*/columns.size(1),
324               /*     B=*/weight.const_data_ptr<scalar_t>(),
325               /*   ldb=*/columns.size(0),
326               /*  beta=*/static_cast<scalar_t>(1),
327               /*     C=*/output_n.mutable_data_ptr<scalar_t>(),
328               /*   ldc=*/columns.size(1));
329         }
330       } else {
331         // All gradients
332         grad_output_n = grad_output.select(0, elt);
333       }
334 
335       // Gradient of input:
336       if (grad_input.defined()) {
337         /*
338           Compute:
339 
340             columns = weight^T * grad_output_n
341 
342           where
343 
344             weight is viewed as weight.view(nOutputPlane, nInputPlane * kH *
345           kW)
346 
347             grad_output_n is viewed as grad_output_n.view(nOutputPlane,
348           outputHeight * outputWidth)
349 
350             columns size is (nInputPlane * kH * kW) x (outputHeight *
351           outputWidth)
352 
353           gemm assumes column-major matrices:
354 
355           channels last:
356             columns^T = weight^T * grad_output_n^T
357             C = alpha * op(A) * op(B) + beta * C
358             op(A) = 'n', op(B) = 'n', alpha=1, beta=0
359 
360           channels first:
361             columns^T = grad_output_n^T * weight
362             C = alpha * op(A) * op(B) + beta * C
363             op(A) = 'n', op(B) = 't', alpha=1, beta=0
364          */
365         if (is_channels_last) {
366           cpublas::gemm(
367               /*transa=*/TransposeType::NoTranspose,
368               /*transb=*/TransposeType::NoTranspose,
369               /*     m=*/columns.size(1),
370               /*     n=*/columns.size(0),
371               /*     k=*/nOutputPlane,
372               /* alpha=*/static_cast<scalar_t>(1),
373               /*     A=*/weight.const_data_ptr<scalar_t>(),
374               /*   lda=*/columns.size(1),
375               /*     B=*/grad_output_n.const_data_ptr<scalar_t>(),
376               /*   ldb=*/nOutputPlane,
377               /*  beta=*/static_cast<scalar_t>(0),
378               /*     C=*/columns.mutable_data_ptr<scalar_t>(),
379               /*   ldc=*/columns.size(1));
380         } else {
381           cpublas::gemm(
382               /*transa=*/TransposeType::NoTranspose,
383               /*transb=*/TransposeType::Transpose,
384               /*     m=*/columns.size(1),
385               /*     n=*/columns.size(0),
386               /*     k=*/nOutputPlane,
387               /* alpha=*/static_cast<scalar_t>(1),
388               /*     A=*/grad_output_n.const_data_ptr<scalar_t>(),
389               /*   lda=*/columns.size(1),
390               /*     B=*/weight.const_data_ptr<scalar_t>(),
391               /*   ldb=*/columns.size(0),
392               /*  beta=*/static_cast<scalar_t>(0),
393               /*     C=*/columns.mutable_data_ptr<scalar_t>(),
394               /*   ldc=*/columns.size(1));
395         }
396         // Unpack columns back into input:
397         Tensor grad_input_n = grad_input.select(0, elt);
398 
399         col2hvol<scalar_t, dim>(
400             columns.data_ptr<scalar_t>(),
401             nInputPlane,
402             input_size,
403             output_size,
404             kernel_size,
405             stride_size,
406             pad_size,
407             dilation_size,
408             grad_input_n.data_ptr<scalar_t>(),
409             is_channels_last);
410       }
411 
412       // Gradient of weight:
413       if (grad_weight.defined()) {
414         // Extract columns:
415         hvol2col<scalar_t, dim>(
416             input_n.const_data_ptr<scalar_t>(),
417             nInputPlane,
418             input_size,
419             output_size,
420             kernel_size,
421             stride_size,
422             pad_size,
423             dilation_size,
424             columns.mutable_data_ptr<scalar_t>(),
425             is_channels_last);
426         scalar_t scale = 1; // TODO: expose as argument?
427         /*
428           Compute:
429 
430             grad_weight = scale * grad_output_n * columns^T + grad_weight
431 
432           where
433 
434             grad_output_n is viewed as grad_output_n.view(nOutputPlane,
435           outputHeight * outputWidth)
436 
437             columns size is (nInputPlane * kD * kH * kW) x (outputHeight *
438           outputWidth)
439 
440             grad_weight is viewed as grad_weight.view(nOutputPlane,
441           nInputPlane * kH * kW)
442 
443           gemm assumes column-major matrices:
444 
445           channels last:
446             grad_weight^T = scale * columns^T * grad_output_n + grad_weight^T
447             C = alpha * op(A) * op(B) + beta * C
448             op(A) = 'n', op(B) = 't', alpha=scale, beta=1
449 
450           channels first:
451             grad_weight^T = scale * columns * grad_output_n^T + grad_weight^T
452             C = alpha * op(A) * op(B) + beta * C
453             op(A) = 't', op(B) = 'n', alpha=scale, beta=1
454         */
455         if (is_channels_last) {
456           cpublas::gemm(
457               /*transa=*/TransposeType::NoTranspose,
458               /*transb=*/TransposeType::Transpose,
459               /*     m=*/columns.size(1),
460               /*     n=*/nOutputPlane,
461               /*     k=*/columns.size(0),
462               /* alpha=*/static_cast<scalar_t>(scale),
463               /*     A=*/columns.const_data_ptr<scalar_t>(),
464               /*   lda=*/columns.size(1),
465               /*     B=*/grad_output_n.const_data_ptr<scalar_t>(),
466               /*   ldb=*/nOutputPlane,
467               /*  beta=*/static_cast<scalar_t>(1),
468               /*     C=*/grad_weight.mutable_data_ptr<scalar_t>(),
469               /*   ldc=*/columns.size(1));
470         } else {
471           cpublas::gemm(
472               /*transa=*/TransposeType::Transpose,
473               /*transb=*/TransposeType::NoTranspose,
474               /*     m=*/columns.size(0),
475               /*     n=*/nOutputPlane,
476               /*     k=*/columns.size(1),
477               /* alpha=*/static_cast<scalar_t>(scale),
478               /*     A=*/columns.const_data_ptr<scalar_t>(),
479               /*   lda=*/columns.size(1),
480               /*     B=*/grad_output_n.const_data_ptr<scalar_t>(),
481               /*   ldb=*/columns.size(1),
482               /*  beta=*/static_cast<scalar_t>(1),
483               /*     C=*/grad_weight.mutable_data_ptr<scalar_t>(),
484               /*   ldc=*/columns.size(0));
485         }
486       }
487 
488       // Gradient of bias:
489       if (grad_bias.defined()) {
490         /*
491           Compute:
492             grad_bias = scale * grad_output_n * ones + grad_bias
493 
494           where
495 
496             grad_bias is viewed as grad_bias.view(nOutputPlane, 1)
497 
498             ones is viewed as ones.view(outputHeight * outputWidth, 1)
499 
500             grad_output_n is viewed as grad_output_n.view(nOutputPlane,
501           outputHeight * outputWidth)
502 
503           gemm assumes column-major matrices:
504 
505             grad_bias^T = scale * grad_output_n * ones + grad_bias^T
506             y = alpha * op(A) * x + beta * y
507             op(A) = 't', alpha=scale, beta=1
508          */
509         // The following expression is equivalent to the above
510         // gemm setup but avoids allocation of ones tensor:
511         grad_bias += grad_output_n.sum(dims);
512         /*
513           TODO: when scale != 1 is introduced then use:
514             grad_bias += scale * grad_output_n.sum(dims);
515          */
516       }
517     }
518   });
519 
520 } // slow_conv_dilated_all_cpu_template
521 
522 } // namespace
523 
slow_conv_dilated2d_cpu(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)524 Tensor slow_conv_dilated2d_cpu(
525     const Tensor& input,
526     const Tensor& weight,
527     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
528     IntArrayRef stride_size,
529     IntArrayRef pad_size,
530     IntArrayRef dilation_size) {
531   // See [Note: hacky wrapper removal for optional tensor]
532   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
533   const Tensor& bias = *bias_maybe_owned;
534 
535   bool use_channels_last = thnn_conv_use_channels_last(input, weight);
536   auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
537 
538   Tensor undefined;
539   internal::slow_conv_dilated_shape_check<2>(
540       input,
541       weight,
542       bias,
543       undefined,
544       kernel_size,
545       stride_size,
546       pad_size,
547       dilation_size);
548   auto is_batch = input.dim() == 4;
549   auto options = input.options();
550   // calculate output tensor size
551   auto output_size = internal::get_output_size<2>(
552       input, weight, kernel_size, stride_size, pad_size, dilation_size);
553   // template function assumes batched tensors.  unsqueeze(0) will
554   // insert batch dimension without affecting the original tensor.
555   const Tensor input_ =
556       (is_batch ? input.contiguous(memory_format) : input.contiguous().unsqueeze(0));
557   const Tensor weight_ = weight.contiguous(memory_format);
558   const Tensor bias_ = (bias.defined() ? bias.contiguous() : undefined);
559   Tensor output = at::empty(output_size, options.memory_format(memory_format));
560   Tensor output_ = (is_batch ? output : output.unsqueeze(0));
561 
562   slow_conv_dilated_all_cpu_template<2>(
563       output_,
564       input_,
565       weight_,
566       bias_,
567       undefined,
568       undefined,
569       undefined,
570       undefined,
571       kernel_size,
572       stride_size,
573       pad_size,
574       dilation_size,
575       use_channels_last);
576   return output;
577 }
578 
slow_conv_dilated3d_cpu(const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,const std::optional<Tensor> & bias_opt,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size)579 Tensor slow_conv_dilated3d_cpu(
580     const Tensor& input,
581     const Tensor& weight,
582     IntArrayRef kernel_size, const std::optional<Tensor>& bias_opt,
583     IntArrayRef stride_size,
584     IntArrayRef pad_size,
585     IntArrayRef dilation_size) {
586   // See [Note: hacky wrapper removal for optional tensor]
587   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
588   const Tensor& bias = *bias_maybe_owned;
589 
590   Tensor undefined;
591   internal::slow_conv_dilated_shape_check<3>(
592       input,
593       weight,
594       bias,
595       undefined,
596       kernel_size,
597       stride_size,
598       pad_size,
599       dilation_size);
600   auto is_batch = input.dim() == 5;
601   auto options = input.options();
602   // calculate output tensor size
603   auto output_size = internal::get_output_size<3>(
604       input, weight, kernel_size, stride_size, pad_size, dilation_size);
605   // template function assumes batched tensors.  unsqueeze(0) will
606   // insert batch dimension without affecting the original tensor.
607   const Tensor input_ =
608       (is_batch ? input.contiguous() : input.contiguous().unsqueeze(0));
609   const Tensor weight_ = weight.contiguous();
610   const Tensor bias_ = (bias.defined() ? bias.contiguous() : undefined);
611   Tensor output = at::empty(output_size, options);
612   Tensor output_ = (is_batch ? output : output.unsqueeze(0));
613 
614   slow_conv_dilated_all_cpu_template<3>(
615       output,
616       input_,
617       weight_,
618       bias_,
619       undefined,
620       undefined,
621       undefined,
622       undefined,
623       kernel_size,
624       stride_size,
625       pad_size,
626       dilation_size);
627   return output;
628 }
629 
slow_conv_dilated2d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size,const std::array<bool,3ul> output_mask)630 static std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated2d_backward_cpu(
631     const Tensor& grad_output,
632     const Tensor& input,
633     const Tensor& weight,
634     IntArrayRef kernel_size,
635     IntArrayRef stride_size,
636     IntArrayRef pad_size,
637     IntArrayRef dilation_size,
638     const std::array<bool, 3ul> output_mask) {
639   bool use_channels_last = thnn_conv_use_channels_last(input, weight);
640   auto memory_format = use_channels_last ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous;
641 
642   Tensor undefined;
643   internal::slow_conv_dilated_shape_check<2>(
644       input,
645       weight,
646       undefined,
647       grad_output,
648       kernel_size,
649       stride_size,
650       pad_size,
651       dilation_size);
652   auto is_batch = input.dim() == 4;
653   auto options = grad_output.options();
654   // template function assumes batched tensors.  unsqueeze(0) will
655   // insert batch dimension without affecting the original tensor.
656   const Tensor grad_output_ =
657       (is_batch ? grad_output.contiguous(memory_format)
658                 : grad_output.contiguous().unsqueeze(0));
659   const Tensor input_ =
660       (is_batch ? input.contiguous(memory_format) : input.contiguous().unsqueeze(0));
661   const Tensor weight_ = weight.contiguous(memory_format);
662   // compute only gradients for which the corresponding output_mask is true:
663   Tensor grad_input =
664       (output_mask[0] ? at::empty(input.sizes(), options.memory_format(memory_format)) : undefined);
665   Tensor grad_weight =
666       (output_mask[1] ? at::empty(weight.sizes(), options.memory_format(memory_format)) : undefined);
667   Tensor grad_bias =
668       (output_mask[2] ? at::empty(weight.size(0), options) : undefined);
669   Tensor grad_input_ =
670       (output_mask[0] ? (is_batch ? grad_input : grad_input.unsqueeze(0))
671                       : undefined);
672   slow_conv_dilated_all_cpu_template<2>(
673       undefined,
674       input_,
675       weight_,
676       undefined,
677       grad_output_,
678       grad_input,
679       grad_weight,
680       grad_bias,
681       kernel_size,
682       stride_size,
683       pad_size,
684       dilation_size,
685       use_channels_last);
686   return std::tie(grad_input, grad_weight, grad_bias);
687 }
688 
slow_conv_dilated3d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef kernel_size,IntArrayRef stride_size,IntArrayRef pad_size,IntArrayRef dilation_size,const std::array<bool,3ul> output_mask)689 static std::tuple<Tensor, Tensor, Tensor> slow_conv_dilated3d_backward_cpu(
690     const Tensor& grad_output,
691     const Tensor& input,
692     const Tensor& weight,
693     IntArrayRef kernel_size,
694     IntArrayRef stride_size,
695     IntArrayRef pad_size,
696     IntArrayRef dilation_size,
697     const std::array<bool, 3ul> output_mask) {
698   Tensor undefined;
699   internal::slow_conv_dilated_shape_check<3>(
700       input,
701       weight,
702       undefined,
703       grad_output,
704       kernel_size,
705       stride_size,
706       pad_size,
707       dilation_size);
708   auto is_batch = input.dim() == 5;
709   auto options = grad_output.options();
710   // template function assumes batched tensors.  unsqueeze(0) will
711   // insert batch dimension without affecting the original tensor.
712   const Tensor grad_output_ =
713       (is_batch ? grad_output.contiguous()
714                 : grad_output.contiguous().unsqueeze(0));
715   const Tensor input_ =
716       (is_batch ? input.contiguous() : input.contiguous().unsqueeze(0));
717   const Tensor weight_ = weight.contiguous();
718   // compute only gradients for which the corresponding output_mask is true:
719   Tensor grad_input =
720       (output_mask[0] ? at::empty(input.sizes(), options) : undefined);
721   Tensor grad_weight =
722       (output_mask[1] ? at::empty(weight.sizes(), options) : undefined);
723   Tensor grad_bias =
724       (output_mask[2] ? at::empty(weight.size(0), options) : undefined);
725   Tensor grad_input_ =
726       (output_mask[0] ? (is_batch ? grad_input : grad_input.unsqueeze(0))
727                       : undefined);
728   slow_conv_dilated_all_cpu_template<3>(
729       undefined,
730       input_,
731       weight_,
732       undefined,
733       grad_output_,
734       grad_input,
735       grad_weight,
736       grad_bias,
737       kernel_size,
738       stride_size,
739       pad_size,
740       dilation_size);
741   return std::tie(grad_input, grad_weight, grad_bias);
742 }
743 
744 REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated2d_backward_stub, &slow_conv_dilated2d_backward_cpu);
745 REGISTER_ALL_CPU_DISPATCH(slow_conv_dilated3d_backward_stub, &slow_conv_dilated3d_backward_cpu);
746 
747 } // namespace at::native
748