1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorOperators.h>
6 #include <ATen/native/ConvolutionMM3d.h>
7 #include <ATen/native/ConvUtils.h>
8 #include <ATen/native/Pool.h>
9 #include <ATen/native/cpu/DepthwiseConvKernel.h>
10 #include <ATen/native/utils/ParamUtils.h>
11 #include <ATen/native/xnnpack/Engine.h>
12 #include <c10/core/GradMode.h>
13 #include <c10/util/accumulate.h>
14 #include <c10/util/irange.h>
15 #include <c10/macros/Macros.h>
16 #include <limits>
17 #include <utility>
18
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #else
22 #include <ATen/ops/permute.h>
23 #endif
24
25 #if AT_NNPACK_ENABLED()
26 #include <nnpack.h>
27 #endif
28
29 #if AT_MKLDNN_ENABLED()
30 #include <ATen/native/mkldnn/Utils.h>
31 #endif
32
33 #ifndef AT_PER_OPERATOR_HEADERS
34 #include <ATen/Functions.h>
35 #include <ATen/NativeFunctions.h>
36 #else
37 #include <ATen/ops/_conv_depthwise2d.h>
38 #include <ATen/ops/_convolution.h>
39 #include <ATen/ops/_convolution_double_backward_native.h>
40 #include <ATen/ops/_convolution_mode.h>
41 #include <ATen/ops/_convolution_mode_native.h>
42 #include <ATen/ops/_convolution_native.h>
43 #include <ATen/ops/_mps_convolution.h>
44 #include <ATen/ops/_mps_convolution_transpose.h>
45 #include <ATen/ops/_nnpack_available.h>
46 #include <ATen/ops/_nnpack_spatial_convolution.h>
47 #include <ATen/ops/_slow_conv2d_backward.h>
48 #include <ATen/ops/_unsafe_view.h>
49 #include <ATen/ops/cat.h>
50 #include <ATen/ops/constant_pad_nd.h>
51 #include <ATen/ops/conv1d_native.h>
52 #include <ATen/ops/conv2d_native.h>
53 #include <ATen/ops/conv3d_native.h>
54 #include <ATen/ops/conv_depthwise3d.h>
55 #include <ATen/ops/conv_transpose1d_native.h>
56 #include <ATen/ops/conv_transpose2d_native.h>
57 #include <ATen/ops/conv_transpose3d_native.h>
58 #include <ATen/ops/convolution.h>
59 #include <ATen/ops/convolution_backward_native.h>
60 #include <ATen/ops/convolution_backward_overrideable.h>
61 #include <ATen/ops/convolution_backward_overrideable_native.h>
62 #include <ATen/ops/convolution_native.h>
63 #include <ATen/ops/convolution_overrideable.h>
64 #include <ATen/ops/convolution_overrideable_native.h>
65 #include <ATen/ops/cudnn_convolution.h>
66 #include <ATen/ops/cudnn_convolution_transpose.h>
67 #include <ATen/ops/empty.h>
68 #include <ATen/ops/empty_like.h>
69 #include <ATen/ops/empty_native.h>
70 #include <ATen/ops/miopen_convolution.h>
71 #include <ATen/ops/miopen_convolution_transpose.h>
72 #include <ATen/ops/miopen_depthwise_convolution.h>
73 #include <ATen/ops/mkldnn_convolution.h>
74 #include <ATen/ops/mps_convolution_backward.h>
75 #include <ATen/ops/mps_convolution_transpose_backward.h>
76 #include <ATen/ops/slow_conv3d.h>
77 #include <ATen/ops/slow_conv_dilated2d.h>
78 #include <ATen/ops/slow_conv_dilated3d.h>
79 #include <ATen/ops/slow_conv_transpose2d.h>
80 #include <ATen/ops/slow_conv_transpose3d.h>
81 #include <ATen/ops/thnn_conv2d.h>
82 #include <ATen/ops/view_as_real.h>
83 #include <ATen/ops/zeros.h>
84 #include <ATen/ops/zeros_like.h>
85 #endif
86
87 constexpr int MIOPEN_DIM_MAX = 5;
88
89 namespace at::native {
90
91
92 static bool conv_benchmark_empty_cache = true;
93
94 // Check workload to activate fast depthwise FP16 cudnn conv kernels
95 template <typename T>
check_cudnn_depthwise_workload(const at::Tensor & input,T stride)96 bool check_cudnn_depthwise_workload(const at::Tensor& input, T stride) {
97 auto w = at::symint::size<T>(input, 3); // same as h
98 auto ch = at::symint::size<T>(input, 1);
99 auto bs = at::symint::size<T>(input, 0);
100 if (stride==1) {
101 if (w >= 7) {
102 // All batch sizes and nb_channels
103 if (w >= 112) {
104 return true;
105 }
106
107 // large nb_channels
108 if (ch >= 1024) {
109 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
110 if (w >= 56) {
111 return true;
112 } else if (bs >= 32) {
113 return true;
114 }
115 }
116
117 // batch_size specific
118 if (bs >= 128) {
119 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
120 if (ch >= 512) {
121 return true;
122 } else if (ch >= 64) {
123 if (w >= 14) {
124 return true;
125 }
126 } else if ((ch >= 32) && (w >=28)) {
127 return true;
128 }
129 } else if (bs >= 64) {
130 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
131 if ((ch >= 256) && (w >= 14)) {
132 return true;
133 } else if ((ch >= 32) && (w >= 28)) {
134 return true;
135 }
136 } else if (bs >= 32) {
137 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
138 if ((ch >= 256) && (w >= 14)) {
139 return true;
140 } else if ((ch >= 128) && (w >= 28)) {
141 return true;
142 } else if ((ch >= 32) && (w >= 56)) {
143 return true;
144 }
145 } else if (bs >= 16) {
146 if ((ch >= 1024) && (w >= 14)) {
147 return true;
148 }
149 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
150 if ((ch >= 256) && (w >= 28)) {
151 return true;
152 } else if ((ch >= 32) && (w >= 56)) {
153 return true;
154 }
155 } else if (bs >= 8) {
156 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
157 if ((ch >= 512) && (w >= 28)) {
158 return true;
159 } else if ((ch >= 64) && (w >= 56)) {
160 return true;
161 }
162 }
163 }
164 } else if (stride==2) {
165 if (ch < 256) {
166 return false;
167 }
168
169 if (w >= 7) {
170 if (bs >= 128) {
171 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
172 if (ch >= 1024) {
173 return true;
174 } else if ((ch >= 512) && (w >= 14)) {
175 return true;
176 } else if (w >= 28) {
177 return true;
178 }
179 } else if (bs >= 64) {
180 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
181 if ((ch >= 512) && (w >= 14)) {
182 return true;
183 } else if (w >= 28) {
184 return true;
185 }
186 } else if (bs >= 32) {
187 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
188 if ((ch >= 1024) && (w >= 14)) {
189 return true;
190 } else if (w >= 28) {
191 return true;
192 }
193 } else if (bs >= 16) {
194 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
195 if ((ch >= 512) && (w >= 28)) {
196 return true;
197 } else if (w >= 56) {
198 return true;
199 }
200 } else if (bs >= 8) {
201 // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
202 if ((ch >= 1024) && (w >= 28)) {
203 return true;
204 } else if (w >= 56) {
205 return true;
206 }
207 } else if (bs >= 1) {
208 if ((ch >= 512) && (w >=112)) {
209 return true;
210 }
211 }
212 }
213 }
214 return false;
215 }
216
217 // simplified version for cudnn 8.2 and above
218 template <typename T>
check_cudnn_depthwise_workload_with_filter(const at::Tensor & input,T stride,const at::Tensor & weight)219 bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, T stride, const at::Tensor& weight) {
220 // 1D conv
221 if(at::symint::size<T>(input, 2) == 1 && stride == 1){
222 return true;
223 }
224
225 // 2d conv
226 // only square filters
227 if (at::symint::size<T>(weight, 2) != at::symint::size<T>(weight, 3)) return false;
228 auto filter = at::symint::size<T>(weight, 3);
229 // only 1/3/5 filter
230 if (filter != 1 && filter != 3 && filter != 5) return false;
231 // we don't enforce square input but only check width to reduce heuristic space
232 if (at::symint::size<T>(input, 3) < 7) return false; // min width 7
233 auto w = at::symint::size<T>(input, 3);
234 // only 1/2 stride, use cudnn for all stride 1
235 if (stride == 1) return true;
236 if (stride != 2) return false;
237
238 auto ch = at::symint::size<T>(input, 1);
239 auto bs = at::symint::size<T>(input, 0);
240 // special case since bs1 show good perf in lots of cases
241 if (bs == 1) {
242 if (filter == 1 && w <= 28) return true;
243 if (filter == 3 || filter == 5) return true;
244 } else {
245 if (filter == 1 && bs <= 16 && ch >= 128 && w <= 7) return true;
246 if (filter == 3 || filter == 5) {
247 if ((ch >= 512) || (ch >= 256 && w >= 28)) return true;
248 }
249 }
250 return false;
251 }
252
253
254 #if defined(C10_MOBILE)
xnnpack_use_convolution2d(const Tensor & input,const Tensor & weight,const at::OptionalIntArrayRef bias_sizes_opt,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool transposed)255 static bool xnnpack_use_convolution2d(
256 const Tensor& input,
257 const Tensor& weight,
258 const at::OptionalIntArrayRef bias_sizes_opt,
259 const IntArrayRef padding,
260 const IntArrayRef stride,
261 const IntArrayRef dilation,
262 const int64_t groups,
263 const bool transposed) {
264 return xnnpack::use_convolution2d(input, weight, bias_sizes_opt, padding, stride, dilation, groups, transposed);
265 }
266
xnnpack_use_convolution2d(const Tensor & input,const Tensor & weight,const at::OptionalSymIntArrayRef bias_sizes_opt,const SymIntArrayRef padding,const SymIntArrayRef stride,const SymIntArrayRef dilation,const c10::SymInt groups,const bool transposed)267 static bool xnnpack_use_convolution2d(
268 const Tensor& input,
269 const Tensor& weight,
270 const at::OptionalSymIntArrayRef bias_sizes_opt,
271 const SymIntArrayRef padding,
272 const SymIntArrayRef stride,
273 const SymIntArrayRef dilation,
274 const c10::SymInt groups,
275 const bool transposed) {
276 // Never use xnnpack for symbolic tracing
277 return false;
278 }
279 #endif
280
281 // This struct is templated so that we can run backend selection in a dynamic
282 // shapes context; all of the real kernel selection in eager mode runs with
283 // int64_t
284 template <typename T>
285 struct ConvParams {
286 std::vector<T> stride;
287 std::vector<T> padding;
288 std::vector<T> dilation;
289 bool transposed{};
290 std::vector<T> output_padding;
291 T groups{};
292 bool benchmark{};
293 bool deterministic{};
294 bool cudnn_enabled{};
295 bool allow_tf32{};
296
is_stridedat::native::ConvParams297 bool is_strided() const {
298 bool is_strided = false;
299 for (const auto& s : stride) {
300 is_strided |= (s != 1);
301 }
302 return is_strided;
303 }
304
is_dilatedat::native::ConvParams305 bool is_dilated() const {
306 bool is_dilated = false;
307 for (const auto& d : dilation) {
308 is_dilated |= (d != 1);
309 }
310 return is_dilated;
311 }
312
is_paddedat::native::ConvParams313 bool is_padded() const {
314 bool is_padded = false;
315 for (auto p : padding) {
316 is_padded |= (p != 0);
317 }
318 return is_padded;
319 }
320
is_output_padding_negat::native::ConvParams321 bool is_output_padding_neg() const {
322 bool is_non_neg = false;
323 for (const auto& p : output_padding) {
324 is_non_neg |= (p < 0);
325 }
326 return is_non_neg;
327 }
328
is_output_padding_bigat::native::ConvParams329 bool is_output_padding_big() const {
330 bool is_big = false;
331 for (auto i: c10::irange(output_padding.size())) {
332 is_big |= (output_padding[i] >= stride[i]);
333 }
334 return is_big;
335 }
336
is_padding_negat::native::ConvParams337 bool is_padding_neg() const {
338 bool is_non_neg = false;
339 for (const auto& p : padding) {
340 is_non_neg |= (p < 0);
341 }
342 return is_non_neg;
343 }
344
is_dilation_negat::native::ConvParams345 bool is_dilation_neg() const {
346 bool is_non_neg = false;
347 for (const auto& p : dilation) {
348 is_non_neg |= (p < 0);
349 }
350 return is_non_neg;
351 }
352
is_stride_nonposat::native::ConvParams353 bool is_stride_nonpos() const {
354 bool is_nonpos = false;
355 for (const auto& s : stride) {
356 is_nonpos |= (s <= 0);
357 }
358 return is_nonpos;
359 }
360
view1d_as_2dat::native::ConvParams361 void view1d_as_2d() {
362 if (stride.size() == 1) {
363 stride.insert(stride.begin(), 1);
364 padding.insert(padding.begin(), 0);
365 dilation.insert(dilation.begin(), 1);
366 output_padding.insert(output_padding.begin(), 0);
367 }
368 }
369
use_cpu_depthwise3x3_winogradat::native::ConvParams370 bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const std::optional<at::Tensor>& bias) const {
371 #if defined(__ARM_NEON__) || (defined(__riscv_v_intrinsic) && __riscv_v_intrinsic>=12000)
372 // Currently only 3x3 depthwise convolutions on tensors of float are supported.
373 return (input.ndimension() == 4) &&
374 (at::symint::size<T>(input, 1) == groups) &&
375 (weight.ndimension() == 4 ) &&
376 (at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0) &&
377 (at::symint::size<T>(weight, 1) == 1) &&
378 (at::symint::size<T>(weight, 2) == 3) &&
379 (at::symint::size<T>(weight, 3) == 3) &&
380 (input.device().is_cpu()) &&
381 (input.scalar_type() == at::kFloat) &&
382 input.is_contiguous() &&
383 (weight.device().is_cpu()) &&
384 (weight.scalar_type() == at::kFloat) &&
385 weight.is_contiguous() &&
386 (!bias.has_value() || bias->is_contiguous()) &&
387 !is_strided() &&
388 !is_dilated() &&
389 !transposed;
390 #else
391 return false;
392 #endif
393 }
394
needs_64bit_indexing_no_splitat::native::ConvParams395 bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const {
396 constexpr int64_t int_max = std::numeric_limits<int>::max();
397 auto numel_input = at::symint::numel<T>(input);
398 // empty input
399 if (numel_input == 0) {
400 return false;
401 }
402 // input size can not be reduced to the range of int by splitting the batch dim
403 auto n = at::symint::size<T>(input, 0);
404 if (numel_input / n > int_max) {
405 return true;
406 }
407 // output size can not be reduced to the range of int by splitting the batch dim
408 T outsize = 1;
409 if (transposed) {
410 auto o = conv_input_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, output_padding, stride, dilation, groups);
411 outsize = c10::multiply_integers(o.begin() + 1, o.end());
412 } else {
413 auto o = conv_output_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, stride, dilation);
414 outsize = c10::multiply_integers(o.begin() + 1, o.end());
415 }
416 return outsize > int_max;
417 }
418
use_cudnnat::native::ConvParams419 bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const {
420 // Note [Mobile check segfaults]
421 // cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest
422 // that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how)
423 #if !defined(C10_MOBILE)
424 if (!detail::getCUDAHooks().compiledWithCuDNN()) {
425 return false;
426 }
427 if (needs_64bit_indexing_no_split(input, weight)) {
428 static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
429 if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
430 TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"
431 " if the V8 API is not enabled or before cuDNN version 9.3+."
432 " Consider upgrading cuDNN and/or enabling the V8 API for better efficiency.");
433 return false;
434 }
435 }
436 if (!input.is_cuda() || !cudnn_enabled) {
437 return false;
438 }
439 if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) {
440 if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) {
441 return false;
442 }
443 }
444 if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) {
445 // bypass dilation checks for channels_last convolution
446 if (deterministic && is_dilated()) {
447 // cudnn doesn't support deterministic dilated convolution fully yet
448 return false;
449 }
450 if (is_dilated()) {
451 return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big();
452 }
453 }
454 return !is_output_padding_big();
455 #else
456 return false;
457 #endif
458 }
459
460 // Use cudnn for FP16 depthwise convolutions
use_cudnn_depthwiseat::native::ConvParams461 bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const {
462 if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) {
463 // always use cudnn_depthwise for channels_last format
464 return true;
465 }
466 if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
467 long cudnn_version = detail::getCUDAHooks().versionCuDNN();
468 if (cudnn_version >= 8200) {
469 bool kernel_cond = (use_cudnn(input, weight) &&
470 input.scalar_type() == kHalf && // only for FP16
471 weight.scalar_type() == kHalf &&
472 is_depthwise(input, weight) &&
473 input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
474 !is_dilated() && // no dilation supported
475 (stride[0] == stride[1] || at::symint::size<T>(input, 2) == 1) && // square or 1d
476 at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
477 if (kernel_cond) {
478 return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1], weight);
479 }
480 }
481 // keep (7600 <= cudnn < 8200) code unchanged
482 bool kernel_cond = (cudnn_version >= 7600 &&
483 use_cudnn(input, weight) &&
484 input.scalar_type() == kHalf && // only for FP16
485 weight.scalar_type() == kHalf &&
486 is_depthwise(input, weight) &&
487 input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
488 at::symint::size<T>(weight, 2) == at::symint::size<T>(weight, 3) && // only square kernels
489 at::symint::size<T>(input, 2) >= 7 && // min width/height 7
490 !is_dilated() && // no dilation supported
491 stride[0] == stride[1] && // equal strides
492 ((at::symint::size<T>(weight, 3) == 3) || (at::symint::size<T>(weight, 3) == 1)) &&
493 at::symint::size<T>(input, 1) >= 32); // min 32 channels supported)
494 if (kernel_cond) {
495 return check_cudnn_depthwise_workload<T>(input, stride[0]);
496 } else {
497 return false;
498 }
499 } else {
500 return false;
501 }
502 }
503
use_miopenat::native::ConvParams504 bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const {
505 if (needs_64bit_indexing_no_split(input, weight)) {
506 return false;
507 }
508 return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16))
509 && cudnn_enabled
510 && input.is_cuda()
511 && detail::getCUDAHooks().compiledWithMIOpen()
512 && input.dim() <= MIOPEN_DIM_MAX
513 && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1
514 ;
515 }
use_mkldnnat::native::ConvParams516 bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const {
517 #if AT_MKLDNN_ENABLED()
518 if (!at::globalContext().userEnabledMkldnn()) {
519 return false;
520 }
521 if (transposed && is_output_padding_big()) {
522 return false;
523 }
524 if (input.device().is_cpu() &&
525 ((input.scalar_type() == at::kBFloat16 && mkldnn_bf16_device_check()) ||
526 (input.scalar_type() == at::kHalf && mkldnn_fp16_device_check()))) {
527 return true;
528 }
529 return (input.is_mkldnn()) || // input is mkldnn Tensor
530 (input.device().is_cpu() &&
531 input.scalar_type() == kFloat && // only on CPU Float Tensors
532 // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded,
533 // but THNN is faster when single-threaded.
534 (is_strided() || is_dilated() || at::symint::size<T>(input, 0) >= 16 ||
535 at::symint::size<T>(weight, -1) != 1 || at::symint::size<T>(weight, -2) != 1 || at::get_num_threads() > 1) &&
536 (groups > 1
537 || (at::symint::size<T>(weight, -1) > 3 && at::symint::size<T>(weight, -2) > 3)
538 || at::symint::size<T>(input, 0) > 1
539 || at::symint::size<T>(input, 0)*at::symint::size<T>(input, 1)*at::symint::size<T>(input, 2)*at::symint::size<T>(input, 3) > 20480) // for some case, native is faster
540 );
541
542 #endif
543 return false;
544 }
use_nnpackat::native::ConvParams545 bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const {
546 #if AT_NNPACK_ENABLED()
547 return at::globalContext().userEnabledNNPACK() &&
548 at::_nnpack_available() &&
549 input.device().is_cpu() &&
550 input.scalar_type() == kFloat && // only on CPU Float Tensors
551 !is_dilated() && // or dilation
552 !transposed && // or transposed tensors
553 input.ndimension() == 4 && // must be in NCHW format
554 weight.ndimension() == 4 &&
555 (at::symint::size<T>(weight, 2) < 17) && (at::symint::size<T>(weight, 3) < 17) && // NNPACK only supports kernels up to 16x16
556 (padding[0] < at::symint::size<T>(weight, 2)) && (padding[1] < at::symint::size<T>(weight, 3)) // NNPACK only supports padding < kernel_size. See https://github.com/pytorch/pytorch/issues/90142.
557 #if !defined(C10_MOBILE)
558 && at::symint::size<T>(input, 0) >= 16 // ensure large enough batch size to ensure perf, tuneable
559 #endif
560 ;
561 #endif
562 return false;
563 }
use_xnnpackat::native::ConvParams564 bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight,
565 const at::OptionalArrayRef<T> bias_sizes_opt) const {
566 #if defined(C10_MOBILE)
567 if (!transposed) {
568 // NB: for the call here, it MATTERS that we are templated. If you
569 // untemplate this to always use SymInt, the function
570 // xnnpack_use_convolution2d will always return false
571 return (at::symint::size<T>(input, 1) == groups) &&
572 xnnpack_use_convolution2d(
573 input,
574 weight,
575 bias_sizes_opt,
576 padding,
577 stride,
578 dilation,
579 groups,
580 transposed);
581 }
582 #endif
583 return false;
584 }
585
use_mpsat::native::ConvParams586 bool use_mps(const at::Tensor& input, const at::Tensor& weight) const {
587 // These checks need to be expanded. Currently we have very limited set of
588 // checks for MPS.
589 #ifdef USE_MPS
590 if (needs_64bit_indexing_no_split(input, weight)) {
591 return false;
592 }
593 if (!input.is_mps()) {
594 return false;
595 }
596 return true;
597 #else
598 return false;
599 #endif
600 }
601
602 // We currently only have depthwise support for the case where groups ==
603 // nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of
604 // a depthwise multiplier)
is_depthwiseat::native::ConvParams605 bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const {
606 return input.is_cuda() &&
607 !transposed &&
608 (input.ndimension() == 4 || input.ndimension() == 5) &&
609 at::symint::size<T>(input, 1) == groups &&
610 groups > 1 && // no point if there is only a single group
611 at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0; // output channels must be a multiple of input channels
612 }
613 };
614
615 DEFINE_DISPATCH(conv_depthwise2d_backward_stub);
616 DEFINE_DISPATCH(conv_depthwise3d_backward_stub);
617 DEFINE_DISPATCH(cudnn_convolution_backward_stub);
618 DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub);
619 DEFINE_DISPATCH(slow_conv_transpose3d_backward_stub);
620 DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub);
621 DEFINE_DISPATCH(miopen_convolution_backward_stub);
622 DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub);
623 DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub);
624 DEFINE_DISPATCH(mkldnn_convolution_backward_stub);
625 DEFINE_DISPATCH(mkldnn_convolution_transpose_stub);
626 DEFINE_DISPATCH(mkldnn_convolution_transpose_backward_stub);
627 DEFINE_DISPATCH(slow_conv_dilated2d_backward_stub);
628 DEFINE_DISPATCH(slow_conv_dilated3d_backward_stub);
629 DEFINE_DISPATCH(slow_conv_transpose2d_backward_stub);
630 REGISTER_NO_CPU_DISPATCH(conv_depthwise2d_backward_stub);
631 REGISTER_NO_CPU_DISPATCH(conv_depthwise3d_backward_stub);
632 REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub);
633 REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub);
634 REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub);
635 REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub);
636 REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub);
637
638 template <typename T>
operator <<(std::ostream & out,const ConvParams<T> & params)639 std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params) {
640 out << "ConvParams {"
641 << " stride = " << IntArrayRef{params.stride}
642 << " padding = " << ArrayRef<T>{params.padding}
643 << " dilation = " << IntArrayRef{params.dilation}
644 << " transposed = " << params.transposed
645 << " output_padding = " << ArrayRef<T>{params.output_padding}
646 << " groups = " << params.groups
647 << " benchmark = " << params.benchmark
648 << " deterministic = " << params.deterministic
649 << " cudnn_enabled = " << params.cudnn_enabled
650 << " allow_tf32 = " << params.allow_tf32
651 << "}";
652 return out;
653 }
654
655 template <typename T>
check_shape_forward(const at::Tensor & input,const c10::ArrayRef<T> & weight_sizes,const at::Tensor & bias,const ConvParams<T> & params)656 static void check_shape_forward(const at::Tensor& input,
657 const c10::ArrayRef<T>& weight_sizes, const at::Tensor& bias,
658 const ConvParams<T>& params) {
659 int64_t k = input.ndimension();
660 int64_t weight_dim = weight_sizes.size();
661 auto groups = params.groups;
662 const auto& padding = params.padding;
663 const auto& dilation = params.dilation;
664 bool transposed = params.transposed;
665
666 TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported");
667 TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
668 TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
669 TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero");
670
671 TORCH_CHECK(weight_dim == k,
672 "Expected ", weight_dim, "-dimensional input for ", weight_dim,
673 "-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ",
674 at::symint::sizes<T>(input), " instead");
675 TORCH_CHECK(weight_sizes[0] >= groups,
676 "Given groups=", groups, ", expected weight to be at least ", groups,
677 " at dimension 0, but got weight of size ", weight_sizes, " instead");
678 TORCH_CHECK(weight_sizes[0] % groups == 0,
679 "Given groups=", groups, ", expected weight to be divisible by ",
680 groups, " at dimension 0, but got weight of size [", weight_sizes,
681 "] instead");
682
683 if (!transposed) {
684 std::vector<T> input_shape;
685 std::vector<T> kernel_shape;
686 bool kernel_size_correct = true;
687
688 TORCH_CHECK(at::symint::size<T>(input, 1) == (weight_sizes[1] * groups),
689 "Given groups=", groups, ", weight of size ", weight_sizes,
690 ", expected input", at::symint::sizes<T>(input), " to have ",
691 (weight_sizes[1] * groups), " channels, but got ", at::symint::size<T>(input, 1),
692 " channels instead");
693
694 TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[0]),
695 "Given weight of size ", weight_sizes,
696 ", expected bias to be 1-dimensional with ", weight_sizes[0], " elements",
697 ", but got bias of size ", at::symint::sizes<T>(bias), " instead");
698
699 for (const auto i : c10::irange(2, k)) {
700 input_shape.push_back(at::symint::size<T>(input, i) + 2 * padding[i-2]);
701 // log new kernel size considering dilation
702 kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1);
703 if (input_shape.back() < kernel_shape.back()) {
704 kernel_size_correct = false;
705 }
706 }
707
708 TORCH_CHECK(input_shape.size() == kernel_shape.size(), "Inconsistent shape between Input and Kernel");
709
710 if (!kernel_size_correct) {
711 // If kernel size is incorrect
712 std::ostringstream input_ss;
713 std::ostringstream kernel_ss;
714 std::string separator = "";
715
716 for (int i = 0, len = input_shape.size(); i < len; ++i) {
717 input_ss << separator << input_shape[i];
718 kernel_ss << separator << kernel_shape[i];
719 separator = " x ";
720 }
721
722 AT_ERROR("Calculated padded input size per channel: (", input_ss.str(), "). "
723 "Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size");
724 }
725 } else { // transposed
726 TORCH_CHECK(at::symint::size<T>(input, 1) == weight_sizes[0],
727 "Given transposed=", transposed, ", weight of size ", weight_sizes,
728 ", expected input", at::symint::sizes<T>(input), " to have ", weight_sizes[0],
729 " channels, but got ", at::symint::size<T>(input, 1), " channels instead");
730 TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[1] * groups),
731 "Given transposed=", transposed, ", weight of size ", weight_sizes,
732 ", expected bias to be 1-dimensional with ", weight_sizes[1] * groups, " elements",
733 ", but got bias of size ", at::symint::sizes<T>(bias), " instead");
734 }
735 }
736
737 template <typename T>
check_shape_backward(const at::Tensor & input,const c10::ArrayRef<T> & weight_sizes,const ConvParams<T> & params)738 static void check_shape_backward(
739 const at::Tensor& input,
740 const c10::ArrayRef<T>& weight_sizes,
741 const ConvParams<T>& params) {
742 check_shape_forward<T>(input, weight_sizes, /*bias=*/ Tensor(), params);
743 }
744
745 // Given an input tensor and an expected number of spatial dimensions, checks that the
746 // input is a valid shape and returns the batched form of the input.
747 //
748 // Args:
749 // input (Tensor): Input tensor
750 // num_spatial_dims (int): Number of spatial dimensions expected for the input
751 // func_name (string): Function name to produce a nice error message for invalid input
752 //
753 // Returns a std::tuple containing:
754 // batched_input (Tensor): Input with a batch dimension
755 // is_batched (bool): Indicates whether the original input was already batched
batchify(const Tensor & input,const int64_t num_spatial_dims,const std::string & func_name)756 static std::tuple<Tensor, bool> batchify(
757 const Tensor& input,
758 const int64_t num_spatial_dims,
759 const std::string& func_name) {
760 // assume NTs are always batched
761 if (input.is_nested()) {
762 return std::make_tuple(input, true);
763 }
764 const auto dim_count_no_batch = num_spatial_dims + 1;
765 const auto dim_count_batch = dim_count_no_batch + 1;
766 const auto is_batched = (input.dim() == dim_count_batch);
767 TORCH_CHECK(input.dim() == dim_count_no_batch || is_batched,
768 "Expected ", dim_count_no_batch, "D (unbatched) or ", dim_count_batch,
769 "D (batched) input to ", func_name, ", but got input of size: ", input.sizes());
770 return std::make_tuple(is_batched ? input : input.unsqueeze(0), is_batched);
771 }
772
check_input_same_type_as_parameters(const Tensor & input,const Tensor & weight,const Tensor & bias)773 static void check_input_same_type_as_parameters(
774 const Tensor& input,
775 const Tensor& weight,
776 const Tensor& bias) {
777 TORCH_CHECK(input.options().type_equal(weight.options()),
778 "Input type (", input.toString(), ") and weight type (", weight.toString(),
779 ") should be the same");
780 TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
781 "Input type (", input.toString(), ") and bias type (", bias.toString(),
782 ") should be the same");
783 }
784
check_input_same_type_as_parameters(const Tensor & input,const Tensor & weight)785 static void check_input_same_type_as_parameters(
786 const Tensor& input,
787 const Tensor& weight) {
788 check_input_same_type_as_parameters(input, weight, /*bias=*/ Tensor());
789 }
790
791 #if AT_MKLDNN_ENABLED()
check_input_same_type_as_parameters(const Tensor & input,const Tensor & weight,const Tensor & bias,const ConvBackend backend)792 static void check_input_same_type_as_parameters(
793 const Tensor& input,
794 const Tensor& weight,
795 const Tensor& bias,
796 const ConvBackend backend) {
797 if (backend == ConvBackend::Mkldnn || backend == ConvBackend::MkldnnTranspose) {
798 TORCH_CHECK(input.options().type_equal(weight.options())
799 || (input.is_mkldnn() && weight.device().is_cpu() && weight.scalar_type() == kFloat),
800 "Input type (", input.toString(), ") and weight type (", weight.toString(),
801 ") should be the same or input should be a MKLDNN tensor and weight is a dense tensor");
802 TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options()))
803 || (input.is_mkldnn() && bias.device().is_cpu() && bias.scalar_type() == kFloat),
804 "Input type (", input.toString(), ") and bias type (", bias.toString(),
805 ") should be the same or input should be a MKLDNN tensor and bias is a dense tensor");
806 } else {
807 check_input_same_type_as_parameters(input, weight, bias);
808 }
809 }
810 #endif
811
view4d(const at::Tensor & tensor)812 static auto view4d(const at::Tensor& tensor) -> at::Tensor {
813 TORCH_CHECK(tensor.ndimension() == 3,
814 "expected 3D tensor, got tensor with ", tensor.ndimension(),
815 " dimensions instead");
816 return tensor.unsqueeze(2);
817 }
818
view3d(const at::Tensor & tensor)819 static auto view3d(const at::Tensor& tensor) -> at::Tensor {
820 TORCH_CHECK(tensor.ndimension() == 4,
821 "expected 4D tensor, got tensor with ", tensor.ndimension(),
822 " dimensions instead");
823 return tensor.squeeze(2);
824 }
825
subtensor(at::Tensor & tensor,int64_t dim,int64_t groups,int64_t g)826 static at::Tensor subtensor(at::Tensor& tensor, int64_t dim, int64_t groups, int64_t g) {
827 if (!tensor.defined()) {
828 return at::Tensor();
829 }
830 const auto memory_format = tensor.suggest_memory_format();
831 int64_t n = tensor.sizes()[dim] / groups;
832 return tensor.narrow(dim, n * g, n).contiguous(memory_format);
833 }
834
835 namespace {
836
complex_to_real(const Tensor & inp)837 std::pair<Tensor, Tensor> complex_to_real(const Tensor& inp) {
838 auto inp_view_as_complex = at::view_as_real(inp);
839 auto dim_i = inp_view_as_complex.dim() - 1;
840 auto i_r = inp_view_as_complex.select(dim_i, 0);
841 auto i_i = inp_view_as_complex.select(dim_i, 1);
842 return std::make_pair(i_r, i_i);
843 }
844
complex_convolution(const Tensor & input,const Tensor & weight,const Tensor & bias,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef dilation,bool transposed,SymIntArrayRef output_padding,const c10::SymInt & groups)845 at::Tensor complex_convolution(
846 const Tensor& input,
847 const Tensor& weight,
848 const Tensor& bias,
849 SymIntArrayRef stride,
850 SymIntArrayRef padding,
851 SymIntArrayRef dilation,
852 bool transposed,
853 SymIntArrayRef output_padding,
854 const c10::SymInt& groups) {
855 check_input_same_type_as_parameters(input, weight, bias);
856 auto [i_r, i_i] = complex_to_real(input.resolve_conj());
857 auto [w_r, w_i] = complex_to_real(weight.resolve_conj());
858
859 // [NOTE] Complex Convolution
860 // conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0))
861 // where W, x and b are all complex inputs.
862 // With Gauss Trick:
863 // a = conv(Wr, xr, br),
864 // b = conv(Wi, xi, 0),
865 // c = conv(Wr + Wi, xr + xi, bi + br)
866 // conv(W, x, b) = a - b + i(c - a - b)
867 Tensor a, b, c;
868 if (!bias.defined()) {
869 a = at::convolution_symint(i_r, w_r, bias, stride, padding, dilation, transposed, output_padding, groups);
870 b = at::convolution_symint(i_i, w_i, bias, stride, padding, dilation, transposed, output_padding, groups);
871 c = at::convolution_symint(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, transposed, output_padding, groups);
872 } else {
873 auto [b_r, b_i] = complex_to_real(bias.resolve_conj());
874 a = at::convolution_symint(i_r, w_r, b_r, stride, padding, dilation, transposed, output_padding, groups);
875 b = at::convolution_symint(i_i, w_i, Tensor(), stride, padding, dilation, transposed, output_padding, groups);
876 c = at::convolution_symint(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, transposed, output_padding, groups);
877 }
878
879 auto i = c10::Scalar(c10::complex<double>(0, 1));
880 return a - b + i * (c - a - b);
881 }
882
complex_convolution_mode(const at::Tensor & input,const at::Tensor & weight,const std::optional<at::Tensor> & bias_opt,c10::SymIntArrayRef stride,c10::string_view padding,c10::SymIntArrayRef dilation,const c10::SymInt & groups)883 at::Tensor complex_convolution_mode(
884 const at::Tensor& input,
885 const at::Tensor& weight,
886 const std::optional<at::Tensor>& bias_opt,
887 c10::SymIntArrayRef stride,
888 c10::string_view padding,
889 c10::SymIntArrayRef dilation,
890 const c10::SymInt& groups) {
891 auto bias = bias_opt.value_or(Tensor());
892 check_input_same_type_as_parameters(input, weight, bias);
893 auto [i_r, i_i] = complex_to_real(input.resolve_conj());
894 auto [w_r, w_i] = complex_to_real(weight.resolve_conj());
895
896 // See [NOTE] Complex Convolution
897 Tensor a, b, c;
898 if (!bias.defined()) {
899 a = at::_convolution_mode_symint(i_r, w_r, bias, stride, padding, dilation, groups);
900 b = at::_convolution_mode_symint(i_i, w_i, bias, stride, padding, dilation, groups);
901 c = at::_convolution_mode_symint(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, groups);
902 } else {
903 auto [b_r, b_i] = complex_to_real(bias.resolve_conj());
904 a = at::_convolution_mode_symint(i_r, w_r, b_r, stride, padding, dilation, groups);
905 b = at::_convolution_mode_symint(i_i, w_i, Tensor(), stride, padding, dilation, groups);
906 c = at::_convolution_mode_symint(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, groups);
907 }
908
909 auto i = c10::Scalar(c10::complex<double>(0, 1));
910 return a - b + i * (c - a - b);
911 }
912
913 } // namespace
914
conv1d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef dilation,c10::SymInt groups)915 at::Tensor conv1d_symint(
916 const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
917 SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
918 // See [Note: hacky wrapper removal for optional tensor]
919 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
920 const Tensor& bias = *bias_maybe_owned;
921
922 TORCH_CHECK(
923 !bias.defined() || bias.dtype() == input_.dtype(),
924 "Input type (",
925 input_.dtype().name(),
926 ") and bias type (",
927 bias.dtype().name(),
928 ") should be the same");
929
930 auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
931 Tensor output;
932 if (at::isComplexType(input_.scalar_type())) {
933 output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups);
934 } else {
935 output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {0}, groups);
936 }
937 return is_batched ? std::move(output) : output.squeeze(0);
938 }
939
conv2d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef dilation,c10::SymInt groups)940 at::Tensor conv2d_symint(
941 const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
942 SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
943 // See [Note: hacky wrapper removal for optional tensor]
944 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
945 const Tensor& bias = *bias_maybe_owned;
946
947 TORCH_CHECK(
948 !bias.defined() || bias.dtype() == input_.dtype(),
949 "Input type (",
950 input_.dtype().name(),
951 ") and bias type (",
952 bias.dtype().name(),
953 ") should be the same");
954
955 auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
956 Tensor output;
957 if (at::isComplexType(input_.scalar_type())) {
958 output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
959 } else {
960 output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups);
961 }
962 return is_batched ? std::move(output) : output.squeeze(0);
963 }
964
conv3d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef dilation,c10::SymInt groups)965 at::Tensor conv3d_symint(
966 const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
967 SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef dilation, c10::SymInt groups) {
968 // See [Note: hacky wrapper removal for optional tensor]
969 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
970 const Tensor& bias = *bias_maybe_owned;
971
972 TORCH_CHECK(
973 !bias.defined() || bias.dtype() == input_.dtype(),
974 "Input type (",
975 input_.dtype().name(),
976 ") and bias type (",
977 bias.dtype().name(),
978 ") should be the same");
979
980 auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
981 Tensor output;
982 if (at::isComplexType(input_.scalar_type())) {
983 output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
984 } else {
985 output = at::convolution_symint(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups);
986 }
987 return is_batched ? std::move(output) : output.squeeze(0);
988 }
989
990
convolution_same(const Tensor & input,const Tensor & weight,const Tensor & bias,SymIntArrayRef stride,SymIntArrayRef dilation,const c10::SymInt & groups)991 static Tensor convolution_same(
992 const Tensor &input, const Tensor &weight, const Tensor &bias,
993 SymIntArrayRef stride, SymIntArrayRef dilation, const c10::SymInt& groups) {
994
995 auto k = weight.dim();
996 TORCH_CHECK(k > 2, "weight should have at least three dimensions");
997 TORCH_CHECK(groups > 0, "non-positive groups is not supported");
998 auto dim = static_cast<size_t>(k - 2);
999 auto weight_sizes = weight.sym_sizes();
1000 auto input_sizes = input.sym_sizes();
1001 TORCH_CHECK(k == input.dim(),
1002 "Expected ", k, "-dimensional input for ",
1003 k, "-dimensional weight", weight_sizes, ", but got ",
1004 input.dim(), "-dimensional input of size ",
1005 input.sizes(), " instead");
1006 TORCH_CHECK(stride.size() == dim || stride.size() == 1U,
1007 "stride cannot broadcast to ", dim, " dimensions");
1008 TORCH_CHECK(dilation.size() == dim || dilation.size() == 1U,
1009 "dilation cannot broadcast to ", dim, " dimensions");
1010 for (auto i: c10::irange(stride.size())) {
1011 TORCH_CHECK(stride[i] == 1, "padding='same' is not supported for strided convolutions");
1012 }
1013
1014 // Calculate the correct padding
1015 SymDimVector padding_l, padding_r;
1016 bool symmetric_padding = true;
1017 for (auto i: c10::irange(dim)) {
1018 auto s = stride.size() == 1 ? stride[0] : stride[i];
1019 auto d = dilation.size() == 1 ? dilation[0] : dilation[i];
1020 auto pad = pooling_same_mode_padding_lr(
1021 input_sizes[i + 2], weight_sizes[i + 2], s, d);
1022 padding_l.push_back(pad.first);
1023 padding_r.push_back(pad.second);
1024 if (pad.first != pad.second) {
1025 symmetric_padding = false;
1026 }
1027 }
1028
1029 if (symmetric_padding) {
1030 // All backends handle symmetric padding natively
1031 SymDimVector output_padding(static_cast<size_t>(dim));
1032 return at::convolution_symint(input, weight, bias, stride, padding_l, dilation,
1033 false, output_padding, groups);
1034 }
1035
1036 TORCH_WARN_ONCE("Using padding='same' with even kernel lengths and odd dilation may"
1037 " require a zero-padded copy of the input be created");
1038 SmallVector<c10::SymInt, kDimVectorStaticSize * 2> pad_nd(static_cast<size_t>(2 * dim));
1039 for (auto i: c10::irange(dim)) {
1040 // Apply padding by the difference, leaving only a symmetric padding
1041 auto delta_pad = padding_r[i] - padding_l[i];
1042 auto pad_idx = 2 * (dim - 1 - i); // F.pad goes from last dim to first
1043 if (delta_pad > 0) {
1044 pad_nd[pad_idx + 1] = delta_pad;
1045 } else {
1046 pad_nd[pad_idx] = delta_pad;
1047 padding_l[i] = padding_r[i];
1048 }
1049 }
1050 auto padded_input = at::constant_pad_nd_symint(input, pad_nd, 0);
1051 SymDimVector output_padding(static_cast<size_t>(dim));
1052 return at::convolution_symint(padded_input, weight, bias, stride, padding_l,
1053 dilation, false, output_padding, groups);
1054 }
1055
_convolution_mode_symint(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,c10::string_view padding,SymIntArrayRef dilation,c10::SymInt groups)1056 Tensor _convolution_mode_symint(
1057 const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1058 SymIntArrayRef stride, c10::string_view padding, SymIntArrayRef dilation,
1059 c10::SymInt groups) {
1060 // See [Note: hacky wrapper removal for optional tensor]
1061 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1062 const Tensor& bias = *bias_maybe_owned;
1063
1064 if (padding == "same") {
1065 return at::native::convolution_same(
1066 input, weight, bias, stride, dilation, groups);
1067 } else if (padding == "valid") {
1068 return at::convolution_symint(
1069 input, weight, bias, stride, {{0}}, dilation, false, {{0}}, groups);
1070 }
1071 TORCH_CHECK(false, "Invalid padding string: '", padding, "'");
1072 }
1073
conv1d_padding_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias,c10::SymIntArrayRef stride,c10::string_view padding,c10::SymIntArrayRef dilation,c10::SymInt groups)1074 at::Tensor conv1d_padding_symint(
1075 const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
1076 c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
1077 c10::SymInt groups) {
1078 auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d");
1079 Tensor output;
1080 if (at::isComplexType(input_.scalar_type())) {
1081 output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
1082 } else {
1083 output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
1084 }
1085 return is_batched ? std::move(output) : output.squeeze(0);
1086 }
1087
conv2d_padding_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias,c10::SymIntArrayRef stride,c10::string_view padding,c10::SymIntArrayRef dilation,c10::SymInt groups)1088 at::Tensor conv2d_padding_symint(
1089 const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
1090 c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
1091 c10::SymInt groups) {
1092 auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d");
1093 Tensor output;
1094 if (at::isComplexType(input_.scalar_type())) {
1095 output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
1096 } else {
1097 output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
1098 }
1099 return is_batched ? std::move(output) : output.squeeze(0);
1100 }
1101
conv3d_padding_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias,c10::SymIntArrayRef stride,c10::string_view padding,c10::SymIntArrayRef dilation,c10::SymInt groups)1102 at::Tensor conv3d_padding_symint(
1103 const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias,
1104 c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation,
1105 c10::SymInt groups) {
1106 auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d");
1107 Tensor output;
1108 if (at::isComplexType(input_.scalar_type())) {
1109 output = complex_convolution_mode(input, weight, bias, stride, padding, dilation, groups);
1110 } else {
1111 output = at::_convolution_mode_symint(input, weight, bias, stride, padding, dilation, groups);
1112 }
1113 return is_batched ? std::move(output) : output.squeeze(0);
1114 }
1115
conv_transpose1d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef output_padding,c10::SymInt groups,SymIntArrayRef dilation)1116 at::Tensor conv_transpose1d_symint(
1117 const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1118 SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
1119 // See [Note: hacky wrapper removal for optional tensor]
1120 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1121 const Tensor& bias = *bias_maybe_owned;
1122
1123 auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 1, "conv_transpose1d");
1124 Tensor output;
1125 if (at::isComplexType(input_.scalar_type())) {
1126 output = complex_convolution(
1127 input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1128 } else {
1129 output = at::convolution_symint(
1130 input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1131 }
1132 return is_batched ? std::move(output) : output.squeeze(0);
1133 }
1134
conv_transpose2d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef output_padding,c10::SymInt groups,SymIntArrayRef dilation)1135 at::Tensor conv_transpose2d_symint(
1136 const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1137 SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
1138 // See [Note: hacky wrapper removal for optional tensor]
1139 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1140 const Tensor& bias = *bias_maybe_owned;
1141
1142 auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d");
1143 Tensor output;
1144 if (at::isComplexType(input_.scalar_type())) {
1145 output = complex_convolution(
1146 input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1147 } else {
1148 output = at::convolution_symint(
1149 input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1150 }
1151 return is_batched ? std::move(output) : output.squeeze(0);
1152 }
1153
conv_transpose3d_symint(const Tensor & input_,const Tensor & weight,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride,SymIntArrayRef padding,SymIntArrayRef output_padding,c10::SymInt groups,SymIntArrayRef dilation)1154 at::Tensor conv_transpose3d_symint(
1155 const Tensor& input_, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1156 SymIntArrayRef stride, SymIntArrayRef padding, SymIntArrayRef output_padding, c10::SymInt groups, SymIntArrayRef dilation) {
1157 // See [Note: hacky wrapper removal for optional tensor]
1158 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1159 const Tensor& bias = *bias_maybe_owned;
1160
1161 auto [input, is_batched] = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d");
1162 Tensor output;
1163 if (at::isComplexType(input_.scalar_type())) {
1164 output = complex_convolution(
1165 input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1166 } else {
1167 output = at::convolution_symint(
1168 input, weight, bias, stride, padding, dilation, true, output_padding, groups);
1169 }
1170 return is_batched ? std::move(output) : output.squeeze(0);
1171 }
1172
convolution(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups)1173 at::Tensor convolution(
1174 const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1175 IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
1176 bool transposed, IntArrayRef output_padding, int64_t groups) {
1177 // See [Note: hacky wrapper removal for optional tensor]
1178 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1179 const Tensor& bias = *bias_maybe_owned;
1180
1181 auto& ctx = at::globalContext();
1182 // See Note [Enabling Deterministic Operations]
1183 bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
1184 return at::_convolution(input, weight, bias, stride, padding, dilation,
1185 transposed, output_padding, groups,
1186 ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
1187 }
1188
convolution_overrideable(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups)1189 at::Tensor convolution_overrideable(
1190 const Tensor& input, const Tensor& weight, const std::optional<Tensor>& bias_opt,
1191 IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
1192 bool transposed, IntArrayRef output_padding, int64_t groups) {
1193 TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function ");
1194 }
1195
1196 // Function to select the convolution backend based on the inputs and params.
1197 // This overload is used within the convolution internals but not exposed to python.
1198 // NB: The forward pass provides a bias tensor while the backward pass provides
1199 // a bool indicating whether the bias is defined. This is done to save memory by
1200 // avoiding saving the full bias tensor for backward.
1201 template <typename T>
_select_conv_backend(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias,const at::OptionalArrayRef<T> bias_sizes_opt,const bool need_backward,const ConvParams<T> & params)1202 ConvBackend _select_conv_backend(
1203 const Tensor& input,
1204 const Tensor& weight,
1205 const std::optional<Tensor>& bias,
1206 const at::OptionalArrayRef<T> bias_sizes_opt,
1207 const bool need_backward,
1208 const ConvParams<T>& params) {
1209
1210 // don't send empty inputs through backends
1211 if (at::symint::size<T>(input, 0) == 0 || at::symint::size<T>(input, 1) == 0) {
1212 return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty;
1213 } else if (at::symint::numel<T>(input) == 0) {
1214 TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes<T>(input));
1215 }
1216
1217 if (params.is_depthwise(input, weight)) {
1218 if (params.use_cudnn_depthwise(input, weight)) {
1219 return ConvBackend::Cudnn;
1220 } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
1221 return ConvBackend::MiopenDepthwise;
1222 } else {
1223 if (input.ndimension() == 4) {
1224 return ConvBackend::CudaDepthwise2d;
1225 } else if (input.ndimension() == 5) {
1226 return ConvBackend::CudaDepthwise3d;
1227 } else {
1228 // unsupported
1229 }
1230 }
1231 } else if (params.use_cudnn(input, weight)) {
1232 if (params.transposed) {
1233 return ConvBackend::CudnnTranspose;
1234 } else {
1235 return ConvBackend::Cudnn;
1236 }
1237 } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
1238 if (params.transposed) {
1239 return ConvBackend::MiopenTranspose;
1240 } else {
1241 return ConvBackend::Miopen;
1242 }
1243 } else if (params.use_mkldnn(input, weight)) {
1244 if (params.transposed) {
1245 return ConvBackend::MkldnnTranspose;
1246 } else {
1247 return ConvBackend::Mkldnn;
1248 }
1249 } else if (!need_backward && params.use_xnnpack(input, weight, bias_sizes_opt)) {
1250 // Using prepacked conv is preferred, but XNNPACK is still the fastest
1251 // option for NHWC.
1252 return ConvBackend::Xnnpack2d;
1253 // 3x3 depthwith convolutions implementation is inference only
1254 } else if (!need_backward && params.use_cpu_depthwise3x3_winograd(input, weight, bias)) {
1255 return ConvBackend::Winograd3x3Depthwise;
1256 } else if (
1257 !params.transposed && (input.ndimension() == 5) &&
1258 (input.device().is_cpu()) &&
1259 !params.is_dilated()) {
1260 // fast path for grouped conv3d
1261 return ConvBackend::Slow3d;
1262 } else if (input.device().is_cpu() || input.is_cuda()) {
1263 // backends without support for groups
1264 if (params.transposed) {
1265 if (input.ndimension() == 4) {
1266 return ConvBackend::SlowTranspose2d;
1267 } else if (input.ndimension() == 5) {
1268 return ConvBackend::SlowTranspose3d;
1269 } else {
1270 // unsupported
1271 }
1272 } else { /* Not transposed */
1273 if (input.ndimension() == 4) {
1274 if (params.is_dilated()) {
1275 return ConvBackend::SlowDilated2d;
1276 } else { /* dim == 4, non-dilated */
1277 if (params.use_nnpack(input, weight)) {
1278 return ConvBackend::NnpackSpatial;
1279 } else {
1280 /* CPU implementation has specialized MM kernels
1281 for non-dilated case here */
1282 return ConvBackend::Slow2d;
1283 }
1284 }
1285 } else if (input.ndimension() == 5 && (input.is_cuda() || params.is_dilated())) {
1286 return ConvBackend::SlowDilated3d;
1287 } else if (input.ndimension() == 5) { /* dim == 5, CPU, non-dilated */
1288 /* CPU implementation has specialized MM kernels
1289 for non-dilated case here */
1290 return ConvBackend::Slow3d;
1291 } else {
1292 // unsupported
1293 }
1294 }
1295 } else if (params.use_mps(input, weight)) {
1296 if (params.transposed) {
1297 return ConvBackend::MpsTranspose;
1298 } else {
1299 return ConvBackend::Mps;
1300 }
1301 } else {
1302 // Only reach here when input is backend with out-of-source implementation.
1303 return ConvBackend::Overrideable;
1304 }
1305
1306 // Error out if no suitable backend was found.
1307 AT_ERROR("unsupported ConvNd parameters");
1308 }
1309
1310 // Selects a backend for convolution based on the inputs and params.
select_conv_backend(const Tensor & input_r,const Tensor & weight_r,const std::optional<Tensor> & bias_opt,SymIntArrayRef stride_,SymIntArrayRef padding_,SymIntArrayRef dilation_,bool transposed_,SymIntArrayRef output_padding_,c10::SymInt groups_,const at::OptionalSymIntArrayRef bias_sizes_opt)1311 ConvBackend select_conv_backend(
1312 const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_opt,
1313 SymIntArrayRef stride_, SymIntArrayRef padding_, SymIntArrayRef dilation_,
1314 bool transposed_, SymIntArrayRef output_padding_, c10::SymInt groups_, const at::OptionalSymIntArrayRef bias_sizes_opt) {
1315 c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
1316 const Tensor& bias = *bias_maybe_owned;
1317
1318 auto& ctx = at::globalContext();
1319 auto k = weight_r.ndimension();
1320 int64_t dim = k - 2;
1321 ConvParams<c10::SymInt> params;
1322 params.stride = expand_param_if_needed(stride_, "stride", dim);
1323 params.padding = expand_param_if_needed(padding_, "padding", dim);
1324 params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
1325 params.transposed = transposed_;
1326 params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
1327 params.groups = std::move(groups_);
1328 params.benchmark = ctx.benchmarkCuDNN();
1329 params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
1330 params.cudnn_enabled = ctx.userEnabledCuDNN();
1331 params.allow_tf32 = ctx.allowTF32CuDNN();
1332
1333 auto input = input_r;
1334 auto weight = weight_r;
1335 check_shape_forward(input, weight.sym_sizes(), bias, params);
1336
1337 // Expand 1d -> 2d.
1338 // This is only done for backends that don't natively support 1d spatial input.
1339 if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
1340 // avoid accidentally going through NHWC for permuted 3d input.
1341 input = input.contiguous();
1342 params.view1d_as_2d();
1343 input = view4d(input);
1344 weight = view4d(weight);
1345 }
1346
1347 auto bias_sizes = bias.defined() ? std::optional<SymIntArrayRef>(bias.sym_sizes()) : bias_sizes_opt;
1348 bool need_backward = GradMode::is_enabled() &&
1349 (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
1350 return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params);
1351 }
1352
1353 // For BC reasons, have a copy that does not require bias_opt
select_conv_backend(const Tensor & input,const Tensor & weight,const at::OptionalIntArrayRef bias_sizes_opt,const bool need_backward,const ConvParams<int64_t> & params)1354 static ConvBackend select_conv_backend(
1355 const Tensor& input,
1356 const Tensor& weight,
1357 const at::OptionalIntArrayRef bias_sizes_opt,
1358 const bool need_backward,
1359 const ConvParams<int64_t>& params) {
1360 return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params);
1361 }
1362
_convolution_nogroup_backend(const Tensor & input,const Tensor & weight,const Tensor & bias,const ConvBackend backend,const ConvParams<int64_t> & params)1363 static at::Tensor _convolution_nogroup_backend(
1364 const Tensor& input,
1365 const Tensor& weight,
1366 const Tensor& bias,
1367 const ConvBackend backend,
1368 const ConvParams<int64_t>& params) {
1369 auto kernel_size = weight.sizes().slice(2);
1370 switch(backend) {
1371 case ConvBackend::NnpackSpatial:
1372 #if AT_NNPACK_ENABLED()
1373 return at::_nnpack_spatial_convolution(input, weight, bias, params.padding, params.stride);
1374 #else
1375 TORCH_INTERNAL_ASSERT(false, "NnpackSpatial backend was selected in PyTorch compiled without nnpack support");
1376 #endif
1377 case ConvBackend::Slow2d:
1378 return at::thnn_conv2d(input, weight, kernel_size, bias, params.stride, params.padding);
1379 case ConvBackend::SlowDilated2d:
1380 return at::slow_conv_dilated2d(
1381 input, weight, kernel_size, bias, params.stride, params.padding, params.dilation);
1382 case ConvBackend::SlowDilated3d:
1383 return at::slow_conv_dilated3d(
1384 input, weight, kernel_size, bias, params.stride, params.padding, params.dilation);
1385 case ConvBackend::SlowTranspose2d:
1386 return at::slow_conv_transpose2d(
1387 input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation);
1388 case ConvBackend::SlowTranspose3d:
1389 return at::slow_conv_transpose3d(
1390 input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation);
1391 default:
1392 TORCH_CHECK(false, "Unsupported conv nogroup backend encountered");
1393 }
1394 }
1395
calc_output_size(const Tensor & input,const Tensor & weight,const ConvParams<int64_t> & params)1396 static inline std::vector<int64_t> calc_output_size(
1397 const Tensor& input,
1398 const Tensor& weight,
1399 const ConvParams<int64_t>& params) {
1400 std::vector<int64_t> output_size = params.transposed ?
1401 conv_input_size(input.sizes(), weight.sizes(), params.padding, params.output_padding,
1402 params.stride, params.dilation, params.groups) :
1403 conv_output_size(input.sizes(), weight.sizes(), params.padding, params.stride, params.dilation);
1404
1405 // Handle empty # of channels.
1406 if (input.size(input_channels_dim) == 0) {
1407 output_size[output_channels_dim] = 0;
1408 }
1409 return output_size;
1410 }
1411
determine_backend_memory_format(const Tensor & input,const Tensor & weight,const ConvBackend backend)1412 static inline at::MemoryFormat determine_backend_memory_format(
1413 const Tensor& input,
1414 const Tensor& weight,
1415 const ConvBackend backend) {
1416 at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
1417 #if !defined(C10_MOBILE)
1418 auto k = weight.ndimension();
1419 // See Note [Mobile check segfaults]
1420 switch(backend) {
1421 case ConvBackend::Cudnn:
1422 case ConvBackend::CudnnTranspose:
1423 if (detail::getCUDAHooks().compiledWithCuDNN()) {
1424 backend_memory_format = cudnn_conv_suggest_memory_format(input, weight);
1425 }
1426 break;
1427 case ConvBackend::Miopen:
1428 case ConvBackend::MiopenDepthwise:
1429 case ConvBackend::MiopenTranspose:
1430 if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
1431 TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
1432 "Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
1433 backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
1434 }
1435 break;
1436 case ConvBackend::Mkldnn:
1437 case ConvBackend::MkldnnTranspose:
1438 if (mkldnn_conv_use_channels_last(input, weight)) {
1439 backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
1440 }
1441 break;
1442 case ConvBackend::Slow2d:
1443 case ConvBackend::SlowDilated2d:
1444 case ConvBackend::SlowTranspose2d:
1445 if (thnn_conv_use_channels_last(input, weight)) {
1446 backend_memory_format = at::MemoryFormat::ChannelsLast;
1447 }
1448 break;
1449 case ConvBackend::Overrideable:
1450 if (xpu_conv_use_channels_last(input, weight)) {
1451 backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
1452 }
1453 break;
1454 default:
1455 backend_memory_format = at::MemoryFormat::Contiguous;
1456 }
1457 #endif
1458 return backend_memory_format;
1459 }
1460
_determine_backend_memory_format(const Tensor & input,const Tensor & weight,const ConvBackend backend)1461 at::MemoryFormat _determine_backend_memory_format(
1462 const Tensor& input,
1463 const Tensor& weight,
1464 const ConvBackend backend) {
1465 return determine_backend_memory_format(input, weight, backend);
1466 }
1467
_convolution(const Tensor & input_r,const Tensor & weight_r,const std::optional<Tensor> & bias_r_opt,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,bool benchmark,bool deterministic,bool cudnn_enabled,bool allow_tf32)1468 at::Tensor _convolution(
1469 const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
1470 IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
1471 bool transposed_, IntArrayRef output_padding_, int64_t groups_,
1472 bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) {
1473 // See [Note: hacky wrapper removal for optional tensor]
1474 c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
1475 const Tensor& bias_r = *bias_r_maybe_owned;
1476
1477 auto input = input_r;
1478 auto weight = weight_r;
1479 auto bias = bias_r;
1480 auto k = weight.ndimension();
1481 c10::IntArrayRef weight_sizes = weight.sizes();
1482 int64_t dim = k - 2;
1483
1484 TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
1485 TORCH_CHECK(groups_ > 0, "non-positive groups is not supported");
1486
1487 ConvParams<int64_t> params;
1488 params.stride = expand_param_if_needed(stride_, "stride", dim);
1489 params.padding = expand_param_if_needed(padding_, "padding", dim);
1490 params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
1491 params.transposed = transposed_;
1492 params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
1493 params.groups = groups_;
1494 params.benchmark = benchmark;
1495 params.deterministic = deterministic;
1496 params.cudnn_enabled = cudnn_enabled;
1497 params.allow_tf32 = allow_tf32;
1498
1499 check_shape_forward(input, weight_sizes, bias, params);
1500
1501 // Expand 1d -> 2d.
1502 // This is only done for backends that don't natively support 1d spatial input.
1503 if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
1504 // avoid accidentally going through NHWC for permuted 3d input.
1505 input = input.contiguous();
1506 params.view1d_as_2d();
1507 input = view4d(input);
1508 weight = view4d(weight);
1509 }
1510
1511 // Select appropriate backend to use.
1512 auto bias_sizes_opt = bias.defined() ? std::optional<IntArrayRef>(bias.sizes()) : std::nullopt;
1513 bool need_backward = GradMode::is_enabled() &&
1514 (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad()));
1515 ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params);
1516 at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
1517
1518 // Call the backend.
1519 Tensor output;
1520 auto kernel_size = weight.sizes().slice(2);
1521 switch (backend) {
1522 case ConvBackend::CudaDepthwise2d:
1523 output = at::_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias,
1524 params.stride, params.padding, params.dilation);
1525 break;
1526 case ConvBackend::CudaDepthwise3d:
1527 output = at::conv_depthwise3d(input.contiguous(), weight, kernel_size, bias,
1528 params.stride, params.padding, params.dilation);
1529 break;
1530 case ConvBackend::Cudnn:
1531 check_input_same_type_as_parameters(input, weight, bias);
1532 output = at::cudnn_convolution(
1533 input.contiguous(backend_memory_format), weight, params.padding, params.stride,
1534 params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
1535 if (bias.defined()) {
1536 output.add_(reshape_bias(input.dim(), bias));
1537 }
1538 break;
1539 case ConvBackend::CudnnTranspose:
1540 check_input_same_type_as_parameters(input, weight, bias);
1541 output = at::cudnn_convolution_transpose(
1542 input.contiguous(backend_memory_format), weight, params.padding, params.output_padding,
1543 params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
1544 if (bias.defined()) {
1545 output.add_(reshape_bias(input.dim(), bias));
1546 }
1547 break;
1548 case ConvBackend::Empty:
1549 {
1550 Tensor weight_view;
1551 // Use permute and clone to avoid at::_unsafe_view(weight, -1) failure for non-contiguous cases where
1552 // view size is not compatible with input tensor's size and stride.
1553 if(weight.is_contiguous()) {
1554 weight_view = at::_unsafe_view(weight, -1);
1555 } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast)) {
1556 weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 1}), -1);
1557 } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
1558 weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 4, 1}), -1);
1559 } else {
1560 weight_view = at::_unsafe_view(weight.clone(at::MemoryFormat::Contiguous), -1);
1561 }
1562
1563 output = (input.size(1) == 0) ? (input.view(-1) * weight_view) : (input * weight_view[0]);
1564 if (bias.defined()) {
1565 output.add_(bias[0]);
1566 }
1567 output = output.view(calc_output_size(input, weight, params));
1568 break;
1569 }
1570 case ConvBackend::Miopen:
1571 check_input_same_type_as_parameters(input, weight, bias);
1572 output = at::miopen_convolution(
1573 input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride,
1574 params.dilation, params.groups, params.benchmark, params.deterministic);
1575 break;
1576 case ConvBackend::MiopenDepthwise:
1577 output = at::miopen_depthwise_convolution(
1578 input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride,
1579 params.dilation, params.groups, params.benchmark, params.deterministic);
1580 break;
1581 case ConvBackend::MiopenTranspose:
1582 check_input_same_type_as_parameters(input, weight, bias);
1583 output = at::miopen_convolution_transpose(
1584 input.contiguous(backend_memory_format), weight, bias, params.padding, params.output_padding,
1585 params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
1586 break;
1587 case ConvBackend::Mkldnn:
1588 #if AT_MKLDNN_ENABLED()
1589 check_input_same_type_as_parameters(input, weight, bias, backend);
1590 if (!input.is_mkldnn()) {
1591 // need to ensure contiguous for non-mkldnn tensors
1592 input = input.contiguous(backend_memory_format);
1593 weight = weight.contiguous(backend_memory_format);
1594 bias = bias.defined() ? bias.contiguous() : bias;
1595 }
1596 output = at::mkldnn_convolution(
1597 input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
1598 #else
1599 TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
1600 #endif
1601 break;
1602 case ConvBackend::MkldnnTranspose:
1603 #if AT_MKLDNN_ENABLED()
1604 check_input_same_type_as_parameters(input, weight, bias, backend);
1605 if (!input.is_mkldnn()) {
1606 // need to ensure contiguous for non-mkldnn tensors
1607 input = input.contiguous(backend_memory_format);
1608 weight = weight.contiguous(backend_memory_format);
1609 bias = bias.defined() ? bias.contiguous() : bias;
1610 }
1611 output = mkldnn_convolution_transpose_stub(input.device().type(),
1612 input, weight, bias, params.padding, params.output_padding, params.stride, params.dilation, params.groups);
1613 #else
1614 TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
1615 #endif
1616 break;
1617 case ConvBackend::MkldnnEmpty:
1618 #if AT_MKLDNN_ENABLED()
1619 output = empty_mkldnn(
1620 calc_output_size(input, weight, params), optTypeMetaToScalarType(input.options().dtype_opt()),
1621 input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt());
1622 #else
1623 TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
1624 #endif
1625 break;
1626 case ConvBackend::Overrideable:
1627 output = at::convolution_overrideable(
1628 input, weight, bias, params.stride, params.padding, params.dilation, params.transposed,
1629 params.output_padding, params.groups);
1630 break;
1631 case ConvBackend::Slow3d:
1632 output = at::slow_conv3d(input, weight, kernel_size, bias, params.stride, params.padding);
1633 break;
1634 case ConvBackend::Winograd3x3Depthwise:
1635 output = convolution_depthwise3x3_winograd_stub(
1636 input.device().type(), input, weight, bias, params.stride, params.padding, params.groups);
1637 break;
1638 case ConvBackend::Xnnpack2d:
1639 output = xnnpack::convolution2d(
1640 input, weight, bias, params.padding, params.stride, params.dilation, params.groups);
1641 break;
1642 // Handle backends that don't natively support groups > 1.
1643 case ConvBackend::NnpackSpatial:
1644 case ConvBackend::Slow2d:
1645 case ConvBackend::SlowDilated2d:
1646 case ConvBackend::SlowDilated3d:
1647 case ConvBackend::SlowTranspose2d:
1648 case ConvBackend::SlowTranspose3d:
1649 input = input.contiguous(backend_memory_format);
1650 weight = weight.contiguous(backend_memory_format);
1651 if (params.groups == 1) {
1652 output = _convolution_nogroup_backend(input, weight, bias, backend, params);
1653 } else {
1654 std::vector<Tensor> outputs(params.groups);
1655 for (const auto g : c10::irange(params.groups)) {
1656 auto input_g = subtensor(input, 1, params.groups, g);
1657 auto weight_g = subtensor(weight, 0, params.groups, g);
1658 auto bias_g = subtensor(bias, 0, params.groups, g);
1659 outputs[g] = _convolution_nogroup_backend(input_g, weight_g, bias_g, backend, params);
1660 }
1661 output = at::cat(outputs, 1);
1662 }
1663 break;
1664 case ConvBackend::Mps:
1665 #ifdef USE_MPS
1666 TORCH_CHECK(input.options().type_equal(weight.options()),
1667 "Input type (", input.toString(), ") and weight type (", weight.toString(),
1668 ") should be the same");
1669 TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
1670 "Input type (", input.toString(), ") and bias type (", bias.toString(),
1671 ") should be the same");
1672
1673 output = at::_mps_convolution(input, weight, bias.defined() ? bias.contiguous() : bias,
1674 params.padding, params.stride, params.dilation,
1675 params.groups);
1676 #else
1677 TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
1678 #endif
1679 break;
1680 case ConvBackend::MpsTranspose:
1681 #ifdef USE_MPS
1682 TORCH_CHECK(input.options().type_equal(weight.options()),
1683 "Input type (", input.toString(), ") and weight type (", weight.toString(),
1684 ") should be the same");
1685 TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
1686 "Input type (", input.toString(), ") and bias type (", bias.toString(),
1687 ") should be the same");
1688 output = at::_mps_convolution_transpose(
1689 input.contiguous(backend_memory_format), weight,
1690 params.padding, params.output_padding,
1691 params.stride, params.dilation, params.groups);
1692 if (bias.defined()) {
1693 output.add_(reshape_bias(input.dim(), bias));
1694 }
1695 #else
1696 TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
1697 #endif
1698 break;
1699 }
1700
1701 if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
1702 output = view3d(output);
1703 }
1704
1705 return output;
1706 }
1707
_convolution(const Tensor & input_r,const Tensor & weight_r,const std::optional<Tensor> & bias_r_opt,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,bool benchmark,bool deterministic,bool cudnn_enabled)1708 at::Tensor _convolution(
1709 const Tensor& input_r, const Tensor& weight_r, const std::optional<Tensor>& bias_r_opt,
1710 IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
1711 bool transposed_, IntArrayRef output_padding_, int64_t groups_,
1712 bool benchmark, bool deterministic, bool cudnn_enabled)
1713 {
1714 // See [Note: hacky wrapper removal for optional tensor]
1715 c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
1716 const Tensor& bias_r = *bias_r_maybe_owned;
1717
1718 return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN());
1719 }
1720
convolution_backward_overrideable(const Tensor & grad_output,const Tensor & input,const Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,std::array<bool,3> output_mask)1721 std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
1722 const Tensor& grad_output, const Tensor& input, const Tensor& weight,
1723 IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation,
1724 bool transposed, IntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) {
1725 TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_backward_overrideable: You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function ");
1726 return std::tuple<Tensor, Tensor, Tensor>(
1727 at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT),
1728 at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT),
1729 at::empty({}));
1730 }
1731
subvariable(const Tensor & var,int64_t dim,int64_t groups,int64_t g)1732 static Tensor subvariable(const Tensor& var, int64_t dim, int64_t groups, int64_t g) {
1733 int64_t n = var.sizes()[dim] / groups;
1734 auto result = var.narrow(dim, n * g, n);
1735 return result;
1736 }
1737
_convolution_double_backward(const std::optional<Tensor> & ggI_opt,const std::optional<Tensor> & ggW_r_opt,const std::optional<Tensor> & ggb_opt,const Tensor & gO_r,const Tensor & weight_r,const Tensor & input,IntArrayRef stride_,IntArrayRef padding_,IntArrayRef dilation_,bool transposed_,IntArrayRef output_padding_,int64_t groups_,std::array<bool,3> output_mask)1738 std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const std::optional<Tensor>& ggI_opt, const std::optional<Tensor>& ggW_r_opt, const std::optional<Tensor>& ggb_opt,
1739 const Tensor& gO_r, const Tensor& weight_r, const Tensor& input,
1740 IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_,
1741 bool transposed_, IntArrayRef output_padding_, int64_t groups_,
1742 std::array<bool, 3> output_mask) {
1743 // See [Note: hacky wrapper removal for optional tensor]
1744 c10::MaybeOwned<Tensor> ggI_maybe_owned = at::borrow_from_optional_tensor(ggI_opt);
1745 const Tensor& ggI = *ggI_maybe_owned;
1746 const Tensor& ggW_r = c10::value_or_else(ggW_r_opt, [] {return Tensor();});
1747 const Tensor& ggb = c10::value_or_else(ggb_opt, [] {return Tensor();});
1748
1749
1750 auto ggW = ggW_r;
1751 auto gO = gO_r;
1752 auto weight = weight_r;
1753
1754 int64_t dim = weight.ndimension() - 2;
1755 ConvParams<int64_t> params;
1756 params.stride = expand_param_if_needed(stride_, "stride", dim);
1757 params.padding = expand_param_if_needed(padding_, "padding", dim);
1758 params.dilation = expand_param_if_needed(dilation_, "dilation", dim);
1759 params.transposed = transposed_;
1760 params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim);
1761 // TODO: hacky way of inferring the groups number for grouped Conv3D
1762 // See: https://github.com/pytorch/pytorch/pull/36355
1763 if (!params.transposed && input.dim() > 4) {
1764 // Avoid undefined behavior when num channels == 0; params are unused for that case.
1765 params.groups = (weight.size(1) > 0) ? input.size(1) / weight.size(1) : -1;
1766 } else {
1767 params.groups = groups_;
1768 }
1769
1770 // Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb
1771 Tensor ggO;
1772 if (input.numel() != 0) {
1773 if (ggI.defined()) {
1774 if (weight.is_cuda()) {
1775 weight = weight.contiguous();
1776 }
1777 ggO = at::convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
1778 }
1779
1780 if (ggW.defined()) {
1781 if (ggW.is_cuda()) {
1782 ggW = ggW.contiguous();
1783 }
1784 auto ggW_term = at::convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups);
1785 if (ggO.defined()) {
1786 ggO = ggO + ggW_term;
1787 } else {
1788 ggO = ggW_term;
1789 }
1790 }
1791 }
1792
1793 if (ggb.defined()) {
1794 // View as (1, ggb.size(0), 1, 1...)
1795
1796 // Expand
1797 std::vector<int64_t> new_size(gO.ndimension(), 1);
1798 new_size[1] = ggb.sizes()[0];
1799 auto ggb_contiguous = ggb.contiguous();
1800 auto ggb_view = ggb_contiguous.view(new_size);
1801
1802 // Expand
1803 auto ggb_expanded = ggb_view.expand(gO.sizes());
1804
1805 if (ggO.defined()) {
1806 ggO = ggO + ggb_expanded;
1807 } else {
1808 ggO = ggb_expanded;
1809 }
1810 }
1811
1812 // Compute gW = conv(ggI, gO)
1813 Tensor gW;
1814 if (ggI.defined()) {
1815
1816 // Modified params with correct padding
1817 ConvParams<int64_t> gw_conv_params(params);
1818
1819 // Disable groups as they are handled separately
1820 auto groups = gw_conv_params.groups;
1821 gw_conv_params.groups = 1;
1822 std::swap(gw_conv_params.dilation, gw_conv_params.stride);
1823
1824 // Transpose gO and ggI to accumulate over batch
1825 auto gOt = gO.transpose(0, 1);
1826 auto ggIt = ggI.transpose(0, 1);
1827
1828 Tensor gWt;
1829 // Compute conv
1830 if (input.numel() != 0) {
1831 if (groups == 1) {
1832
1833 if (gOt.is_cuda()) {
1834 gOt = gOt.contiguous();
1835 }
1836 // Compute conv
1837 if (params.transposed) {
1838 gw_conv_params.transposed = false;
1839 gWt = at::convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
1840 } else {
1841 gWt = at::convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
1842 }
1843 } else {
1844 std::vector<Tensor> gWt_list(groups);
1845 for (const auto g : c10::irange(groups)) {
1846 auto ggIt_g = subvariable(ggIt, 0, groups, g);
1847 auto gOt_g = subvariable(gOt, 0, groups, g);
1848 if (gOt_g.is_cuda()) {
1849 gOt_g = gOt_g.contiguous();
1850 }
1851
1852 // Compute conv
1853 if (params.transposed) {
1854 gw_conv_params.transposed = false;
1855 gWt_list[g] = at::convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
1856 } else {
1857 gWt_list[g] = at::convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups);
1858 }
1859 }
1860
1861 gWt = at::cat(gWt_list, 1);
1862 }
1863
1864 // Transpose gW to match chan_in and chan_out
1865 gW = gWt.transpose(0, 1);
1866
1867 // narrow gW to only relevant portion
1868 // we do it this way instead of narrowing the input itself because
1869 // the ConvForward kernels don't support asymmetric padding.
1870 auto gW_size = gW.sizes();
1871 auto w_size = weight.sizes();
1872 for (const auto i : c10::irange(2, static_cast<int64_t>(gW_size.size()))) {
1873 if (gW_size[i] > w_size[i]) {
1874 gW = gW.narrow(i, 0, w_size[i]);
1875 gW_size = gW.sizes();
1876 }
1877 }
1878 }
1879 }
1880
1881 // Compute gI = convT(gO, ggW) if !transposed
1882 // gI = conv(gO, ggw) if transposed
1883 Tensor gI;
1884 if (input.numel() != 0) {
1885 if (ggW.defined()) {
1886 ConvParams<int64_t> gi_conv_params(params);
1887 gi_conv_params.transposed = !params.transposed;
1888
1889 if (params.transposed) {
1890 if (gO.is_cuda()) {
1891 gO = gO.contiguous();
1892 }
1893 gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);
1894
1895 // narrow gI to only relevant portion
1896 // we do it this way because negative output_padding is not supported
1897 // TODO: figure out if we can narrow gO and save some compute,
1898 // rather than narrowing the computed gI
1899 auto gI_size = gI.sizes();
1900 auto i_size = input.sizes();
1901 for (const auto i : c10::irange(2, static_cast<int64_t>(gI_size.size()))) {
1902 if (gI_size[i] > i_size[i]) {
1903 gI = gI.narrow(i, 0, i_size[i]);
1904 gI_size = gI.sizes();
1905 }
1906 }
1907 } else {
1908 // calculate output_padding
1909 // TODO: figure out why this needs to be computed...
1910 auto kernel_size = weight.sizes().slice(2);
1911 auto input_shape = input.sizes().slice(2);
1912 auto grad_output_shape = gO.sizes().slice(2);
1913
1914 for (const auto i : c10::irange(kernel_size.size())) {
1915 // Check if whole input has been used or not
1916 auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.dilation[i]
1917 - 2 * gi_conv_params.padding[i]
1918 + (gi_conv_params.stride[i] * (grad_output_shape[i] - 1) + 1);
1919 if (expected_input_shape != input_shape[i]) {
1920 gi_conv_params.output_padding[i] = input_shape[i] - expected_input_shape;
1921 }
1922 }
1923
1924 if (gO.is_cuda()) {
1925 gO = gO.contiguous();
1926 }
1927
1928 gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups);
1929 }
1930 }
1931 }
1932
1933 return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
1934 }
1935
_convolution_backward_nogroup_backend(const Tensor & grad_output,const Tensor & input,const Tensor & weight,const std::array<bool,3> output_mask,const ConvBackend backend,const ConvParams<int64_t> & params)1936 static std::tuple<at::Tensor, at::Tensor, at::Tensor> _convolution_backward_nogroup_backend(
1937 const Tensor& grad_output,
1938 const Tensor& input,
1939 const Tensor& weight,
1940 const std::array<bool, 3> output_mask,
1941 const ConvBackend backend,
1942 const ConvParams<int64_t>& params) {
1943 auto kernel_size = weight.sizes().slice(2);
1944 switch(backend) {
1945 case ConvBackend::Slow2d:
1946 return at::_slow_conv2d_backward(
1947 grad_output, input, weight, kernel_size, params.stride, params.padding, output_mask);
1948 // NB: nnpack backward does not support strided convolutions; use slow impl instead
1949 case ConvBackend::NnpackSpatial:
1950 case ConvBackend::SlowDilated2d:
1951 return slow_conv_dilated2d_backward_stub(
1952 input.device().type(),
1953 grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask);
1954 case ConvBackend::SlowDilated3d:
1955 return slow_conv_dilated3d_backward_stub(
1956 input.device().type(),
1957 grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask);
1958 case ConvBackend::SlowTranspose2d:
1959 return slow_conv_transpose2d_backward_stub(
1960 input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding,
1961 params.output_padding, params.dilation, output_mask);
1962 case ConvBackend::SlowTranspose3d:
1963 return slow_conv_transpose3d_backward_stub(
1964 input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding,
1965 params.output_padding, params.dilation, output_mask);
1966 default:
1967 TORCH_CHECK(false, "Unsupported conv nogroup backend encountered");
1968 }
1969 }
1970
1971 // Backward pass for convolution. Computes gradients for input, weight, and bias depending on the
1972 // output_mask setting. This function supports 1D, 2D, or 3D spatial convolution and currently requires
1973 // a single batch dimension to be present.
1974 //
1975 // Args:
1976 // grad_output_: tensor of shape (N, C_out, L_out), (N, C_out, H_out, W_out), or (N, C_out, D_out, H_out, W_out)
1977 // input_: tensor of shape (N, C_in, L_in), (N, C_in, H_in, W_in), or (N, C_in, D_in, H_in, W_in)
1978 // weight_: tensor of shape (C_out, C_in // groups, *kernel_size); dimension of kernel_size must match the number
1979 // of input spatial dimensions
1980 // bias_sizes_opt: if specified, indicates that a bias was used in the forward pass and contains the shape
1981 // of the bias. While the bias shape can be computed from other inputs, it is provided to this function for
1982 // ease of use. The bias shape is (weight.shape[0]) for normal convolution and (weight.shape[1] * groups)
1983 // for transposed convolution.
1984 // stride: single value or an array with dimension matching the number of input spatial dimensions
1985 // padding: single value or an array with dimension matching the number of input spatial dimensions
1986 // dilation: single value or an array with dimension matching the number of input spatial dimensions
1987 // transposed: boolean indicating whether the convolution is transposed
1988 // output_padding: single value or dimension == number of input spatial dimensions; only supported when
1989 // transposed is true
1990 // groups: number of groups for grouped convolution
1991 // output_mask: 3-dim boolean array specifying which gradients to compute in input, weight, bias order
convolution_backward(const Tensor & grad_output_,const Tensor & input_,const Tensor & weight_,const at::OptionalIntArrayRef bias_sizes_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,std::array<bool,3> output_mask)1992 std::tuple<Tensor, Tensor, Tensor> convolution_backward(
1993 const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
1994 const at::OptionalIntArrayRef bias_sizes_opt,
1995 IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding,
1996 int64_t groups, std::array<bool, 3> output_mask) {
1997 auto grad_output = grad_output_;
1998 auto input = input_;
1999 auto weight = weight_;
2000
2001 auto k = weight.ndimension();
2002 int64_t dim = k - 2;
2003
2004 TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
2005
2006 auto& ctx = at::globalContext();
2007 ConvParams<int64_t> params;
2008 params.stride = expand_param_if_needed(stride, "stride", dim);
2009 params.padding = expand_param_if_needed(padding, "padding", dim);
2010 params.dilation = expand_param_if_needed(dilation, "dilation", dim);
2011 params.transposed = transposed;
2012 params.output_padding = expand_param_if_needed(output_padding, "output_padding", dim);
2013 params.groups = groups;
2014 params.benchmark = ctx.benchmarkCuDNN();
2015 params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
2016 params.cudnn_enabled = ctx.userEnabledCuDNN();
2017 params.allow_tf32 = ctx.allowTF32CuDNN();
2018
2019 // Validate inputs.
2020 check_shape_backward(input, weight.sizes(), params);
2021 TORCH_CHECK(input.dim() == grad_output.dim(),
2022 "Expected input and grad_output to have the same number of dimensions, but got: ",
2023 input.dim(), " and ", grad_output.dim());
2024
2025 // output_padding is only supported for transposed convolutions
2026 if (!params.transposed) {
2027 for (auto pad : params.output_padding) {
2028 TORCH_CHECK(pad == 0, "output_padding is not supported for non-transposed convolutions; got: ",
2029 params.output_padding);
2030 }
2031 }
2032
2033 // Expand 1d -> 2d.
2034 // This is only done for backends that don't natively support 1d spatial input.
2035 if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
2036 // avoid accidentally going through NHWC for permuted 3d input.
2037 input = input.contiguous();
2038 params.view1d_as_2d();
2039 grad_output = view4d(grad_output);
2040 input = view4d(input);
2041 weight = view4d(weight);
2042 }
2043
2044 // Select appropriate backend to use.
2045 ConvBackend backend = select_conv_backend(input, weight, bias_sizes_opt, /*need_backward=*/ true, params);
2046 at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend);
2047
2048 // Call the backend.
2049 Tensor backend_grad_input, backend_grad_weight, backend_grad_bias;
2050 auto kernel_size = weight.sizes().slice(2);
2051 switch(backend) {
2052 case ConvBackend::CudaDepthwise2d:
2053 {
2054 std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
2055 std::tie(backend_grad_input, backend_grad_weight) =
2056 conv_depthwise2d_backward_stub(input.device().type(), grad_output, input,
2057 weight, kernel_size, params.stride, params.padding, params.dilation, input_weight_output_mask);
2058 break;
2059 }
2060 case ConvBackend::CudaDepthwise3d:
2061 TORCH_CHECK(input.ndimension() == 5);
2062 std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2063 conv_depthwise3d_backward_stub(
2064 input.device().type(), grad_output, input, weight, kernel_size, params.stride,
2065 params.padding, params.dilation, output_mask);
2066 break;
2067 case ConvBackend::Cudnn:
2068 {
2069 check_input_same_type_as_parameters(input, weight);
2070 std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
2071 std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_backward_stub(
2072 input.device().type(),
2073 // Only make input contiguous when it is necessary for the backwards computation
2074 output_mask[1] ? input.contiguous(backend_memory_format) : input,
2075 grad_output, weight, params.padding, params.stride,
2076 params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
2077 input_weight_output_mask);
2078 break;
2079 }
2080 case ConvBackend::Mps:
2081 {
2082 #ifdef USE_MPS
2083 check_input_same_type_as_parameters(input, weight);
2084 std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2085 at::mps_convolution_backward(input, grad_output, weight, params.padding,
2086 params.stride, params.dilation, params.groups, output_mask);
2087 #else
2088 TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
2089 #endif
2090 break;
2091 }
2092 case ConvBackend::MpsTranspose:
2093 {
2094 #ifdef USE_MPS
2095 check_input_same_type_as_parameters(input, weight);
2096 std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
2097 std::tie(backend_grad_input, backend_grad_weight) = at::mps_convolution_transpose_backward(
2098 // Only make input contiguous when it is necessary for the backwards computation
2099 output_mask[1] ? input.contiguous(backend_memory_format) : input,
2100 grad_output, weight, params.padding, params.output_padding,
2101 params.stride, params.dilation, params.groups, input_weight_output_mask);
2102 #else
2103 TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support");
2104 #endif
2105 break;
2106 }
2107 case ConvBackend::CudnnTranspose:
2108 {
2109 check_input_same_type_as_parameters(input, weight);
2110 std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]};
2111 std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_transpose_backward_stub(
2112 input.device().type(),
2113 // Only make input contiguous when it is necessary for the backwards computation
2114 output_mask[1] ? input.contiguous(backend_memory_format) : input,
2115 grad_output, weight, params.padding, params.output_padding,
2116 params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32,
2117 input_weight_output_mask);
2118 break;
2119 }
2120 case ConvBackend::Empty:
2121 if (output_mask[0]) {
2122 backend_grad_input = at::zeros_like(input);
2123 }
2124 if (output_mask[1]) {
2125 backend_grad_weight = at::zeros_like(weight);
2126 }
2127 if (output_mask[2]) {
2128 backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options());
2129 }
2130 break;
2131 case ConvBackend::MkldnnEmpty:
2132 #if AT_MKLDNN_ENABLED()
2133 if (output_mask[0]) {
2134 if (input.is_mkldnn()) {
2135 backend_grad_input = empty_mkldnn(input.sizes(), optTypeMetaToScalarType(input.options().dtype_opt()),
2136 input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt());
2137 backend_grad_input.zero_();
2138 } else {
2139 backend_grad_input = at::zeros_like(input);
2140 }
2141 }
2142 if (output_mask[1]) {
2143 // mkldnn weight is not supported during training by the mkldnn backend
2144 backend_grad_weight = at::zeros_like(weight);
2145 }
2146 if (output_mask[2]) {
2147 // mkldnn bias is not supported during training by the mkldnn backend
2148 backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options());
2149 }
2150 #else
2151 TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support");
2152 #endif
2153 break;
2154 case ConvBackend::Miopen:
2155 check_input_same_type_as_parameters(input, weight);
2156 std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2157 miopen_convolution_backward_stub(
2158 input.device().type(),
2159 input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride,
2160 params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
2161 break;
2162 case ConvBackend::MiopenDepthwise:
2163 std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2164 miopen_depthwise_convolution_backward_stub(
2165 input.device().type(),
2166 input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride,
2167 params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
2168 break;
2169 case ConvBackend::MiopenTranspose:
2170 check_input_same_type_as_parameters(input, weight);
2171 std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2172 miopen_convolution_transpose_backward_stub(
2173 input.device().type(),
2174 input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.output_padding,
2175 params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, output_mask);
2176 break;
2177 case ConvBackend::Mkldnn:
2178 TORCH_CHECK(!weight.is_mkldnn(),
2179 "The MKLDNN backend does not support weight as an MKLDNN tensor during training");
2180 if (!input.is_mkldnn()) {
2181 input = input.contiguous(backend_memory_format);
2182 weight = weight.contiguous(backend_memory_format);
2183 }
2184 std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2185 mkldnn_convolution_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
2186 params.stride, params.dilation, params.groups, output_mask);
2187 break;
2188 case ConvBackend::MkldnnTranspose:
2189 TORCH_CHECK(!weight.is_mkldnn(),
2190 "The MKLDNN backend does not support weight as an MKLDNN tensor during training");
2191 if (!input.is_mkldnn()) {
2192 input = input.contiguous(backend_memory_format);
2193 weight = weight.contiguous(backend_memory_format);
2194 }
2195 std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2196 mkldnn_convolution_transpose_backward_stub(input.device().type(), input, grad_output, weight, params.padding,
2197 params.output_padding, params.stride, params.dilation, params.groups, output_mask);
2198 break;
2199 case ConvBackend::Overrideable:
2200 // Only reach here when input is backend with out-of-source implementation.
2201 std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2202 at::convolution_backward_overrideable(grad_output, input, weight, params.stride, params.padding,
2203 params.dilation, params.transposed, params.output_padding, params.groups, output_mask);
2204 break;
2205 case ConvBackend::Slow3d:
2206 // Note that no CUDA implementation of this kernel exists currently.
2207 std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2208 slow_conv3d_backward_cpu(
2209 grad_output, input, weight, kernel_size,
2210 params.stride, params.padding, output_mask);
2211 break;
2212 // Handle backends that don't natively support groups > 1.
2213 case ConvBackend::NnpackSpatial:
2214 case ConvBackend::Slow2d:
2215 case ConvBackend::SlowDilated2d:
2216 case ConvBackend::SlowDilated3d:
2217 case ConvBackend::SlowTranspose2d:
2218 case ConvBackend::SlowTranspose3d:
2219 {
2220 input = input.contiguous(backend_memory_format);
2221 weight = weight.contiguous(backend_memory_format);
2222 if (params.groups == 1) {
2223 std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) =
2224 _convolution_backward_nogroup_backend(
2225 grad_output, input, weight, output_mask, backend, params);
2226 } else {
2227 std::vector<Tensor> backend_grad_inputs(params.groups);
2228 std::vector<Tensor> backend_grad_weights(params.groups);
2229 std::vector<Tensor> backend_grad_biases(params.groups);
2230 for (int g = 0; g < params.groups; ++g) {
2231 auto grad_output_g = subtensor(grad_output, 1, params.groups, g);
2232 auto input_g = subtensor(input, 1, params.groups, g);
2233 auto weight_g = subtensor(weight, 0, params.groups, g);
2234 std::tie(backend_grad_inputs[g], backend_grad_weights[g], backend_grad_biases[g]) =
2235 _convolution_backward_nogroup_backend(
2236 grad_output_g, input_g, weight_g, output_mask, backend, params);
2237 }
2238 if (output_mask[0]) {
2239 backend_grad_input = at::cat(backend_grad_inputs, 1);
2240 }
2241 if (output_mask[1]) {
2242 backend_grad_weight = at::cat(backend_grad_weights, 0);
2243 }
2244 if (output_mask[2]) {
2245 backend_grad_bias = at::cat(backend_grad_biases, 0);
2246 }
2247 }
2248 break;
2249 }
2250 // Backward is not supported for these backends.
2251 case ConvBackend::Winograd3x3Depthwise:
2252 TORCH_CHECK(false, "Backward is not supported for depthwise 3x3 winograd");
2253 break;
2254 case ConvBackend::Xnnpack2d:
2255 TORCH_CHECK(false, "Backward is not supported for xnnpack");
2256 break;
2257 }
2258
2259 // Convert 2D inputs back to 1D for backends that don't natively support 1D
2260 // spatial inputs.
2261 if (output_mask[0]) {
2262 if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
2263 backend_grad_input = view3d(backend_grad_input);
2264 }
2265 }
2266 if (output_mask[1]) {
2267 if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) {
2268 backend_grad_weight = view3d(backend_grad_weight);
2269 }
2270 }
2271 if (output_mask[2]) {
2272 if (!backend_grad_bias.defined()) {
2273 // Calculate bias gradients outside of the backend for those that don't support it.
2274 backend_grad_bias = grad_output.sum((dim == 3) ? IntArrayRef{0, 2, 3, 4} : IntArrayRef{0, 2, 3});
2275 }
2276 }
2277
2278 return std::make_tuple(backend_grad_input, backend_grad_weight, backend_grad_bias);
2279 }
2280
_cudnn_set_conv_benchmark_empty_cache(bool enable)2281 void _cudnn_set_conv_benchmark_empty_cache(bool enable) {
2282 conv_benchmark_empty_cache = enable;
2283 }
2284
_cudnn_get_conv_benchmark_empty_cache()2285 bool _cudnn_get_conv_benchmark_empty_cache() {
2286 return conv_benchmark_empty_cache;
2287 }
2288
2289
2290
2291 } // namespace at::native
2292