xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Convolution.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorOperators.h>
6 #include <ATen/native/ConvolutionMM3d.h>
7 #include <ATen/native/ConvUtils.h>
8 #include <ATen/native/Pool.h>
9 #include <ATen/native/cpu/DepthwiseConvKernel.h>
10 #include <ATen/native/utils/ParamUtils.h>
11 #include <ATen/native/xnnpack/Engine.h>
12 #include <c10/core/GradMode.h>
13 #include <c10/util/accumulate.h>
14 #include <c10/util/irange.h>
15 #include <c10/macros/Macros.h>
16 #include <limits>
17 #include <utility>
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #else
22 #include <ATen/ops/permute.h>
23 #endif
24 
25 #if AT_NNPACK_ENABLED()
26 #include <nnpack.h>
27 #endif
28 
29 #if AT_MKLDNN_ENABLED()
30 #include <ATen/native/mkldnn/Utils.h>
31 #endif
32 
33 #ifndef AT_PER_OPERATOR_HEADERS
34 #include <ATen/Functions.h>
35 #include <ATen/NativeFunctions.h>
36 #else
37 #include <ATen/ops/_conv_depthwise2d.h>
38 #include <ATen/ops/_convolution.h>
39 #include <ATen/ops/_convolution_double_backward_native.h>
40 #include <ATen/ops/_convolution_mode.h>
41 #include <ATen/ops/_convolution_mode_native.h>
42 #include <ATen/ops/_convolution_native.h>
43 #include <ATen/ops/_mps_convolution.h>
44 #include <ATen/ops/_mps_convolution_transpose.h>
45 #include <ATen/ops/_nnpack_available.h>
46 #include <ATen/ops/_nnpack_spatial_convolution.h>
47 #include <ATen/ops/_slow_conv2d_backward.h>
48 #include <ATen/ops/_unsafe_view.h>
49 #include <ATen/ops/cat.h>
50 #include <ATen/ops/constant_pad_nd.h>
51 #include <ATen/ops/conv1d_native.h>
52 #include <ATen/ops/conv2d_native.h>
53 #include <ATen/ops/conv3d_native.h>
54 #include <ATen/ops/conv_depthwise3d.h>
55 #include <ATen/ops/conv_transpose1d_native.h>
56 #include <ATen/ops/conv_transpose2d_native.h>
57 #include <ATen/ops/conv_transpose3d_native.h>
58 #include <ATen/ops/convolution.h>
59 #include <ATen/ops/convolution_backward_native.h>
60 #include <ATen/ops/convolution_backward_overrideable.h>
61 #include <ATen/ops/convolution_backward_overrideable_native.h>
62 #include <ATen/ops/convolution_native.h>
63 #include <ATen/ops/convolution_overrideable.h>
64 #include <ATen/ops/convolution_overrideable_native.h>
65 #include <ATen/ops/cudnn_convolution.h>
66 #include <ATen/ops/cudnn_convolution_transpose.h>
67 #include <ATen/ops/empty.h>
68 #include <ATen/ops/empty_like.h>
69 #include <ATen/ops/empty_native.h>
70 #include <ATen/ops/miopen_convolution.h>
71 #include <ATen/ops/miopen_convolution_transpose.h>
72 #include <ATen/ops/miopen_depthwise_convolution.h>
73 #include <ATen/ops/mkldnn_convolution.h>
74 #include <ATen/ops/mps_convolution_backward.h>
75 #include <ATen/ops/mps_convolution_transpose_backward.h>
76 #include <ATen/ops/slow_conv3d.h>
77 #include <ATen/ops/slow_conv_dilated2d.h>
78 #include <ATen/ops/slow_conv_dilated3d.h>
79 #include <ATen/ops/slow_conv_transpose2d.h>
80 #include <ATen/ops/slow_conv_transpose3d.h>
81 #include <ATen/ops/thnn_conv2d.h>
82 #include <ATen/ops/view_as_real.h>
83 #include <ATen/ops/zeros.h>
84 #include <ATen/ops/zeros_like.h>
85 #endif
86 
87 constexpr int MIOPEN_DIM_MAX = 5;
88 
89 namespace at::native {
90 
91 
92 static bool conv_benchmark_empty_cache = true;
93 
94 // Check workload to activate fast depthwise FP16 cudnn conv kernels
95 template <typename T>
check_cudnn_depthwise_workload(const at::Tensor & input,T stride)96 bool check_cudnn_depthwise_workload(const at::Tensor& input, T stride) {
97   auto w = at::symint::size<T>(input, 3);  // same as h
98   auto ch = at::symint::size<T>(input, 1);
99   auto bs = at::symint::size<T>(input, 0);
100   if (stride==1) {
101     if (w >= 7) {
102       // All batch sizes and nb_channels
103       if (w >= 112) {
104         return true;
105       }
106 
107       // large nb_channels
108       if (ch >= 1024) {
109         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
110         if (w >= 56) {
111           return true;
112         } else if (bs >= 32) {
113           return true;
114         }
115       }
116 
117       // batch_size specific
118       if (bs >= 128) {
119         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
120         if (ch >= 512) {
121           return true;
122         } else if (ch >= 64) {
123           if (w >= 14) {
124             return true;
125           }
126         } else if ((ch >= 32) && (w >=28)) {
127           return true;
128         }
129       } else if (bs >= 64) {
130         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
131         if ((ch >= 256) && (w >= 14)) {
132           return true;
133         } else if ((ch >= 32) && (w >= 28)) {
134           return true;
135         }
136       } else if (bs >= 32) {
137         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
138         if ((ch >= 256) && (w >= 14)) {
139           return true;
140         } else if ((ch >= 128) && (w >= 28)) {
141           return true;
142         } else if ((ch >= 32) && (w >= 56)) {
143           return true;
144         }
145       } else if (bs >= 16) {
146         if ((ch >= 1024) && (w >= 14)) {
147           return true;
148         }
149         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
150         if ((ch >= 256) && (w >= 28)) {
151           return true;
152         } else if ((ch >= 32) && (w >= 56)) {
153           return true;
154         }
155       } else if (bs >= 8) {
156         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
157         if ((ch >= 512) && (w >= 28)) {
158           return true;
159         } else if ((ch >= 64) && (w >= 56)) {
160           return true;
161         }
162       }
163     }
164   } else if (stride==2) {
165     if (ch < 256) {
166       return false;
167     }
168 
169     if (w >= 7) {
170       if (bs >= 128) {
171         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
172         if (ch >= 1024) {
173           return true;
174         } else if ((ch >= 512) && (w >= 14)) {
175           return true;
176         } else if (w >= 28) {
177           return true;
178         }
179       } else if (bs >= 64) {
180         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
181         if ((ch >= 512) && (w >= 14)) {
182           return true;
183         } else if (w >= 28) {
184           return true;
185         }
186       } else if (bs >= 32) {
187         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
188         if ((ch >= 1024) && (w >= 14)) {
189           return true;
190         } else if (w >= 28) {
191           return true;
192         }
193       } else if (bs >= 16) {
194         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
195         if ((ch >= 512) && (w >= 28)) {
196           return true;
197         } else if (w >= 56) {
198           return true;
199         }
200       } else if (bs >= 8) {
201         // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
202         if ((ch >= 1024) && (w >= 28)) {
203           return true;
204         } else if (w >= 56) {
205           return true;
206         }
207       } else if (bs >= 1) {
208         if ((ch >= 512) && (w >=112)) {
209           return true;
210         }
211       }
212     }
213   }
214   return false;
215 }
216 
217 // simplified version for cudnn 8.2 and above
218 template <typename T>
check_cudnn_depthwise_workload_with_filter(const at::Tensor & input,T stride,const at::Tensor & weight)219 bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, T stride, const at::Tensor& weight) {
220   // 1D conv
221   if(at::symint::size<T>(input, 2) == 1 && stride == 1){
222     return true;
223   }
224 
225   // 2d conv
226   // only square filters
227   if (at::symint::size<T>(weight, 2) != at::symint::size<T>(weight, 3)) return false;
228   auto filter = at::symint::size<T>(weight, 3);
229   // only 1/3/5 filter
230   if (filter != 1 && filter != 3 && filter != 5) return false;
231   // we don't enforce square input but only check width to reduce heuristic space
232   if (at::symint::size<T>(input, 3) < 7) return false; // min width 7
233   auto w = at::symint::size<T>(input, 3);
234   // only 1/2 stride, use cudnn for all stride 1
235   if (stride == 1) return true;
236   if (stride != 2) return false;
237 
238   auto ch = at::symint::size<T>(input, 1);
239   auto bs = at::symint::size<T>(input, 0);
240   // special case since bs1 show good perf in lots of cases
241   if (bs == 1) {
242     if (filter == 1 && w <= 28) return true;
243     if (filter == 3 || filter == 5) return true;
244   } else {
245     if (filter == 1 && bs <= 16 && ch >= 128 && w <= 7) return true;
246     if (filter == 3 || filter == 5) {
247       if ((ch >= 512) || (ch >= 256 && w >= 28)) return true;
248     }
249   }
250   return false;
251 }
252 
253 
254 #if defined(C10_MOBILE)
xnnpack_use_convolution2d(const Tensor & input,const Tensor & weight,const at::OptionalIntArrayRef bias_sizes_opt,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool transposed)255 static bool xnnpack_use_convolution2d(
256     const Tensor& input,
257     const Tensor& weight,
258     const at::OptionalIntArrayRef bias_sizes_opt,
259     const IntArrayRef padding,
260     const IntArrayRef stride,
261     const IntArrayRef dilation,
262     const int64_t groups,
263     const bool transposed) {
264   return xnnpack::use_convolution2d(input, weight, bias_sizes_opt, padding, stride, dilation, groups, transposed);
265 }
266 
xnnpack_use_convolution2d(const Tensor & input,const Tensor & weight,const at::OptionalSymIntArrayRef bias_sizes_opt,const SymIntArrayRef padding,const SymIntArrayRef stride,const SymIntArrayRef dilation,const c10::SymInt groups,const bool transposed)267 static bool xnnpack_use_convolution2d(
268     const Tensor& input,
269     const Tensor& weight,
270     const at::OptionalSymIntArrayRef bias_sizes_opt,
271     const SymIntArrayRef padding,
272     const SymIntArrayRef stride,
273     const SymIntArrayRef dilation,
274     const c10::SymInt groups,
275     const bool transposed) {
276   // Never use xnnpack for symbolic tracing
277   return false;
278 }
279 #endif
280 
281 // This struct is templated so that we can run backend selection in a dynamic
282 // shapes context; all of the real kernel selection in eager mode runs with
283 // int64_t
284 template <typename T>
285 struct ConvParams {
286   std::vector<T> stride;
287   std::vector<T> padding;
288   std::vector<T> dilation;
289   bool transposed{};
290   std::vector<T> output_padding;
291   T groups{};
292   bool benchmark{};
293   bool deterministic{};
294   bool cudnn_enabled{};
295   bool allow_tf32{};
296 
is_stridedat::native::ConvParams297   bool is_strided() const {
298     bool is_strided = false;
299     for (const auto& s : stride) {
300       is_strided |= (s != 1);
301     }
302     return is_strided;
303   }
304 
is_dilatedat::native::ConvParams305   bool is_dilated() const {
306     bool is_dilated = false;
307     for (const auto& d : dilation) {
308       is_dilated |= (d != 1);
309     }
310     return is_dilated;
311   }
312 
is_paddedat::native::ConvParams313   bool is_padded() const {
314     bool is_padded = false;
315     for (auto p : padding) {
316       is_padded |= (p != 0);
317     }
318     return is_padded;
319   }
320 
is_output_padding_negat::native::ConvParams321   bool is_output_padding_neg() const {
322     bool is_non_neg = false;
323     for (const auto& p : output_padding) {
324       is_non_neg |= (p < 0);
325     }
326     return is_non_neg;
327   }
328 
is_output_padding_bigat::native::ConvParams329   bool is_output_padding_big() const {
330     bool is_big = false;
331     for (auto i: c10::irange(output_padding.size())) {
332       is_big |= (output_padding[i] >= stride[i]);
333     }
334     return is_big;
335   }
336 
is_padding_negat::native::ConvParams337   bool is_padding_neg() const {
338     bool is_non_neg = false;
339     for (const auto& p : padding) {
340       is_non_neg |= (p < 0);
341     }
342     return is_non_neg;
343   }
344 
is_dilation_negat::native::ConvParams345   bool is_dilation_neg() const {
346     bool is_non_neg = false;
347     for (const auto& p : dilation) {
348       is_non_neg |= (p < 0);
349     }
350     return is_non_neg;
351   }
352 
is_stride_nonposat::native::ConvParams353   bool is_stride_nonpos() const {
354     bool is_nonpos = false;
355     for (const auto& s : stride) {
356       is_nonpos |= (s <= 0);
357     }
358     return is_nonpos;
359   }
360 
view1d_as_2dat::native::ConvParams361   void view1d_as_2d() {
362     if (stride.size() == 1) {
363       stride.insert(stride.begin(), 1);
364       padding.insert(padding.begin(), 0);
365       dilation.insert(dilation.begin(), 1);
366       output_padding.insert(output_padding.begin(), 0);
367     }
368   }
369 
use_cpu_depthwise3x3_winogradat::native::ConvParams370   bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const std::optional<at::Tensor>& bias) const {
371 #if defined(__ARM_NEON__) || (defined(__riscv_v_intrinsic) && __riscv_v_intrinsic>=12000)
372     // Currently only 3x3 depthwise convolutions on tensors of float are supported.
373     return (input.ndimension() == 4) &&
374            (at::symint::size<T>(input, 1) == groups) &&
375            (weight.ndimension() == 4 ) &&
376            (at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0) &&
377            (at::symint::size<T>(weight, 1) == 1) &&
378            (at::symint::size<T>(weight, 2) == 3) &&
379            (at::symint::size<T>(weight, 3) == 3) &&
380            (input.device().is_cpu()) &&
381            (input.scalar_type() == at::kFloat) &&
382            input.is_contiguous() &&
383            (weight.device().is_cpu()) &&
384            (weight.scalar_type() == at::kFloat) &&
385            weight.is_contiguous() &&
386            (!bias.has_value() || bias->is_contiguous()) &&
387            !is_strided() &&
388            !is_dilated() &&
389            !transposed;
390 #else
391     return false;
392 #endif
393   }
394 
needs_64bit_indexing_no_splitat::native::ConvParams395   bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const {
396     constexpr int64_t int_max = std::numeric_limits<int>::max();
397     auto numel_input = at::symint::numel<T>(input);
398     // empty input
399     if (numel_input == 0) {
400       return false;
401     }
402     // input size can not be reduced to the range of int by splitting the batch dim
403     auto n = at::symint::size<T>(input, 0);
404     if (numel_input / n > int_max) {
405       return true;
406     }
407     // output size can not be reduced to the range of int by splitting the batch dim
408     T outsize = 1;
409     if (transposed) {
410       auto o = conv_input_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, output_padding, stride, dilation, groups);
411       outsize = c10::multiply_integers(o.begin() + 1, o.end());
412     } else {
413       auto o = conv_output_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, stride, dilation);
414       outsize = c10::multiply_integers(o.begin() + 1, o.end());
415     }
416     return outsize > int_max;
417   }
418 
use_cudnnat::native::ConvParams419   bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const {
420   // Note [Mobile check segfaults]
421   // cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest
422   // that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how)
423 #if !defined(C10_MOBILE)
424     if (!detail::getCUDAHooks().compiledWithCuDNN()) {
425       return false;
426     }
427     if (needs_64bit_indexing_no_split(input, weight)) {
428       static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
429       if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
430         TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"
431                         " if the V8 API is not enabled or before cuDNN version 9.3+."
432                         " Consider upgrading cuDNN and/or enabling the V8 API for better efficiency.");
433         return false;
434       }
435     }
436     if (!input.is_cuda() || !cudnn_enabled) {
437       return false;
438     }
439     if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) {
440       if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) {
441         return false;
442       }
443     }
444     if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) {
445       // bypass dilation checks for channels_last convolution
446       if (deterministic && is_dilated()) {
447         // cudnn doesn't support deterministic dilated convolution fully yet
448         return false;
449       }
450       if (is_dilated()) {
451         return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big();
452       }
453     }
454     return !is_output_padding_big();
455 #else
456     return false;
457 #endif
458   }
459 
460   // Use cudnn for FP16 depthwise convolutions
use_cudnn_depthwiseat::native::ConvParams461   bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const  {
462     if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) {
463       // always use cudnn_depthwise for channels_last format
464       return true;
465     }
466     if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
467       long cudnn_version = detail::getCUDAHooks().versionCuDNN();
468       if (cudnn_version >= 8200) {
469         bool kernel_cond =  (use_cudnn(input, weight) &&
470                              input.scalar_type() == kHalf && // only for FP16
471                              weight.scalar_type() == kHalf &&
472                              is_depthwise(input, weight) &&
473                              input.ndimension() == 4 &&   // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
474                              !is_dilated() && // no dilation supported
475                              (stride[0] == stride[1] || at::symint::size<T>(input, 2) == 1) && // square or 1d
476                              at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
477         if (kernel_cond) {
478           return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1], weight);
479         }
480       }
481       // keep (7600 <= cudnn < 8200) code unchanged
482       bool kernel_cond =  (cudnn_version >= 7600 &&
483                            use_cudnn(input, weight) &&
484                            input.scalar_type() == kHalf && // only for FP16
485                            weight.scalar_type() == kHalf &&
486                            is_depthwise(input, weight) &&
487                            input.ndimension() == 4 &&   // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
488                            at::symint::size<T>(weight, 2) == at::symint::size<T>(weight, 3) && // only square kernels
489                            at::symint::size<T>(input, 2) >= 7 && // min width/height 7
490                            !is_dilated() && // no dilation supported
491                            stride[0] == stride[1] && // equal strides
492                            ((at::symint::size<T>(weight, 3) == 3) || (at::symint::size<T>(weight, 3) == 1)) &&
493                            at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
494       if (kernel_cond) {
495         return check_cudnn_depthwise_workload<T>(input, stride[0]);
496       } else {
497         return false;
498       }
499     } else {
500       return false;
501     }
502   }
503 
use_miopenat::native::ConvParams504   bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const  {
505     if (needs_64bit_indexing_no_split(input, weight)) {
506       return false;
507     }
508     return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16))
509            && cudnn_enabled
510            && input.is_cuda()
511            && detail::getCUDAHooks().compiledWithMIOpen()
512            && input.dim() <= MIOPEN_DIM_MAX
513            && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1
514            ;
515   }
use_mkldnnat::native::ConvParams516   bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const  {
517 #if AT_MKLDNN_ENABLED()
518     if (!at::globalContext().userEnabledMkldnn()) {
519       return false;
520     }
521     if (transposed && is_output_padding_big()) {
522       return false;
523     }
524     if (input.device().is_cpu() &&
525         ((input.scalar_type() == at::kBFloat16 && mkldnn_bf16_device_check()) ||
526          (input.scalar_type() == at::kHalf && mkldnn_fp16_device_check()))) {
527       return true;
528     }
529     return (input.is_mkldnn()) || // input is mkldnn Tensor
530       (input.device().is_cpu() &&
531        input.scalar_type() == kFloat && // only on CPU Float Tensors
532        // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded,
533        // but THNN is faster when single-threaded.
534        (is_strided() || is_dilated() || at::symint::size<T>(input, 0) >= 16 ||
535         at::symint::size<T>(weight, -1) != 1 || at::symint::size<T>(weight, -2) != 1 || at::get_num_threads() > 1) &&
536        (groups > 1
537         || (at::symint::size<T>(weight, -1) > 3 && at::symint::size<T>(weight, -2) > 3)
538         || at::symint::size<T>(input, 0) > 1
539         || at::symint::size<T>(input, 0)*at::symint::size<T>(input, 1)*at::symint::size<T>(input, 2)*at::symint::size<T>(input, 3) > 20480) // for some case, native is faster
540         );
541 
542 #endif
543     return false;
544   }
use_nnpackat::native::ConvParams545   bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const  {
546 #if AT_NNPACK_ENABLED()
547     return at::globalContext().userEnabledNNPACK() &&
548            at::_nnpack_available() &&
549            input.device().is_cpu() &&
550            input.scalar_type() == kFloat && // only on CPU Float Tensors
551            !is_dilated() && // or dilation
552            !transposed &&   // or transposed tensors
553            input.ndimension() == 4 && // must be in NCHW format
554            weight.ndimension() == 4 &&
555            (at::symint::size<T>(weight, 2) < 17) && (at::symint::size<T>(weight, 3) < 17) && // NNPACK only supports kernels up to 16x16
556            (padding[0] < at::symint::size<T>(weight, 2)) && (padding[1] < at::symint::size<T>(weight, 3)) // NNPACK only supports padding < kernel_size. See https://github.com/pytorch/pytorch/issues/90142.
557 #if !defined(C10_MOBILE)
558            && at::symint::size<T>(input, 0) >= 16 // ensure large enough batch size to ensure perf, tuneable
559 #endif
560        ;
561 #endif
562     return false;
563   }
use_xnnpackat::native::ConvParams564   bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight,
565                    const at::OptionalArrayRef<T> bias_sizes_opt) const {
566 #if defined(C10_MOBILE)
567     if (!transposed) {
568       // NB: for the call here, it MATTERS that we are templated. If you
569       // untemplate this to always use SymInt, the function
570       // xnnpack_use_convolution2d will always return false
571       return (at::symint::size<T>(input, 1) == groups) &&
572               xnnpack_use_convolution2d(
573                   input,
574                   weight,
575                   bias_sizes_opt,
576                   padding,
577                   stride,
578                   dilation,
579                   groups,
580                   transposed);
581     }
582 #endif
583     return false;
584   }
585 
use_mpsat::native::ConvParams586   bool use_mps(const at::Tensor& input, const at::Tensor& weight) const {
587     // These checks need to be expanded. Currently we have very limited set of
588     // checks for MPS.
589 #ifdef USE_MPS
590     if (needs_64bit_indexing_no_split(input, weight)) {
591       return false;
592     }
593     if (!input.is_mps()) {
594       return false;
595     }
596     return true;
597 #else
598     return false;
599 #endif
600   }
601 
602   // We currently only have depthwise support for the case where groups ==
603   // nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of
604   // a depthwise multiplier)
is_depthwiseat::native::ConvParams605   bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const  {
606     return input.is_cuda() &&
607            !transposed &&
608            (input.ndimension() == 4 || input.ndimension() == 5) &&
609            at::symint::size<T>(input, 1) == groups &&
610            groups > 1 && // no point if there is only a single group
611            at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0; // output channels must be a multiple of input channels
612   }
613 };
614 
615 DEFINE_DISPATCH(conv_depthwise2d_backward_stub);
616 DEFINE_DISPATCH(conv_depthwise3d_backward_stub);
617 DEFINE_DISPATCH(cudnn_convolution_backward_stub);
618 DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub);
619 DEFINE_DISPATCH(slow_conv_transpose3d_backward_stub);
620 DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub);
621 DEFINE_DISPATCH(miopen_convolution_backward_stub);
622 DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub);
623 DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub);
624 DEFINE_DISPATCH(mkldnn_convolution_backward_stub);
625 DEFINE_DISPATCH(mkldnn_convolution_transpose_stub);
626 DEFINE_DISPATCH(mkldnn_convolution_transpose_backward_stub);
627 DEFINE_DISPATCH(slow_conv_dilated2d_backward_stub);
628 DEFINE_DISPATCH(slow_conv_dilated3d_backward_stub);
629 DEFINE_DISPATCH(slow_conv_transpose2d_backward_stub);
630 REGISTER_NO_CPU_DISPATCH(conv_depthwise2d_backward_stub);
631 REGISTER_NO_CPU_DISPATCH(conv_depthwise3d_backward_stub);
632 REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub);
633 REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub);
634 REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub);
635 REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub);
636 REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub);
637 
638 template <typename T>
operator <<(std::ostream & out,const ConvParams<T> & params)639 std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params) {
640   out << "ConvParams {"
641       << "  stride = " << IntArrayRef{params.stride}
642       << "  padding = " << ArrayRef<T>{params.padding}
643       << "  dilation = " << IntArrayRef{params.dilation}
644       << "  transposed = " << params.transposed
645       << "  output_padding = " << ArrayRef<T>{params.output_padding}
646       << "  groups = " << params.groups
647       << "  benchmark = " << params.benchmark
648       << "  deterministic = " << params.deterministic
649       << "  cudnn_enabled = " << params.cudnn_enabled
650       << "  allow_tf32 = " << params.allow_tf32
651       << "}";
652   return out;
653 }
654 
655 template <typename T>
check_shape_forward(const at::Tensor & input,const c10::ArrayRef<T> & weight_sizes,const at::Tensor & bias,const ConvParams<T> & params)656 static void check_shape_forward(const at::Tensor& input,
657                                 const c10::ArrayRef<T>& weight_sizes, const at::Tensor& bias,
658                                 const ConvParams<T>& params) {
659   int64_t k = input.ndimension();
660   int64_t weight_dim = weight_sizes.size();
661   auto groups = params.groups;
662   const auto& padding = params.padding;
663   const auto& dilation = params.dilation;
664   bool transposed = params.transposed;
665 
666   TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported");
667   TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
668   TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
669   TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero");
670 
671   TORCH_CHECK(weight_dim == k,
672            "Expected ", weight_dim, "-dimensional input for ", weight_dim,
673            "-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ",
674            at::symint::sizes<T>(input), " instead");
675   TORCH_CHECK(weight_sizes[0] >= groups,
676            "Given groups=", groups, ", expected weight to be at least ", groups,
677            " at dimension 0, but got weight of size ", weight_sizes, " instead");
678   TORCH_CHECK(weight_sizes[0] % groups == 0,
679            "Given groups=", groups, ", expected weight to be divisible by ",
680            groups, " at dimension 0, but got weight of size [", weight_sizes,
681            "] instead");
682 
683   if (!transposed) {
684     std::vector<T> input_shape;
685     std::vector<T> kernel_shape;
686     bool kernel_size_correct = true;
687 
688     TORCH_CHECK(at::symint::size<T>(input, 1) == (weight_sizes[1] * groups),
689                 "Given groups=", groups, ", weight of size ", weight_sizes,
690                 ", expected input", at::symint::sizes<T>(input), " to have ",
691                 (weight_sizes[1] * groups), " channels, but got ", at::symint::size<T>(input, 1),
692                 " channels instead");
693 
694     TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[0]),
695              "Given weight of size ", weight_sizes,
696              ", expected bias to be 1-dimensional with ", weight_sizes[0], " elements",
697              ", but got bias of size ", at::symint::sizes<T>(bias), " instead");
698 
699     for (const auto i : c10::irange(2, k)) {
700       input_shape.push_back(at::symint::size<T>(input, i) + 2 * padding[i-2]);
701       // log new kernel size considering dilation
702       kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1);
703       if (input_shape.back() < kernel_shape.back()) {
704         kernel_size_correct = false;
705       }
706     }
707 
708     TORCH_CHECK(input_shape.size() == kernel_shape.size(), "Inconsistent shape between Input and Kernel");
709 
710     if (!kernel_size_correct) {
711       // If kernel size is incorrect
712       std::ostringstream input_ss;
713       std::ostringstream kernel_ss;
714       std::string separator = "";
715 
716       for (int i = 0, len = input_shape.size(); i < len; ++i) {
717         input_ss << separator << input_shape[i];
718         kernel_ss << separator << kernel_shape[i];
719         separator = " x ";
720       }
721 
722       AT_ERROR("Calculated padded input size per channel: (", input_ss.str(), "). "
723                "Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size");
724     }
725   } else { // transposed
726     TORCH_CHECK(at::symint::size<T>(input, 1) == weight_sizes[0],
727              "Given transposed=", transposed, ", weight of size ", weight_sizes,
728              ", expected input", at::symint::sizes<T>(input), " to have ", weight_sizes[0],
729              " channels, but got ", at::symint::size<T>(input, 1), " channels instead");
730     TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[1] * groups),
731              "Given transposed=", transposed, ", weight of size ", weight_sizes,
732              ", expected bias to be 1-dimensional with ", weight_sizes[1] * groups, " elements",
733              ", but got bias of size ", at::symint::sizes<T>(bias), " instead");
734   }
735 }
736 
737 template <typename T>
check_shape_backward(const at::Tensor & input,const c10::ArrayRef<T> & weight_sizes,const ConvParams<T> & params)738 static void check_shape_backward(
739     const at::Tensor& input,
740     const c10::ArrayRef<T>& weight_sizes,
741     const ConvParams<T>& params) {
742   check_shape_forward<T>(input, weight_sizes, /*bias=*/ Tensor(), params);
743 }
744 
745 // Given an input tensor and an expected number of spatial dimensions, checks that the
746 // input is a valid shape and returns the batched form of the input.
747 //
748 // Args:
749 //     input (Tensor): Input tensor
750 //     num_spatial_dims (int): Number of spatial dimensions expected for the input
751 //     func_name (string): Function name to produce a nice error message for invalid input
752 //
753 // Returns a std::tuple containing:
754 //     batched_input (Tensor): Input with a batch dimension
755 //     is_batched (bool): Indicates whether the original input was already batched
batchify(const Tensor & input,const int64_t num_spatial_dims,const std::string & func_name)756 static std::tuple<Tensor, bool> batchify(
757     const Tensor& input,
758     const int64_t num_spatial_dims,
759     const std::string& func_name) {
760   // assume NTs are always batched
761   if (input.is_nested()) {
762     return std::make_tuple(input, true);
763   }
764   const auto dim_count_no_batch = num_spatial_dims + 1;
765   const auto dim_count_batch = dim_count_no_batch + 1;
766   const auto is_batched = (input.dim() == dim_count_batch);
767   TORCH_CHECK(input.dim() == dim_count_no_batch || is_batched,
768       "Expected ", dim_count_no_batch, "D (unbatched) or ", dim_count_batch,
769       "D (batched) input to ", func_name, ", but got input of size: ", input.sizes());
770   return std::make_tuple(is_batched ? input : input.unsqueeze(0), is_batched);
771 }
772 
check_input_same_type_as_parameters(const Tensor & input,const Tensor & weight,const Tensor & bias)773 static void check_input_same_type_as_parameters(
774     const Tensor& input,
775     const Tensor& weight,
776     const Tensor& bias) {
777   TORCH_CHECK(input.options().type_equal(weight.options()),
778       "Input type (", input.toString(), ") and weight type (", weight.toString(),
779       ") should be the same");
780   TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
781       "Input type (", input.toString(), ") and bias type (", bias.toString(),
782       ") should be the same");
783 }
784 
check_input_same_type_as_parameters(const Tensor & input,const Tensor & weight)785 static void check_input_same_type_as_parameters(
786     const Tensor& input,
787     const Tensor& weight) {
788   check_input_same_type_as_parameters(input, weight, /*bias=*/ Tensor());
789 }
790 
791 #if AT_MKLDNN_ENABLED()
check_input_same_type_as_parameters(const Tensor & input,const Tensor & weight,const Tensor & bias,const ConvBackend backend)792 static void check_input_same_type_as_parameters(
793     const Tensor& input,
794     const Tensor& weight,
795     const Tensor& bias,
796     const ConvBackend backend) {
797   if (backend == ConvBackend::Mkldnn || backend == ConvBackend::MkldnnTranspose) {
798     TORCH_CHECK(input.options().type_equal(weight.options())
799         || (input.is_mkldnn() && weight.device().is_cpu() && weight.scalar_type() == kFloat),
800         "Input type (", input.toString(), ") and weight type (", weight.toString(),
801         ") should be the same or input should be a MKLDNN tensor and weight is a dense tensor");
802     TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options()))
803         || (input.is_mkldnn() && bias.device().is_cpu() && bias.scalar_type() == kFloat),
804         "Input type (", input.toString(), ") and bias type (", bias.toString(),
805         ") should be the same or input should be a MKLDNN tensor and bias is a dense tensor");
806   } else {
807     check_input_same_type_as_parameters(input, weight, bias);
808   }
809 }
810 #endif
811 
view4d(const at::Tensor & tensor)812 static auto view4d(const at::Tensor& tensor) -> at::Tensor {
813   TORCH_CHECK(tensor.ndimension() == 3,
814            "expected 3D tensor, got tensor with ", tensor.ndimension(),
815            " dimensions instead");
816   return tensor.unsqueeze(2);
817 }
818 
view3d(const at::Tensor & tensor)819 static auto view3d(const at::Tensor& tensor) -> at::Tensor {
820   TORCH_CHECK(tensor.ndimension() == 4,
821            "expected 4D tensor, got tensor with ", tensor.ndimension(),
822            " dimensions instead");
823   return tensor.squeeze(2);
824 }
825 
subtensor(at::Tensor & tensor,int64_t dim,int64_t groups,int64_t g)826 static at::Tensor subtensor(at::Tensor& tensor, int64_t dim, int64_t groups, int64_t g) {
827   if (!tensor.defined()) {
828     return at::Tensor();
829   }
830   const auto memory_format = tensor.suggest_memory_format();
831   int64_t n = tensor.sizes()[dim] / groups;
832   return tensor.narrow(dim, n * g, n).contiguous(memory_format);
833 }
834 
835 namespace {
836 
complex_to_real(const Tensor & inp)837 std::pair<Tensor, Tensor> complex_to_real(const Tensor& inp) {
838   auto inp_view_as_complex = at::view_as_real(inp);
839   auto dim_i = inp_view_as_complex.dim() - 1;
840   auto i_r = inp_view_as_complex.select(dim_i, 0);
841   auto i_i = inp_view_as_complex.select(dim_i, 1);
842   return std::make_pair(i_r, i_i);
843 }
844 
complex_convolution(const Tensor & input,const Tensor & weight,const Tensor & bias,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef dilation,bool transposed,SymIntArrayRef output_padding,const c10::SymInt & groups)845 at::Tensor complex_convolution(
846     const Tensor& input,
847     const Tensor& weight,
848     const Tensor& bias,
849     SymIntArrayRef stride,
850     SymIntArrayRef padding,
851     SymIntArrayRef dilation,
852     bool transposed,
853     SymIntArrayRef output_padding,
854     const c10::SymInt& groups) {
855   check_input_same_type_as_parameters(input, weight, bias);
856   auto [i_r, i_i] = complex_to_real(input.resolve_conj());
857   auto [w_r, w_i] = complex_to_real(weight.resolve_conj());
858 
859   // [NOTE] Complex Convolution
860   // conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
861   // where W, x and b are all complex inputs.
862   // With Gauss Trick:
863   // a = conv(Wr, xr, br),
864   // b = conv(Wi, xi, 0),
865   // c = conv(Wr + Wi, xr + xi, bi + br)
866   // conv(W, x, b) = a - b + i(c - a - b)
867   Tensor a, b, c;
868   if (!bias.defined()) {
869     a = at::convolution_symint(i_r, w_r, bias, stride, padding, dilation, transposed, output_padding, groups);
870     b = at::convolution_symint(i_i, w_i, bias, stride, padding, dilation, transposed, output_padding, groups);
871     c = at::convolution_symint(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, transposed, output_padding, groups);
872   } else {
873     auto [b_r, b_i] = complex_to_real(bias.resolve_conj());
874     a = at::convolution_symint(i_r, w_r, b_r, stride, padding, dilation, transposed, output_padding, groups);
875     b = at::convolution_symint(i_i, w_i, Tensor(), stride, padding, dilation, transposed, output_padding, groups);
876     c = at::convolution_symint(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, transposed, output_padding, groups);
877   }
878 
879   auto i = c10::Scalar(c10::complex<double>(0, 1));
880   return a - b + i * (c - a - b);
881 }
882 
complex_convolution_mode(const at::Tensor & input,const at::Tensor & weight,const std::optional<at::Tensor> & bias_opt,c10::SymIntArrayRef stride,c10::string_view padding,c10::SymIntArrayRef dilation,const c10::SymInt & groups)883 at::Tensor complex_convolution_mode(
884     const at::Tensor& input,
885     const at::Tensor& weight,
886     const std::optional<at::Tensor>& bias_opt,
887     c10::SymIntArrayRef stride,
888     c10::string_view padding,
889     c10::SymIntArrayRef dilation,
890     const c10::SymInt& groups) {
891   auto bias = bias_opt.value_or(Tensor());
892   check_input_same_type_as_parameters(input, weight, bias);
893   auto [i_r, i_i] = complex_to_real(input.resolve_conj());
894   auto [w_r, w_i] = complex_to_real(weight.resolve_conj());
895 
896   // See [NOTE] Complex Convolution
897   Tensor a, b, c;
898   if (!bias.defined()) {
899     a = at::_convolution_mode_symint(i_r, w_r, bias, stride, padding, dilation, groups);
900     b = at::_convolution_mode_symint(i_i, w_i, bias, stride, padding, dilation, groups);
901     c = at::_convolution_mode_symint(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, groups);
902   } else {
903     auto [b_r, b_i] = complex_to_real(bias.resolve_conj());
904     a = at::_convolution_mode_symint(i_r, w_r, b_r, stride, padding, dilation, groups);
905     b = at::_convolution_mode_symint(i_i, w_i, Tensor(), stride, padding, dilation, groups);
906     c = at::_convolution_mode_symint(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, groups);
907   }
908 
909   auto i = c10::Scalar(c10::complex<double>(0, 1));
910   return a - b + i * (c - a - b);
911 }
912 
913 } // namespace
914 
conv1d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef dilation,c10::SymInt groups)915 at::Tensor conv1d_symint(
916     const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
917     SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
918   // See [Note: hacky wrapper removal for optional tensor]
919   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
920   const Tensor& bias = *bias_maybe_owned;
921 
922   TORCH_CHECK(
923     !bias.defined() || bias.dtype() == input_.dtype(),
924     "Input type (",
925     input_.dtype().name(),
926     ") and bias type (",
927     bias.dtype().name(),
928     ") should be the same");
929 
930   auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
931   Tensor output;
932   if (at::isComplexType(input_.scalar_type())) {
933     output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
934   } else {
935     output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {0}, groups);
936   }
937   return is_batched ? std::move(output) : output.squeeze(0);
938 }
939 
conv2d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef dilation,c10::SymInt groups)940 at::Tensor conv2d_symint(
941     const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
942     SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
943   // See [Note: hacky wrapper removal for optional tensor]
944   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
945   const Tensor& bias = *bias_maybe_owned;
946 
947   TORCH_CHECK(
948     !bias.defined() || bias.dtype() == input_.dtype(),
949     "Input type (",
950     input_.dtype().name(),
951     ") and bias type (",
952     bias.dtype().name(),
953     ") should be the same");
954 
955   auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
956   Tensor output;
957   if (at::isComplexType(input_.scalar_type())) {
958     output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
959   } else {
960     output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
961   }
962   return is_batched ? std::move(output) : output.squeeze(0);
963 }
964 
conv3d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef dilation,c10::SymInt groups)965 at::Tensor conv3d_symint(
966     const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
967     SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
968   // See [Note: hacky wrapper removal for optional tensor]
969   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
970   const Tensor& bias = *bias_maybe_owned;
971 
972   TORCH_CHECK(
973     !bias.defined() || bias.dtype() == input_.dtype(),
974     "Input type (",
975     input_.dtype().name(),
976     ") and bias type (",
977     bias.dtype().name(),
978     ") should be the same");
979 
980   auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
981   Tensor output;
982   if (at::isComplexType(input_.scalar_type())) {
983     output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
984   } else {
985     output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
986   }
987   return is_batched ? std::move(output) : output.squeeze(0);
988 }
989 
990 
convolution_same(const Tensor & input,const Tensor & weight,const Tensor & bias,SymIntArrayRef stride,SymIntArrayRef dilation,const c10::SymInt & groups)991 static Tensor convolution_same(
992     const Tensor &input, const Tensor &weight, const Tensor &bias,
993     SymIntArrayRef stride, SymIntArrayRef dilation, const c10::SymInt& groups) {
994 
995   auto k = weight.dim();
996   TORCH_CHECK(k > 2, "weight should have at least three dimensions");
997   TORCH_CHECK(groups > 0, "non-positive groups is not supported");
998   auto dim = static_cast<size_t>(k - 2);
999   auto weight_sizes = weight.sym_sizes();
1000   auto input_sizes = input.sym_sizes();
1001   TORCH_CHECK(k == input.dim(),
1002               "Expected ", k, "-dimensional input for ",
1003               k, "-dimensional weight", weight_sizes, ", but got ",
1004               input.dim(), "-dimensional input of size ",
1005               input.sizes(), " instead");
1006   TORCH_CHECK(stride.size() == dim || stride.size() == 1U,
1007               "stride cannot broadcast to ", dim, " dimensions");
1008   TORCH_CHECK(dilation.size() == dim || dilation.size() == 1U,
1009               "dilation cannot broadcast to ", dim, " dimensions");
1010   for (auto i: c10::irange(stride.size())) {
1011     TORCH_CHECK(stride[i] == 1, "padding='same' is not supported for strided convolutions");
1012   }
1013 
1014   // Calculate the correct padding
1015   SymDimVector padding_l, padding_r;
1016   bool symmetric_padding = true;
1017   for (auto i: c10::irange(dim)) {
1018     auto s = stride.size() == 1 ? stride[0] : stride[i];
1019     auto d = dilation.size() == 1 ? dilation[0] : dilation[i];
1020     auto pad = pooling_same_mode_padding_lr(
1021         input_sizes[i + 2], weight_sizes[i + 2], s, d);
1022     padding_l.push_back(pad.first);
1023     padding_r.push_back(pad.second);
1024     if (pad.first != pad.second) {
1025       symmetric_padding = false;
1026     }
1027   }
1028 
1029   if (symmetric_padding) {
1030     // All backends handle symmetric padding natively
1031     SymDimVector output_padding(static_cast<size_t>(dim));
1032     return at::convolution_symint(input, weight, bias, stride, padding_l, dilation,
1033                                false, output_padding, groups);
1034   }
1035 
1036   TORCH_WARN_ONCE("Using padding='same' with even kernel lengths and odd dilation may"
1037                   " require a zero-padded copy of the input be created");
1038   SmallVector<c10::SymInt, kDimVectorStaticSize * 2> pad_nd(static_cast<size_t>(2 * dim));
1039   for (auto i: c10::irange(dim)) {
1040     // Apply padding by the difference, leaving only a symmetric padding
1041     auto delta_pad = padding_r[i] - padding_l[i];
1042     auto pad_idx = 2 * (dim - 1 - i);  // F.pad goes from last dim to first
1043     if (delta_pad > 0) {
1044       pad_nd[pad_idx + 1] = delta_pad;
1045     } else {
1046       pad_nd[pad_idx] = delta_pad;
1047       padding_l[i] = padding_r[i];
1048     }
1049   }
1050   auto padded_input = at::constant_pad_nd_symint(input, pad_nd, 0);
1051   SymDimVector output_padding(static_cast<size_t>(dim));
1052   return at::convolution_symint(padded_input, weight, bias, stride, padding_l,
1053                                 dilation, false, output_padding, groups);
1054 }
1055 
_convolution_mode_symint(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,c10::string_view padding,SymIntArrayRef dilation,c10::SymInt groups)1056 Tensor _convolution_mode_symint(
1057     const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1058     SymIntArrayRef stride, c10::string_view padding, SymIntArrayRef dilation,
1059     c10::SymInt groups) {
1060   // See [Note: hacky wrapper removal for optional tensor]
1061   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1062   const Tensor& bias = *bias_maybe_owned;
1063 
1064   if (padding == "same") {
1065     return at::native::convolution_same(
1066         input, weight, bias, stride, dilation, groups);
1067   } else if (padding == "valid") {
1068     return at::convolution_symint(
1069         input, weight, bias, stride, {{0}}, dilation, false, {{0}}, groups);
1070   }
1071   TORCH_CHECK(false, "Invalid padding string: '", padding, "'");
1072 }
1073 
conv1d_padding_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias,c10::SymIntArrayRef stride,c10::string_view padding,c10::SymIntArrayRef dilation,c10::SymInt groups)1074 at::Tensor conv1d_padding_symint(
1075     const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
1076     c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
1077     c10::SymInt groups) {
1078   auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
1079   Tensor output;
1080   if (at::isComplexType(input_.scalar_type())) {
1081     output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
1082   } else {
1083     output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
1084   }
1085   return is_batched ? std::move(output) : output.squeeze(0);
1086 }
1087 
conv2d_padding_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias,c10::SymIntArrayRef stride,c10::string_view padding,c10::SymIntArrayRef dilation,c10::SymInt groups)1088 at::Tensor conv2d_padding_symint(
1089     const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
1090     c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
1091     c10::SymInt groups) {
1092   auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
1093   Tensor output;
1094   if (at::isComplexType(input_.scalar_type())) {
1095     output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
1096   } else {
1097     output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
1098   }
1099   return is_batched ? std::move(output) : output.squeeze(0);
1100 }
1101 
conv3d_padding_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias,c10::SymIntArrayRef stride,c10::string_view padding,c10::SymIntArrayRef dilation,c10::SymInt groups)1102 at::Tensor conv3d_padding_symint(
1103     const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
1104     c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
1105     c10::SymInt groups) {
1106   auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
1107   Tensor output;
1108   if (at::isComplexType(input_.scalar_type())) {
1109     output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
1110   } else {
1111     output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
1112   }
1113   return is_batched ? std::move(output) : output.squeeze(0);
1114 }
1115 
conv_transpose1d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef output_padding,c10::SymInt groups,SymIntArrayRef dilation)1116 at::Tensor conv_transpose1d_symint(
1117     const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1118     SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
1119   // See [Note: hacky wrapper removal for optional tensor]
1120   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1121   const Tensor& bias = *bias_maybe_owned;
1122 
1123   auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv_transpose1d");
1124   Tensor output;
1125   if (at::isComplexType(input_.scalar_type())) {
1126     output = complex_convolution(
1127       input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1128   } else {
1129     output = at::convolution_symint(
1130       input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1131   }
1132   return is_batched ? std::move(output) : output.squeeze(0);
1133 }
1134 
conv_transpose2d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef output_padding,c10::SymInt groups,SymIntArrayRef dilation)1135 at::Tensor conv_transpose2d_symint(
1136     const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1137     SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
1138   // See [Note: hacky wrapper removal for optional tensor]
1139   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1140   const Tensor& bias = *bias_maybe_owned;
1141 
1142   auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d");
1143   Tensor output;
1144   if (at::isComplexType(input_.scalar_type())) {
1145     output = complex_convolution(
1146       input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1147   } else {
1148     output = at::convolution_symint(
1149       input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1150   }
1151   return is_batched ? std::move(output) : output.squeeze(0);
1152 }
1153 
conv_transpose3d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef output_padding,c10::SymInt groups,SymIntArrayRef dilation)1154 at::Tensor conv_transpose3d_symint(
1155     const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1156     SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
1157   // See [Note: hacky wrapper removal for optional tensor]
1158   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1159   const Tensor& bias = *bias_maybe_owned;
1160 
1161   auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d");
1162   Tensor output;
1163   if (at::isComplexType(input_.scalar_type())) {
1164     output = complex_convolution(
1165       input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1166   } else {
1167     output = at::convolution_symint(
1168       input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1169   }
1170   return is_batched ? std::move(output) : output.squeeze(0);
1171 }
1172 
convolution(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups)1173 at::Tensor convolution(
1174     const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1175     IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
1176     bool transposed, IntArrayRef output_padding, int64_t groups) {
1177   // See [Note: hacky wrapper removal for optional tensor]
1178   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1179   const Tensor& bias = *bias_maybe_owned;
1180 
1181   auto& ctx = at::globalContext();
1182   // See Note [Enabling Deterministic Operations]
1183   bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
1184   return at::_convolution(input, weight, bias, stride, padding, dilation,
1185                           transposed, output_padding, groups,
1186                           ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
1187 }
1188 
convolution_overrideable(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups)1189 at::Tensor convolution_overrideable(
1190     const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1191     IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
1192     bool transposed, IntArrayRef output_padding, int64_t groups) {
1193   TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function ");
1194 }
1195 
1196 // Function to select the convolution backend based on the inputs and params.
1197 // This overload is used within the convolution internals but not exposed to python.
1198 // NB: The forward pass provides a bias tensor while the backward pass provides
1199 // a bool indicating whether the bias is defined. This is done to save memory by
1200 // avoiding saving the full bias tensor for backward.
1201 template <typename T>
_select_conv_backend(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias,const at::OptionalArrayRef<T> bias_sizes_opt,const bool need_backward,const ConvParams<T> & params)1202 ConvBackend _select_conv_backend(
1203     const Tensor& input,
1204     const Tensor& weight,
1205     const std::optional<Tensor>& bias,
1206     const at::OptionalArrayRef<T> bias_sizes_opt,
1207     const bool need_backward,
1208     const ConvParams<T>& params) {
1209 
1210   // don't send empty inputs through backends
1211   if (at::symint::size<T>(input, 0) == 0 || at::symint::size<T>(input, 1) == 0) {
1212     return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty;
1213   } else if (at::symint::numel<T>(input) == 0) {
1214     TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes<T>(input));
1215   }
1216 
1217   if (params.is_depthwise(input, weight)) {
1218     if (params.use_cudnn_depthwise(input, weight)) {
1219       return ConvBackend::Cudnn;
1220     } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
1221       return ConvBackend::MiopenDepthwise;
1222     } else {
1223       if (input.ndimension() == 4) {
1224         return ConvBackend::CudaDepthwise2d;
1225       } else if (input.ndimension() == 5) {
1226         return ConvBackend::CudaDepthwise3d;
1227       } else {
1228         // unsupported
1229       }
1230     }
1231   } else if (params.use_cudnn(input, weight)) {
1232     if (params.transposed) {
1233       return ConvBackend::CudnnTranspose;
1234     } else {
1235       return ConvBackend::Cudnn;
1236     }
1237   } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
1238     if (params.transposed) {
1239       return ConvBackend::MiopenTranspose;
1240     } else {
1241       return ConvBackend::Miopen;
1242     }
1243   } else if (params.use_mkldnn(input, weight)) {
1244     if (params.transposed) {
1245       return ConvBackend::MkldnnTranspose;
1246     } else {
1247       return ConvBackend::Mkldnn;
1248     }
1249   } else if (!need_backward && params.use_xnnpack(input, weight, bias_sizes_opt)) {
1250     // Using prepacked conv is preferred, but XNNPACK is still the fastest
1251     // option for NHWC.
1252     return ConvBackend::Xnnpack2d;
1253   // 3x3 depthwith convolutions implementation is inference only
1254   } else if (!need_backward && params.use_cpu_depthwise3x3_winograd(input, weight, bias)) {
1255     return ConvBackend::Winograd3x3Depthwise;
1256   } else if (
1257       !params.transposed && (input.ndimension() == 5) &&
1258       (input.device().is_cpu()) &&
1259       !params.is_dilated()) {
1260     // fast path for grouped conv3d
1261     return ConvBackend::Slow3d;
1262   } else if (input.device().is_cpu() || input.is_cuda()) {
1263     // backends without support for groups
1264     if (params.transposed) {
1265       if (input.ndimension() == 4) {
1266         return ConvBackend::SlowTranspose2d;
1267       } else if (input.ndimension() == 5) {
1268         return ConvBackend::SlowTranspose3d;
1269       } else {
1270         // unsupported
1271       }
1272     } else {  /* Not transposed */
1273       if (input.ndimension() == 4) {
1274         if (params.is_dilated()) {
1275           return ConvBackend::SlowDilated2d;
1276         } else {  /* dim == 4, non-dilated */
1277           if (params.use_nnpack(input, weight)) {
1278             return ConvBackend::NnpackSpatial;
1279           } else {
1280             /* CPU implementation has specialized MM kernels
1281                for non-dilated case here */
1282             return ConvBackend::Slow2d;
1283           }
1284         }
1285       } else if (input.ndimension() == 5 && (input.is_cuda() || params.is_dilated())) {
1286         return ConvBackend::SlowDilated3d;
1287       } else if (input.ndimension() == 5) { /* dim == 5, CPU, non-dilated */
1288         /* CPU implementation has specialized MM kernels
1289            for non-dilated case here */
1290         return ConvBackend::Slow3d;
1291       } else {
1292         // unsupported
1293       }
1294     }
1295   } else if (params.use_mps(input, weight)) {
1296     if (params.transposed) {
1297       return ConvBackend::MpsTranspose;
1298     } else {
1299       return ConvBackend::Mps;
1300     }
1301   } else {
1302     // Only reach here when input is backend with out-of-source implementation.
1303     return ConvBackend::Overrideable;
1304   }
1305 
1306   // Error out if no suitable backend was found.
1307   AT_ERROR("unsupported ConvNd parameters");
1308 }
1309 
1310 // Selects a backend for convolution based on the inputs and params.
select_conv_backend(const Tensor & input_r,const Tensor & weight_r,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride_,SymIntArrayRef padding_,SymIntArrayRef dilation_,bool transposed_,SymIntArrayRef output_padding_,c10::SymInt groups_,const at::OptionalSymIntArrayRef bias_sizes_opt)1311 ConvBackend select_conv_backend(
1312     const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_opt,
1313     SymIntArrayRef stride_, SymIntArrayRef padding_, SymIntArrayRef dilation_,
1314     bool transposed_, SymIntArrayRef output_padding_, c10::SymInt groups_, const at::OptionalSymIntArrayRef bias_sizes_opt) {
1315   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1316   const Tensor& bias = *bias_maybe_owned;
1317 
1318   auto& ctx = at::globalContext();
1319   auto k = weight_r.ndimension();
1320   int64_t dim = k - 2;
1321   ConvParams<c10::SymInt> params;
1322   params.stride = expand_param_if_needed(stride_, "stride", dim);
1323   params.padding = expand_param_if_needed(padding_, "padding", dim);
1324   params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
1325   params.transposed = transposed_;
1326   params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
1327   params.groups = std::move(groups_);
1328   params.benchmark = ctx.benchmarkCuDNN();
1329   params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
1330   params.cudnn_enabled = ctx.userEnabledCuDNN();
1331   params.allow_tf32 = ctx.allowTF32CuDNN();
1332 
1333   auto input = input_r;
1334   auto weight = weight_r;
1335   check_shape_forward(input, weight.sym_sizes(), bias, params);
1336 
1337   // Expand 1d -> 2d.
1338   // This is only done for backends that don't natively support 1d spatial input.
1339   if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
1340     // avoid accidentally going through NHWC for permuted 3d input.
1341     input = input.contiguous();
1342     params.view1d_as_2d();
1343     input = view4d(input);
1344     weight = view4d(weight);
1345   }
1346 
1347   auto bias_sizes = bias.defined() ? std::optional<SymIntArrayRef>(bias.sym_sizes()) : bias_sizes_opt;
1348   bool need_backward = GradMode::is_enabled() &&
1349       (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
1350   return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params);
1351 }
1352 
1353 // For BC reasons, have a copy that does not require bias_opt
select_conv_backend(const Tensor & input,const Tensor & weight,const at::OptionalIntArrayRef bias_sizes_opt,const bool need_backward,const ConvParams<int64_t> & params)1354 static ConvBackend select_conv_backend(
1355     const Tensor& input,
1356     const Tensor& weight,
1357     const at::OptionalIntArrayRef bias_sizes_opt,
1358     const bool need_backward,
1359     const ConvParams<int64_t>& params) {
1360   return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params);
1361 }
1362 
_convolution_nogroup_backend(const Tensor & input,const Tensor & weight,const Tensor & bias,const ConvBackend backend,const ConvParams<int64_t> & params)1363 static at::Tensor _convolution_nogroup_backend(
1364     const Tensor& input,
1365     const Tensor& weight,
1366     const Tensor& bias,
1367     const ConvBackend backend,
1368     const ConvParams<int64_t>& params) {
1369   auto kernel_size = weight.sizes().slice(2);
1370   switch(backend) {
1371     case ConvBackend::NnpackSpatial:
1372 #if AT_NNPACK_ENABLED()
1373       return at::_nnpack_spatial_convolution(input, weight, bias, params.padding, params.stride);
1374 #else
1375       TORCH_INTERNAL_ASSERT(false, "NnpackSpatial backend was selected in PyTorch compiled without nnpack support");
1376 #endif
1377     case ConvBackend::Slow2d:
1378       return at::thnn_conv2d(input, weight, kernel_size, bias, params.stride, params.padding);
1379     case ConvBackend::SlowDilated2d:
1380       return at::slow_conv_dilated2d(
1381           input, weight, kernel_size, bias, params.stride, params.padding, params.dilation);
1382     case ConvBackend::SlowDilated3d:
1383       return at::slow_conv_dilated3d(
1384           input, weight, kernel_size, bias, params.stride, params.padding, params.dilation);
1385     case ConvBackend::SlowTranspose2d:
1386       return at::slow_conv_transpose2d(
1387           input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation);
1388     case ConvBackend::SlowTranspose3d:
1389       return at::slow_conv_transpose3d(
1390           input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation);
1391     default:
1392       TORCH_CHECK(false, "Unsupported conv nogroup backend encountered");
1393   }
1394 }
1395 
calc_output_size(const Tensor & input,const Tensor & weight,const ConvParams<int64_t> & params)1396 static inline std::vector<int64_t> calc_output_size(
1397     const Tensor& input,
1398     const Tensor& weight,
1399     const ConvParams<int64_t>& params) {
1400   std::vector<int64_t> output_size = params.transposed ?
1401     conv_input_size(input.sizes(), weight.sizes(), params.padding, params.output_padding,
1402         params.stride, params.dilation, params.groups) :
1403     conv_output_size(input.sizes(), weight.sizes(), params.padding, params.stride, params.dilation);
1404 
1405   // Handle empty # of channels.
1406   if (input.size(input_channels_dim) == 0) {
1407     output_size[output_channels_dim] = 0;
1408   }
1409   return output_size;
1410 }
1411 
determine_backend_memory_format(const Tensor & input,const Tensor & weight,const ConvBackend backend)1412 static inline at::MemoryFormat determine_backend_memory_format(
1413     const Tensor& input,
1414     const Tensor& weight,
1415     const ConvBackend backend) {
1416   at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
1417 #if !defined(C10_MOBILE)
1418   auto k = weight.ndimension();
1419   // See Note [Mobile check segfaults]
1420   switch(backend) {
1421     case ConvBackend::Cudnn:
1422     case ConvBackend::CudnnTranspose:
1423       if (detail::getCUDAHooks().compiledWithCuDNN()) {
1424         backend_memory_format = cudnn_conv_suggest_memory_format(input, weight);
1425       }
1426       break;
1427     case ConvBackend::Miopen:
1428     case ConvBackend::MiopenDepthwise:
1429     case ConvBackend::MiopenTranspose:
1430       if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
1431         TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
1432             "Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
1433         backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
1434       }
1435       break;
1436     case ConvBackend::Mkldnn:
1437     case ConvBackend::MkldnnTranspose:
1438       if (mkldnn_conv_use_channels_last(input, weight)) {
1439         backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
1440       }
1441       break;
1442     case ConvBackend::Slow2d:
1443     case ConvBackend::SlowDilated2d:
1444     case ConvBackend::SlowTranspose2d:
1445       if (thnn_conv_use_channels_last(input, weight)) {
1446         backend_memory_format = at::MemoryFormat::ChannelsLast;
1447       }
1448       break;
1449     case ConvBackend::Overrideable:
1450       if (xpu_conv_use_channels_last(input, weight)) {
1451         backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
1452       }
1453       break;
1454     default:
1455       backend_memory_format = at::MemoryFormat::Contiguous;
1456   }
1457 #endif
1458   return backend_memory_format;
1459 }
1460 
_determine_backend_memory_format(const Tensor & input,const Tensor & weight,const ConvBackend backend)1461 at::MemoryFormat _determine_backend_memory_format(
1462     const Tensor& input,
1463     const Tensor& weight,
1464     const ConvBackend backend)  {
1465   return determine_backend_memory_format(input, weight, backend);
1466 }
1467 
_convolution(const Tensor & input_r,const Tensor & weight_r,const std::optional<Tensor> & bias_r_opt,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,bool benchmark,bool deterministic,bool cudnn_enabled,bool allow_tf32)1468 at::Tensor _convolution(
1469     const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
1470     IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
1471     bool transposed_, IntArrayRef output_padding_, int64_t groups_,
1472     bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
1473   // See [Note: hacky wrapper removal for optional tensor]
1474   c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
1475   const Tensor& bias_r = *bias_r_maybe_owned;
1476 
1477   auto input = input_r;
1478   auto weight = weight_r;
1479   auto bias = bias_r;
1480   auto k = weight.ndimension();
1481   c10::IntArrayRef weight_sizes = weight.sizes();
1482   int64_t dim = k - 2;
1483 
1484   TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
1485   TORCH_CHECK(groups_ > 0, "non-positive groups is not supported");
1486 
1487   ConvParams<int64_t> params;
1488   params.stride = expand_param_if_needed(stride_, "stride", dim);
1489   params.padding = expand_param_if_needed(padding_, "padding", dim);
1490   params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
1491   params.transposed = transposed_;
1492   params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
1493   params.groups = groups_;
1494   params.benchmark = benchmark;
1495   params.deterministic = deterministic;
1496   params.cudnn_enabled = cudnn_enabled;
1497   params.allow_tf32 = allow_tf32;
1498 
1499   check_shape_forward(input, weight_sizes, bias, params);
1500 
1501   // Expand 1d -> 2d.
1502   // This is only done for backends that don't natively support 1d spatial input.
1503   if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
1504     // avoid accidentally going through NHWC for permuted 3d input.
1505     input = input.contiguous();
1506     params.view1d_as_2d();
1507     input = view4d(input);
1508     weight = view4d(weight);
1509   }
1510 
1511   // Select appropriate backend to use.
1512   auto bias_sizes_opt = bias.defined() ? std::optional<IntArrayRef>(bias.sizes()) : std::nullopt;
1513   bool need_backward = GradMode::is_enabled() &&
1514       (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
1515   ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params);
1516   at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
1517 
1518   // Call the backend.
1519   Tensor output;
1520   auto kernel_size = weight.sizes().slice(2);
1521   switch (backend) {
1522     case ConvBackend::CudaDepthwise2d:
1523       output = at::_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias,
1524           params.stride, params.padding, params.dilation);
1525       break;
1526     case ConvBackend::CudaDepthwise3d:
1527       output = at::conv_depthwise3d(input.contiguous(), weight, kernel_size, bias,
1528           params.stride, params.padding, params.dilation);
1529       break;
1530     case ConvBackend::Cudnn:
1531       check_input_same_type_as_parameters(input, weight, bias);
1532       output = at::cudnn_convolution(
1533           input.contiguous(backend_memory_format), weight, params.padding, params.stride,
1534           params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
1535       if (bias.defined()) {
1536         output.add_(reshape_bias(input.dim(), bias));
1537       }
1538       break;
1539     case ConvBackend::CudnnTranspose:
1540       check_input_same_type_as_parameters(input, weight, bias);
1541       output = at::cudnn_convolution_transpose(
1542           input.contiguous(backend_memory_format), weight, params.padding, params.output_padding,
1543           params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
1544       if (bias.defined()) {
1545         output.add_(reshape_bias(input.dim(), bias));
1546       }
1547       break;
1548     case ConvBackend::Empty:
1549     {
1550       Tensor weight_view;
1551       // Use permute and clone to avoid at::_unsafe_view(weight, -1) failure for non-contiguous cases where
1552       // view size is not compatible with input tensor's size and stride.
1553       if(weight.is_contiguous()) {
1554         weight_view = at::_unsafe_view(weight, -1);
1555       } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast)) {
1556         weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 1}), -1);
1557       } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
1558         weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 4, 1}), -1);
1559       } else {
1560         weight_view = at::_unsafe_view(weight.clone(at::MemoryFormat::Contiguous), -1);
1561       }
1562 
1563       output = (input.size(1) == 0) ? (input.view(-1) * weight_view) : (input * weight_view[0]);
1564       if (bias.defined()) {
1565         output.add_(bias[0]);
1566       }
1567       output = output.view(calc_output_size(input, weight, params));
1568       break;
1569     }
1570     case ConvBackend::Miopen:
1571       check_input_same_type_as_parameters(input, weight, bias);
1572       output = at::miopen_convolution(
1573           input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride,
1574           params.dilation, params.groups, params.benchmark, params.deterministic);
1575       break;
1576     case ConvBackend::MiopenDepthwise:
1577       output = at::miopen_depthwise_convolution(
1578           input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride,
1579           params.dilation, params.groups, params.benchmark, params.deterministic);
1580       break;
1581     case ConvBackend::MiopenTranspose:
1582       check_input_same_type_as_parameters(input, weight, bias);
1583       output = at::miopen_convolution_transpose(
1584           input.contiguous(backend_memory_format), weight, bias, params.padding, params.output_padding,
1585           params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
1586       break;
1587     case ConvBackend::Mkldnn:
1588 #if AT_MKLDNN_ENABLED()
1589       check_input_same_type_as_parameters(input, weight, bias, backend);
1590       if (!input.is_mkldnn()) {
1591         // need to ensure contiguous for non-mkldnn tensors
1592         input = input.contiguous(backend_memory_format);
1593         weight = weight.contiguous(backend_memory_format);
1594         bias = bias.defined() ? bias.contiguous() : bias;
1595       }
1596       output = at::mkldnn_convolution(
1597           input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
1598 #else
1599       TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
1600 #endif
1601       break;
1602     case ConvBackend::MkldnnTranspose:
1603 #if AT_MKLDNN_ENABLED()
1604       check_input_same_type_as_parameters(input, weight, bias, backend);
1605       if (!input.is_mkldnn()) {
1606         // need to ensure contiguous for non-mkldnn tensors
1607         input = input.contiguous(backend_memory_format);
1608         weight = weight.contiguous(backend_memory_format);
1609         bias = bias.defined() ? bias.contiguous() : bias;
1610       }
1611       output = mkldnn_convolution_transpose_stub(input.device().type(),
1612           input, weight, bias, params.padding, params.output_padding, params.stride, params.dilation, params.groups);
1613 #else
1614       TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
1615 #endif
1616       break;
1617     case ConvBackend::MkldnnEmpty:
1618 #if AT_MKLDNN_ENABLED()
1619       output = empty_mkldnn(
1620           calc_output_size(input, weight, params), optTypeMetaToScalarType(input.options().dtype_opt()),
1621           input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt());
1622 #else
1623       TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
1624 #endif
1625       break;
1626     case ConvBackend::Overrideable:
1627       output = at::convolution_overrideable(
1628           input, weight, bias, params.stride, params.padding, params.dilation, params.transposed,
1629           params.output_padding, params.groups);
1630       break;
1631     case ConvBackend::Slow3d:
1632       output = at::slow_conv3d(input, weight, kernel_size, bias, params.stride, params.padding);
1633       break;
1634     case ConvBackend::Winograd3x3Depthwise:
1635       output = convolution_depthwise3x3_winograd_stub(
1636           input.device().type(), input, weight, bias, params.stride, params.padding, params.groups);
1637       break;
1638     case ConvBackend::Xnnpack2d:
1639       output = xnnpack::convolution2d(
1640           input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
1641       break;
1642     // Handle backends that don't natively support groups > 1.
1643     case ConvBackend::NnpackSpatial:
1644     case ConvBackend::Slow2d:
1645     case ConvBackend::SlowDilated2d:
1646     case ConvBackend::SlowDilated3d:
1647     case ConvBackend::SlowTranspose2d:
1648     case ConvBackend::SlowTranspose3d:
1649       input = input.contiguous(backend_memory_format);
1650       weight = weight.contiguous(backend_memory_format);
1651       if (params.groups == 1) {
1652         output = _convolution_nogroup_backend(input, weight, bias, backend, params);
1653       } else {
1654         std::vector<Tensor> outputs(params.groups);
1655         for (const auto g : c10::irange(params.groups)) {
1656           auto input_g = subtensor(input, 1, params.groups, g);
1657           auto weight_g = subtensor(weight, 0, params.groups, g);
1658           auto bias_g = subtensor(bias, 0, params.groups, g);
1659           outputs[g] = _convolution_nogroup_backend(input_g, weight_g, bias_g, backend, params);
1660         }
1661         output = at::cat(outputs, 1);
1662       }
1663       break;
1664     case ConvBackend::Mps:
1665 #ifdef USE_MPS
1666       TORCH_CHECK(input.options().type_equal(weight.options()),
1667                "Input type (", input.toString(), ") and weight type (", weight.toString(),
1668                ") should be the same");
1669       TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
1670                "Input type (", input.toString(), ") and bias type (", bias.toString(),
1671                ") should be the same");
1672 
1673       output = at::_mps_convolution(input, weight, bias.defined() ? bias.contiguous() : bias,
1674                                      params.padding, params.stride, params.dilation,
1675                                      params.groups);
1676 #else
1677       TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
1678 #endif
1679       break;
1680     case ConvBackend::MpsTranspose:
1681 #ifdef USE_MPS
1682       TORCH_CHECK(input.options().type_equal(weight.options()),
1683                "Input type (", input.toString(), ") and weight type (", weight.toString(),
1684                ") should be the same");
1685       TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
1686                "Input type (", input.toString(), ") and bias type (", bias.toString(),
1687                ") should be the same");
1688       output = at::_mps_convolution_transpose(
1689           input.contiguous(backend_memory_format), weight,
1690           params.padding, params.output_padding,
1691           params.stride, params.dilation, params.groups);
1692       if (bias.defined()) {
1693         output.add_(reshape_bias(input.dim(), bias));
1694       }
1695 #else
1696       TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
1697 #endif
1698       break;
1699   }
1700 
1701   if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
1702     output = view3d(output);
1703   }
1704 
1705   return output;
1706 }
1707 
_convolution(const Tensor & input_r,const Tensor & weight_r,const std::optional<Tensor> & bias_r_opt,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,bool benchmark,bool deterministic,bool cudnn_enabled)1708 at::Tensor _convolution(
1709     const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
1710     IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
1711     bool transposed_, IntArrayRef output_padding_, int64_t groups_,
1712     bool benchmark, bool deterministic, bool cudnn_enabled)
1713 {
1714   // See [Note: hacky wrapper removal for optional tensor]
1715   c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
1716   const Tensor& bias_r = *bias_r_maybe_owned;
1717 
1718   return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN());
1719 }
1720 
convolution_backward_overrideable(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,std::array<bool,3> output_mask)1721 std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
1722         const Tensor& grad_output, const Tensor& input, const Tensor& weight,
1723         IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
1724         bool transposed, IntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) {
1725    TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_backward_overrideable: You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function ");
1726   return std::tuple<Tensor, Tensor, Tensor>(
1727           at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT),
1728           at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT),
1729           at::empty({}));
1730 }
1731 
subvariable(const Tensor & var,int64_t dim,int64_t groups,int64_t g)1732 static Tensor subvariable(const Tensor& var, int64_t dim, int64_t groups, int64_t g) {
1733   int64_t n = var.sizes()[dim] / groups;
1734   auto result = var.narrow(dim, n * g, n);
1735   return result;
1736 }
1737 
_convolution_double_backward(const std::optional<Tensor> & ggI_opt,const std::optional<Tensor> & ggW_r_opt,const std::optional<Tensor> & ggb_opt,const Tensor & gO_r,const Tensor & weight_r,const Tensor & input,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,std::array<bool,3> output_mask)1738 std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const std::optional<Tensor>& ggI_opt, const std::optional<Tensor>& ggW_r_opt, const std::optional<Tensor>& ggb_opt,
1739     const Tensor& gO_r, const Tensor& weight_r, const Tensor& input,
1740     IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
1741     bool transposed_, IntArrayRef output_padding_, int64_t groups_,
1742     std::array<bool, 3> output_mask) {
1743   // See [Note: hacky wrapper removal for optional tensor]
1744   c10::MaybeOwned<Tensor> ggI_maybe_owned = at::borrow_from_optional_tensor(ggI_opt);
1745   const Tensor& ggI = *ggI_maybe_owned;
1746   const Tensor& ggW_r = c10::value_or_else(ggW_r_opt, [] {return Tensor();});
1747   const Tensor& ggb = c10::value_or_else(ggb_opt, [] {return Tensor();});
1748 
1749 
1750   auto ggW = ggW_r;
1751   auto gO = gO_r;
1752   auto weight = weight_r;
1753 
1754   int64_t dim = weight.ndimension() - 2;
1755   ConvParams<int64_t> params;
1756   params.stride = expand_param_if_needed(stride_, "stride", dim);
1757   params.padding = expand_param_if_needed(padding_, "padding", dim);
1758   params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
1759   params.transposed = transposed_;
1760   params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
1761   // TODO: hacky way of inferring the groups number for grouped Conv3D
1762   // See: https://github.com/pytorch/pytorch/pull/36355
1763   if (!params.transposed && input.dim() > 4) {
1764     // Avoid undefined behavior when num channels == 0; params are unused for that case.
1765     params.groups = (weight.size(1) > 0) ? input.size(1) / weight.size(1) : -1;
1766   } else {
1767     params.groups = groups_;
1768   }
1769 
1770   // Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb
1771   Tensor ggO;
1772   if (input.numel() != 0) {
1773     if (ggI.defined()) {
1774       if (weight.is_cuda()) {
1775         weight = weight.contiguous();
1776       }
1777       ggO = at::convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
1778     }
1779 
1780     if (ggW.defined()) {
1781       if (ggW.is_cuda()) {
1782         ggW = ggW.contiguous();
1783       }
1784       auto ggW_term = at::convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
1785       if (ggO.defined()) {
1786         ggO = ggO + ggW_term;
1787       } else {
1788         ggO = ggW_term;
1789       }
1790     }
1791   }
1792 
1793   if (ggb.defined()) {
1794     // View as (1, ggb.size(0), 1, 1...)
1795 
1796     // Expand
1797     std::vector<int64_t> new_size(gO.ndimension(), 1);
1798     new_size[1] = ggb.sizes()[0];
1799     auto ggb_contiguous = ggb.contiguous();
1800     auto ggb_view = ggb_contiguous.view(new_size);
1801 
1802     // Expand
1803     auto ggb_expanded = ggb_view.expand(gO.sizes());
1804 
1805     if (ggO.defined()) {
1806       ggO = ggO + ggb_expanded;
1807     } else {
1808       ggO = ggb_expanded;
1809     }
1810   }
1811 
1812   // Compute gW = conv(ggI, gO)
1813   Tensor gW;
1814   if (ggI.defined()) {
1815 
1816     // Modified params with correct padding
1817     ConvParams<int64_t> gw_conv_params(params);
1818 
1819     // Disable groups as they are handled separately
1820     auto groups = gw_conv_params.groups;
1821     gw_conv_params.groups = 1;
1822     std::swap(gw_conv_params.dilation, gw_conv_params.stride);
1823 
1824     // Transpose gO and ggI to accumulate over batch
1825     auto gOt = gO.transpose(0, 1);
1826     auto ggIt = ggI.transpose(0, 1);
1827 
1828     Tensor gWt;
1829     // Compute conv
1830     if (input.numel() != 0) {
1831       if (groups == 1) {
1832 
1833         if (gOt.is_cuda()) {
1834           gOt = gOt.contiguous();
1835         }
1836         // Compute conv
1837         if (params.transposed) {
1838           gw_conv_params.transposed = false;
1839           gWt = at::convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
1840         } else {
1841           gWt = at::convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
1842         }
1843       } else {
1844         std::vector<Tensor> gWt_list(groups);
1845         for (const auto g : c10::irange(groups)) {
1846           auto ggIt_g = subvariable(ggIt, 0, groups, g);
1847           auto gOt_g = subvariable(gOt, 0, groups, g);
1848           if (gOt_g.is_cuda()) {
1849             gOt_g = gOt_g.contiguous();
1850           }
1851 
1852           // Compute conv
1853           if (params.transposed) {
1854             gw_conv_params.transposed = false;
1855             gWt_list[g] = at::convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
1856           } else {
1857             gWt_list[g] = at::convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
1858           }
1859         }
1860 
1861         gWt = at::cat(gWt_list, 1);
1862       }
1863 
1864       // Transpose gW to match chan_in and chan_out
1865       gW = gWt.transpose(0, 1);
1866 
1867       // narrow gW to only relevant portion
1868       // we do it this way instead of narrowing the input itself because
1869       // the ConvForward kernels don't support asymmetric padding.
1870       auto gW_size = gW.sizes();
1871       auto w_size = weight.sizes();
1872       for (const auto i : c10::irange(2, static_cast<int64_t>(gW_size.size()))) {
1873         if (gW_size[i] > w_size[i]) {
1874             gW = gW.narrow(i, 0, w_size[i]);
1875             gW_size = gW.sizes();
1876         }
1877       }
1878     }
1879   }
1880 
1881   // Compute gI = convT(gO, ggW) if !transposed
1882   //         gI = conv(gO, ggw)  if transposed
1883   Tensor gI;
1884   if (input.numel() != 0) {
1885     if (ggW.defined()) {
1886       ConvParams<int64_t> gi_conv_params(params);
1887       gi_conv_params.transposed = !params.transposed;
1888 
1889       if (params.transposed) {
1890         if (gO.is_cuda()) {
1891           gO = gO.contiguous();
1892         }
1893         gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);
1894 
1895         // narrow gI to only relevant portion
1896         // we do it this way because negative output_padding is not supported
1897         // TODO: figure out if we can narrow gO and save some compute,
1898         // rather than narrowing the computed gI
1899         auto gI_size = gI.sizes();
1900         auto i_size = input.sizes();
1901         for (const auto i : c10::irange(2, static_cast<int64_t>(gI_size.size()))) {
1902           if (gI_size[i] > i_size[i]) {
1903             gI = gI.narrow(i, 0, i_size[i]);
1904             gI_size = gI.sizes();
1905           }
1906         }
1907       } else {
1908         // calculate output_padding
1909         // TODO: figure out why this needs to be computed...
1910         auto kernel_size = weight.sizes().slice(2);
1911         auto input_shape = input.sizes().slice(2);
1912         auto grad_output_shape = gO.sizes().slice(2);
1913 
1914         for (const auto i : c10::irange(kernel_size.size())) {
1915           // Check if whole input has been used or not
1916           auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.dilation[i]
1917             - 2 * gi_conv_params.padding[i]
1918             + (gi_conv_params.stride[i] * (grad_output_shape[i] - 1) + 1);
1919           if (expected_input_shape != input_shape[i]) {
1920             gi_conv_params.output_padding[i] = input_shape[i] - expected_input_shape;
1921           }
1922         }
1923 
1924         if (gO.is_cuda()) {
1925           gO = gO.contiguous();
1926         }
1927 
1928         gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);
1929       }
1930     }
1931   }
1932 
1933   return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
1934 }
1935 
_convolution_backward_nogroup_backend(const Tensor & grad_output,const Tensor & input,const Tensor & weight,const std::array<bool,3> output_mask,const ConvBackend backend,const ConvParams<int64_t> & params)1936 static std::tuple<at::Tensor, at::Tensor, at::Tensor> _convolution_backward_nogroup_backend(
1937     const Tensor& grad_output,
1938     const Tensor& input,
1939     const Tensor& weight,
1940     const std::array<bool, 3> output_mask,
1941     const ConvBackend backend,
1942     const ConvParams<int64_t>& params) {
1943   auto kernel_size = weight.sizes().slice(2);
1944   switch(backend) {
1945     case ConvBackend::Slow2d:
1946       return at::_slow_conv2d_backward(
1947         grad_output, input, weight, kernel_size, params.stride, params.padding, output_mask);
1948     // NB: nnpack backward does not support strided convolutions; use slow impl instead
1949     case ConvBackend::NnpackSpatial:
1950     case ConvBackend::SlowDilated2d:
1951       return slow_conv_dilated2d_backward_stub(
1952         input.device().type(),
1953         grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask);
1954     case ConvBackend::SlowDilated3d:
1955       return slow_conv_dilated3d_backward_stub(
1956         input.device().type(),
1957         grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask);
1958     case ConvBackend::SlowTranspose2d:
1959       return slow_conv_transpose2d_backward_stub(
1960         input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding,
1961         params.output_padding, params.dilation, output_mask);
1962     case ConvBackend::SlowTranspose3d:
1963       return slow_conv_transpose3d_backward_stub(
1964         input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding,
1965         params.output_padding, params.dilation, output_mask);
1966     default:
1967       TORCH_CHECK(false, "Unsupported conv nogroup backend encountered");
1968   }
1969 }
1970 
1971 // Backward pass for convolution. Computes gradients for input, weight, and bias depending on the
1972 // output_mask setting. This function supports 1D, 2D, or 3D spatial convolution and currently requires
1973 // a single batch dimension to be present.
1974 //
1975 // Args:
1976 //   grad_output_: tensor of shape (N, C_out, L_out), (N, C_out, H_out, W_out), or (N, C_out, D_out, H_out, W_out)
1977 //   input_: tensor of shape (N, C_in, L_in), (N, C_in, H_in, W_in), or (N, C_in, D_in, H_in, W_in)
1978 //   weight_: tensor of shape (C_out, C_in // groups, *kernel_size); dimension of kernel_size must match the number
1979 //       of input spatial dimensions
1980 //   bias_sizes_opt: if specified, indicates that a bias was used in the forward pass and contains the shape
1981 //       of the bias. While the bias shape can be computed from other inputs, it is provided to this function for
1982 //       ease of use. The bias shape is (weight.shape[0]) for normal convolution and (weight.shape[1] * groups)
1983 //       for transposed convolution.
1984 //   stride: single value or an array with dimension matching the number of input spatial dimensions
1985 //   padding: single value or an array with dimension matching the number of input spatial dimensions
1986 //   dilation: single value or an array with dimension matching the number of input spatial dimensions
1987 //   transposed: boolean indicating whether the convolution is transposed
1988 //   output_padding: single value or dimension == number of input spatial dimensions; only supported when
1989 //       transposed is true
1990 //   groups: number of groups for grouped convolution
1991 //   output_mask: 3-dim boolean array specifying which gradients to compute in input, weight, bias order
convolution_backward(const Tensor & grad_output_,const Tensor & input_,const Tensor & weight_,const at::OptionalIntArrayRef bias_sizes_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,std::array<bool,3> output_mask)1992 std::tuple<Tensor, Tensor, Tensor> convolution_backward(
1993     const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
1994     const at::OptionalIntArrayRef bias_sizes_opt,
1995     IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
1996     int64_t groups, std::array<bool, 3> output_mask) {
1997   auto grad_output = grad_output_;
1998   auto input = input_;
1999   auto weight = weight_;
2000 
2001   auto k = weight.ndimension();
2002   int64_t dim = k - 2;
2003 
2004   TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
2005 
2006   auto& ctx = at::globalContext();
2007   ConvParams<int64_t> params;
2008   params.stride = expand_param_if_needed(stride, "stride", dim);
2009   params.padding = expand_param_if_needed(padding, "padding", dim);
2010   params.dilation = expand_param_if_needed(dilation, "dilation", dim);
2011   params.transposed = transposed;
2012   params.output_padding = expand_param_if_needed(output_padding, "output_padding", dim);
2013   params.groups = groups;
2014   params.benchmark = ctx.benchmarkCuDNN();
2015   params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
2016   params.cudnn_enabled = ctx.userEnabledCuDNN();
2017   params.allow_tf32 = ctx.allowTF32CuDNN();
2018 
2019   // Validate inputs.
2020   check_shape_backward(input, weight.sizes(), params);
2021   TORCH_CHECK(input.dim() == grad_output.dim(),
2022       "Expected input and grad_output to have the same number of dimensions, but got: ",
2023       input.dim(), " and ", grad_output.dim());
2024 
2025   // output_padding is only supported for transposed convolutions
2026   if (!params.transposed) {
2027     for (auto pad : params.output_padding) {
2028       TORCH_CHECK(pad == 0, "output_padding is not supported for non-transposed convolutions; got: ",
2029         params.output_padding);
2030     }
2031   }
2032 
2033   // Expand 1d -> 2d.
2034   // This is only done for backends that don't natively support 1d spatial input.
2035   if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
2036     // avoid accidentally going through NHWC for permuted 3d input.
2037     input = input.contiguous();
2038     params.view1d_as_2d();
2039     grad_output = view4d(grad_output);
2040     input = view4d(input);
2041     weight = view4d(weight);
2042   }
2043 
2044   // Select appropriate backend to use.
2045   ConvBackend backend = select_conv_backend(input, weight, bias_sizes_opt, /*need_backward=*/ true, params);
2046   at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
2047 
2048   // Call the backend.
2049   Tensor backend_grad_input, backend_grad_weight, backend_grad_bias;
2050   auto kernel_size = weight.sizes().slice(2);
2051   switch(backend) {
2052     case ConvBackend::CudaDepthwise2d:
2053     {
2054       std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
2055       std::tie(backend_grad_input, backend_grad_weight) =
2056         conv_depthwise2d_backward_stub(input.device().type(), grad_output, input,
2057           weight, kernel_size, params.stride, params.padding, params.dilation, input_weight_output_mask);
2058       break;
2059     }
2060     case ConvBackend::CudaDepthwise3d:
2061       TORCH_CHECK(input.ndimension() == 5);
2062       std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2063         conv_depthwise3d_backward_stub(
2064           input.device().type(), grad_output, input, weight, kernel_size, params.stride,
2065           params.padding, params.dilation, output_mask);
2066       break;
2067     case ConvBackend::Cudnn:
2068     {
2069       check_input_same_type_as_parameters(input, weight);
2070       std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
2071       std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_backward_stub(
2072           input.device().type(),
2073           // Only make input contiguous when it is necessary for the backwards computation
2074           output_mask[1] ? input.contiguous(backend_memory_format) : input,
2075           grad_output, weight, params.padding, params.stride,
2076           params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
2077           input_weight_output_mask);
2078       break;
2079     }
2080     case ConvBackend::Mps:
2081     {
2082 #ifdef USE_MPS
2083       check_input_same_type_as_parameters(input, weight);
2084       std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2085         at::mps_convolution_backward(input, grad_output, weight, params.padding,
2086           params.stride, params.dilation, params.groups, output_mask);
2087 #else
2088       TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
2089 #endif
2090       break;
2091     }
2092     case ConvBackend::MpsTranspose:
2093     {
2094 #ifdef USE_MPS
2095       check_input_same_type_as_parameters(input, weight);
2096       std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
2097       std::tie(backend_grad_input, backend_grad_weight) = at::mps_convolution_transpose_backward(
2098         // Only make input contiguous when it is necessary for the backwards computation
2099         output_mask[1] ? input.contiguous(backend_memory_format) : input,
2100         grad_output, weight, params.padding, params.output_padding,
2101         params.stride, params.dilation, params.groups, input_weight_output_mask);
2102 #else
2103       TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
2104 #endif
2105       break;
2106     }
2107     case ConvBackend::CudnnTranspose:
2108     {
2109       check_input_same_type_as_parameters(input, weight);
2110       std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
2111       std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_transpose_backward_stub(
2112         input.device().type(),
2113         // Only make input contiguous when it is necessary for the backwards computation
2114         output_mask[1] ? input.contiguous(backend_memory_format) : input,
2115         grad_output, weight, params.padding, params.output_padding,
2116         params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
2117         input_weight_output_mask);
2118       break;
2119     }
2120     case ConvBackend::Empty:
2121       if (output_mask[0]) {
2122         backend_grad_input = at::zeros_like(input);
2123       }
2124       if (output_mask[1]) {
2125         backend_grad_weight = at::zeros_like(weight);
2126       }
2127       if (output_mask[2]) {
2128         backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options());
2129       }
2130       break;
2131     case ConvBackend::MkldnnEmpty:
2132 #if AT_MKLDNN_ENABLED()
2133       if (output_mask[0]) {
2134         if (input.is_mkldnn()) {
2135           backend_grad_input = empty_mkldnn(input.sizes(), optTypeMetaToScalarType(input.options().dtype_opt()),
2136               input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt());
2137           backend_grad_input.zero_();
2138         } else {
2139           backend_grad_input = at::zeros_like(input);
2140         }
2141       }
2142       if (output_mask[1]) {
2143         // mkldnn weight is not supported during training by the mkldnn backend
2144         backend_grad_weight = at::zeros_like(weight);
2145       }
2146       if (output_mask[2]) {
2147         // mkldnn bias is not supported during training by the mkldnn backend
2148         backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options());
2149       }
2150 #else
2151       TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
2152 #endif
2153       break;
2154     case ConvBackend::Miopen:
2155       check_input_same_type_as_parameters(input, weight);
2156       std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2157         miopen_convolution_backward_stub(
2158           input.device().type(),
2159           input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride,
2160           params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
2161       break;
2162     case ConvBackend::MiopenDepthwise:
2163       std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2164           miopen_depthwise_convolution_backward_stub(
2165             input.device().type(),
2166             input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride,
2167             params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
2168       break;
2169     case ConvBackend::MiopenTranspose:
2170       check_input_same_type_as_parameters(input, weight);
2171       std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2172         miopen_convolution_transpose_backward_stub(
2173           input.device().type(),
2174           input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.output_padding,
2175           params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
2176       break;
2177     case ConvBackend::Mkldnn:
2178       TORCH_CHECK(!weight.is_mkldnn(),
2179           "The MKLDNN backend does not support weight as an MKLDNN tensor during training");
2180       if (!input.is_mkldnn()) {
2181         input = input.contiguous(backend_memory_format);
2182         weight = weight.contiguous(backend_memory_format);
2183       }
2184       std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2185         mkldnn_convolution_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
2186           params.stride, params.dilation, params.groups, output_mask);
2187       break;
2188     case ConvBackend::MkldnnTranspose:
2189       TORCH_CHECK(!weight.is_mkldnn(),
2190           "The MKLDNN backend does not support weight as an MKLDNN tensor during training");
2191       if (!input.is_mkldnn()) {
2192         input = input.contiguous(backend_memory_format);
2193         weight = weight.contiguous(backend_memory_format);
2194       }
2195       std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2196         mkldnn_convolution_transpose_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
2197         params.output_padding, params.stride, params.dilation, params.groups, output_mask);
2198       break;
2199     case ConvBackend::Overrideable:
2200       // Only reach here when input is backend with out-of-source implementation.
2201       std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2202         at::convolution_backward_overrideable(grad_output, input, weight, params.stride, params.padding,
2203           params.dilation, params.transposed, params.output_padding, params.groups, output_mask);
2204       break;
2205     case ConvBackend::Slow3d:
2206       // Note that no CUDA implementation of this kernel exists currently.
2207       std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2208         slow_conv3d_backward_cpu(
2209             grad_output, input, weight, kernel_size,
2210             params.stride, params.padding, output_mask);
2211       break;
2212     // Handle backends that don't natively support groups > 1.
2213     case ConvBackend::NnpackSpatial:
2214     case ConvBackend::Slow2d:
2215     case ConvBackend::SlowDilated2d:
2216     case ConvBackend::SlowDilated3d:
2217     case ConvBackend::SlowTranspose2d:
2218     case ConvBackend::SlowTranspose3d:
2219     {
2220       input = input.contiguous(backend_memory_format);
2221       weight = weight.contiguous(backend_memory_format);
2222       if (params.groups == 1) {
2223         std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2224           _convolution_backward_nogroup_backend(
2225             grad_output, input, weight, output_mask, backend, params);
2226       } else {
2227         std::vector<Tensor> backend_grad_inputs(params.groups);
2228         std::vector<Tensor> backend_grad_weights(params.groups);
2229         std::vector<Tensor> backend_grad_biases(params.groups);
2230         for (int g = 0; g < params.groups; ++g) {
2231           auto grad_output_g = subtensor(grad_output, 1, params.groups, g);
2232           auto input_g = subtensor(input, 1, params.groups, g);
2233           auto weight_g = subtensor(weight, 0, params.groups, g);
2234           std::tie(backend_grad_inputs[g], backend_grad_weights[g], backend_grad_biases[g]) =
2235             _convolution_backward_nogroup_backend(
2236               grad_output_g, input_g, weight_g, output_mask, backend, params);
2237         }
2238         if (output_mask[0]) {
2239           backend_grad_input = at::cat(backend_grad_inputs, 1);
2240         }
2241         if (output_mask[1]) {
2242           backend_grad_weight = at::cat(backend_grad_weights, 0);
2243         }
2244         if (output_mask[2]) {
2245           backend_grad_bias = at::cat(backend_grad_biases, 0);
2246         }
2247       }
2248       break;
2249     }
2250     // Backward is not supported for these backends.
2251     case ConvBackend::Winograd3x3Depthwise:
2252       TORCH_CHECK(false, "Backward is not supported for depthwise 3x3 winograd");
2253       break;
2254     case ConvBackend::Xnnpack2d:
2255       TORCH_CHECK(false, "Backward is not supported for xnnpack");
2256       break;
2257   }
2258 
2259   // Convert 2D inputs back to 1D for backends that don't natively support 1D
2260   // spatial inputs.
2261   if (output_mask[0]) {
2262     if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
2263       backend_grad_input = view3d(backend_grad_input);
2264     }
2265   }
2266   if (output_mask[1]) {
2267     if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
2268       backend_grad_weight = view3d(backend_grad_weight);
2269     }
2270   }
2271   if (output_mask[2]) {
2272     if (!backend_grad_bias.defined()) {
2273       // Calculate bias gradients outside of the backend for those that don't support it.
2274       backend_grad_bias = grad_output.sum((dim == 3) ? IntArrayRef{0, 2, 3, 4} : IntArrayRef{0, 2, 3});
2275     }
2276   }
2277 
2278   return std::make_tuple(backend_grad_input, backend_grad_weight, backend_grad_bias);
2279 }
2280 
_cudnn_set_conv_benchmark_empty_cache(bool enable)2281 void _cudnn_set_conv_benchmark_empty_cache(bool enable) {
2282   conv_benchmark_empty_cache = enable;
2283 }
2284 
_cudnn_get_conv_benchmark_empty_cache()2285 bool _cudnn_get_conv_benchmark_empty_cache() {
2286   return conv_benchmark_empty_cache;
2287 }
2288 
2289 
2290 
2291 } // namespace at::native
2292