1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/TensorGeometry.h>
4 #include <ATen/TensorUtils.h>
5 #include <ATen/core/Tensor.h>
6 #include <ATen/cuda/CUDAConfig.h>
7 #include <ATen/cuda/EmptyTensor.h>
8 #include <ATen/native/ConvUtils.h>
9
10 #if AT_CUDNN_ENABLED()
11
12 #include <ATen/native/cudnn/ConvShared.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/cudnn_convolution_add_relu_native.h>
19 #include <ATen/ops/cudnn_convolution_native.h>
20 #include <ATen/ops/cudnn_convolution_relu_native.h>
21 #include <ATen/ops/cudnn_convolution_transpose_native.h>
22 #include <ATen/ops/empty.h>
23 #include <ATen/ops/empty_like.h>
24 #include <ATen/ops/zeros.h>
25 #include <ATen/ops/zeros_like.h>
26 #endif
27
28 // NOTE [cuDNN API version]
29 //
30 // ConvPlaceholders.cpp contains placeholder implementation of cudnn
31 // convolution when cudnn is not enabled. These operators only raises
32 // errors, and do no real computation. These operators are implemented
33 // using current operators.
34 //
35 // cuDNN v7 and v8 have different API. ConvShared.{cpp, h} contains
36 // code shared by v7 and v8. Conv_v7.cpp contains implementation of
37 // convolution using cuDNN v7 API. Conv_v8.cpp contains implementation
38 // with v8 API.
39 //
40 // NOTE [ Convolution design ]
41 //
42 // cuDNN convolutions does not handle bias. Bias is handled outside.
43 //
44 // The general strategy:
45 //
46 // - cudnn_convolution (Tensor)
47 // Entry points for clients
48 //
49 // - cudnn_convolution_forward (TensorArg)
50 // Entry point, which may be reused between regular
51 // convolution and transposed convolution.
52 //
53 // - raw_cudnn_convolution_forward_out (Tensor)
54 // Function that has different implementation on Conv_v7.cpp
55 // and Conv_v8.cpp
56 //
57 // The raw API directly invokes CuDNN and are implemented differently
58 // on cuDNN v7 and cuDNN v8
59 //
60 // There are a few reasons this should never be directly exposed
61 // via ATen:
62 //
63 // - It takes output as a parameter (this should be computed!)
64 // - It doesn't do input checking
65 // - It doesn't resize output (it is assumed to be correctly sized)
66 //
67 // Where does argument checking happen? Here's the division of
68 // responsibility:
69 // - Things that happen in at::Tensor
70 // - TensorArg allocation
71 // - Things that happen in TensorArg
72 // - Check arguments (type, GPU, shape)
73
74 namespace at {
75 namespace native {
76
77 // ---------------------------------------------------------------------
78 //
79 // ConvolutionParams
80 //
81 // ---------------------------------------------------------------------
82
operator <<(std::ostream & out,const ConvolutionParams & params)83 std::ostream& operator<<(std::ostream& out, const ConvolutionParams& params) {
84 out << "ConvolutionParams \n"
85 << " memory_format = " << params.memory_format << "\n"
86 << " data_type = " << cudnnTypeToString(params.dataType) << "\n"
87 << " padding = " << ArrayRef<int>{params.padding} << "\n"
88 << " stride = " << ArrayRef<int>{params.stride} << "\n"
89 << " dilation = " << ArrayRef<int>{params.dilation} << "\n"
90 << " groups = " << params.groups << "\n"
91 << " deterministic = " << (params.deterministic ? "true" : "false")
92 << "\n"
93 << " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << "\n";
94
95 return out;
96 }
97
98 // NB: This can't be a constructor, because then ConvolutionParams
99 // would not be a POD anymore.
100 // TODO: Use TensorGeometry here instead of the entire Tensor, which we
101 // don't actually need. (OTOH: We can always pass in
102 // grad_input/grad_output, so this is not very pressing)
setConvolutionParams(ConvolutionParams * params,const at::Tensor & input,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool deterministic,bool allow_tf32,at::MemoryFormat memory_format)103 void setConvolutionParams(
104 ConvolutionParams* params,
105 const at::Tensor& input,
106 const at::Tensor& weight,
107 IntArrayRef padding,
108 IntArrayRef stride,
109 IntArrayRef dilation,
110 int64_t groups,
111 bool deterministic,
112 bool allow_tf32,
113 at::MemoryFormat memory_format) {
114 cudnnDataType_t dataType = getCudnnDataType(input);
115 memset(params, 0, sizeof(ConvolutionParams));
116 params->device_id = at::cuda::current_device();
117 params->dataType = dataType;
118 // ASSERT(weight.dim() == input.dim())
119 params->input_dim = input.dim();
120 params->memory_format = memory_format;
121 for (int i = 0; i != params->input_dim; ++i) {
122 params->input_size[i] = (int)input.sizes()[i];
123 params->weight_size[i] = (int)weight.sizes()[i];
124 }
125 // ASSERT(padding.size() == stride.size())
126 // ASSERT(padding.size() == dilation.size())
127 for (size_t i = 0; i != padding.size(); ++i) {
128 params->padding[i] = padding[i];
129 params->stride[i] = stride[i];
130 params->dilation[i] = dilation[i];
131 }
132 // In principle, we shouldn't parametrize by groups for legacy
133 // CuDNN, but it doesn't seem worth the effort to actually do this.
134 params->groups = groups;
135 params->deterministic = deterministic;
136 params->allow_tf32 = allow_tf32;
137 }
138
repro_from_args(const ConvolutionParams & params)139 std::string repro_from_args(const ConvolutionParams& params) {
140 auto pybool = [](bool b) -> const char* { return b ? "True" : "False"; };
141 std::string partial_dtype;
142 switch (params.dataType) {
143 case CUDNN_DATA_FLOAT:
144 partial_dtype = "float";
145 break;
146 case CUDNN_DATA_DOUBLE:
147 partial_dtype = "double";
148 break;
149 case CUDNN_DATA_HALF:
150 partial_dtype = "half";
151 break;
152 default:
153 partial_dtype = "unsupported";
154 }
155 const std::string full_dtype = "torch." + partial_dtype;
156 const int out_channels = params.weight_size[0];
157 const int in_channels = params.weight_size[1] * params.groups;
158 const size_t dim = params.input_dim;
159 const std::string channels_last_xd =
160 dim == 4 ? "channels_last" : "channels_last_3d";
161 const std::string to_channels_last =
162 ((params.memory_format == at::MemoryFormat::ChannelsLast) ||
163 (params.memory_format == at::MemoryFormat::ChannelsLast3d))
164 ? ".to(memory_format=torch." + channels_last_xd + ")"
165 : "";
166
167 std::ostringstream ss;
168 ss << "You can try to repro this exception using the following code snippet. ";
169 ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
170 ss << "import torch\n";
171 ss << "torch.backends.cuda.matmul.allow_tf32 = "
172 << pybool(at::globalContext().allowTF32CuBLAS()) << "\n";
173 ss << "torch.backends.cudnn.benchmark = "
174 << pybool(at::globalContext().benchmarkCuDNN()) << "\n";
175 ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic)
176 << "\n";
177 ss << "torch.backends.cudnn.allow_tf32 = " << pybool(params.allow_tf32)
178 << "\n";
179 ss << "data = torch.randn(" << ArrayRef<int>(params.input_size, dim)
180 << ", dtype=" << full_dtype << ", ";
181 ss << "device='cuda', requires_grad=True)" << to_channels_last << "\n";
182 ss << "net = torch.nn.Conv" << dim - 2 << "d(" << in_channels << ", "
183 << out_channels << ", ";
184 ss << "kernel_size=" << ArrayRef<int>(¶ms.weight_size[2], dim - 2)
185 << ", ";
186 ss << "padding=" << ArrayRef<int>(params.padding, dim - 2) << ", ";
187 ss << "stride=" << ArrayRef<int>(params.stride, dim - 2) << ", ";
188 ss << "dilation=" << ArrayRef<int>(params.dilation, dim - 2) << ", ";
189 ss << "groups=" << params.groups << ")\n";
190 ss << "net = net.cuda()." << partial_dtype << "()" << to_channels_last
191 << "\n";
192 ss << "out = net(data)\n";
193 ss << "out.backward(torch.randn_like(out))\n";
194 ss << "torch.cuda.synchronize()\n\n";
195
196 return ss.str();
197 }
198
199 // ---------------------------------------------------------------------
200 //
201 // Convolution forward / Transposed convolution backward
202 //
203 // ---------------------------------------------------------------------
204
cudnn_convolution_forward_out(TensorArg & output,CheckedFrom c,const TensorArg & input,const TensorArg & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)205 void cudnn_convolution_forward_out(
206 TensorArg& output,
207 CheckedFrom c,
208 const TensorArg& input,
209 const TensorArg& weight,
210 IntArrayRef padding,
211 IntArrayRef stride,
212 IntArrayRef dilation,
213 int64_t groups,
214 bool benchmark,
215 bool deterministic,
216 bool allow_tf32) {
217 checkAllSameType(c, {input, weight});
218 checkAllSameGPU(c, {input, weight});
219
220 auto memory_format = output->suggest_memory_format();
221 convolution_shape_check(
222 c, input, weight, output, padding, stride, dilation, groups);
223
224 Tensor weight_contig = weight->contiguous(memory_format);
225 Tensor input_contig = input->contiguous(memory_format);
226
227 raw_cudnn_convolution_forward_out(
228 *output,
229 input_contig,
230 weight_contig,
231 padding,
232 stride,
233 dilation,
234 groups,
235 benchmark,
236 deterministic,
237 allow_tf32);
238 }
239
cudnn_convolution(const Tensor & input_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)240 Tensor cudnn_convolution(
241 const Tensor& input_t,
242 const Tensor& weight_t,
243 IntArrayRef padding,
244 IntArrayRef stride,
245 IntArrayRef dilation,
246 int64_t groups,
247 bool benchmark,
248 bool deterministic,
249 bool allow_tf32) {
250 TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2};
251 CheckedFrom c = "cudnn_convolution";
252 auto memory_format = cudnn_conv_suggest_memory_format(input_t, weight_t);
253 Tensor output_t = at::detail::empty_cuda(
254 conv_output_size(
255 input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
256 input->options().memory_format(memory_format));
257 if (output_t.numel() == 0) {
258 return output_t;
259 }
260 // Avoid ambiguity of "output" when this is being used as backwards
261 TensorArg output{output_t, "result", 0};
262 cudnn_convolution_forward_out(
263 output,
264 c,
265 input,
266 weight,
267 padding,
268 stride,
269 dilation,
270 groups,
271 benchmark,
272 deterministic,
273 allow_tf32);
274 return *output;
275 }
cudnn_convolution_out(const Tensor & input_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32,Tensor & output_t)276 at::Tensor& cudnn_convolution_out(
277 const Tensor& input_t,
278 const Tensor& weight_t,
279 IntArrayRef padding,
280 IntArrayRef stride,
281 IntArrayRef dilation,
282 int64_t groups,
283 bool benchmark,
284 bool deterministic,
285 bool allow_tf32,
286 Tensor& output_t) {
287 TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2};
288 CheckedFrom c = "cudnn_convolution";
289 if (output_t.numel() == 0) {
290 return output_t;
291 }
292 TensorArg output{output_t, "result", 0};
293 cudnn_convolution_forward_out(
294 output,
295 c,
296 input,
297 weight,
298 padding,
299 stride,
300 dilation,
301 groups,
302 benchmark,
303 deterministic,
304 allow_tf32);
305 return output_t;
306 }
307
308 // NB: output_padding not needed here, as there is no ambiguity to resolve
cudnn_convolution_transpose_backward_input(const Tensor & grad_output_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)309 Tensor cudnn_convolution_transpose_backward_input(
310 const Tensor& grad_output_t,
311 const Tensor& weight_t,
312 IntArrayRef padding,
313 IntArrayRef stride,
314 IntArrayRef dilation,
315 int64_t groups,
316 bool benchmark,
317 bool deterministic,
318 bool allow_tf32) {
319 TensorArg grad_output{grad_output_t, "grad_output", 1},
320 weight{weight_t, "weight", 2};
321 auto memory_format =
322 cudnn_conv_suggest_memory_format(grad_output_t, weight_t);
323 Tensor output_t = at::detail::empty_cuda(
324 conv_output_size(
325 grad_output_t.sizes(), weight_t.sizes(), padding, stride, dilation),
326 grad_output_t.options().memory_format(memory_format));
327
328 if (output_t.numel() == 0) {
329 return output_t;
330 }
331 TensorArg output{output_t, "result", 0};
332 cudnn_convolution_forward_out(
333 output,
334 "cudnn_convolution_transpose_backward_input",
335 grad_output,
336 weight,
337 padding,
338 stride,
339 dilation,
340 groups,
341 benchmark,
342 deterministic,
343 allow_tf32);
344 return *output;
345 }
346
347 // ---------------------------------------------------------------------
348 //
349 // Convolution backward / Transposed convolution forward
350 //
351 // ---------------------------------------------------------------------
352
353 // NOTE [ Backward vs transpose convolutions ]
354 //
355 // Backward and transpose are algorithmically equivalent, but they
356 // compute their geometry differently. In a backwards, you knew what
357 // the original size of the input tensor was, so you can cache that
358 // geometry and fill it directly. In transposed convolution, it is
359 // more conventional to not explicitly specify the output (previously
360 // input) size, and compute it. This, however, leaves a degree of
361 // freedom; this degree of freedom is resolved using the
362 // output_padding parameter. Both of these interfaces are equivalent,
363 // but they are differently convenient depending on the use case.
364
cudnn_convolution_backward_input(CheckedFrom c,IntArrayRef input_size,const TensorArg & grad_output,const TensorArg & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)365 Tensor cudnn_convolution_backward_input(
366 CheckedFrom c,
367 IntArrayRef input_size,
368 const TensorArg& grad_output,
369 const TensorArg& weight,
370 IntArrayRef padding,
371 IntArrayRef stride,
372 IntArrayRef dilation,
373 int64_t groups,
374 bool benchmark,
375 bool deterministic,
376 bool allow_tf32) {
377 checkAllSameType(c, {grad_output, weight});
378 checkAllSameGPU(c, {grad_output, weight});
379
380 auto memory_format = cudnn_conv_suggest_memory_format(*grad_output, *weight);
381 Tensor grad_input_t = at::detail::empty_cuda(
382 input_size, grad_output->options().memory_format(memory_format));
383
384 // Avoid "grad_input" when this is being used as transposed convolution
385 TensorArg grad_input{grad_input_t, "result", 0};
386 convolution_shape_check(
387 c, grad_input, weight, grad_output, padding, stride, dilation, groups);
388
389 Tensor weight_contig = weight->contiguous(memory_format);
390 Tensor grad_output_contig = grad_output->contiguous(memory_format);
391
392 raw_cudnn_convolution_backward_input_out(
393 *grad_input,
394 grad_output_contig,
395 weight_contig,
396 padding,
397 stride,
398 dilation,
399 groups,
400 benchmark,
401 deterministic,
402 allow_tf32);
403
404 return *grad_input;
405 }
406
cudnn_convolution_transpose_forward(CheckedFrom c,const TensorArg & grad_output,const TensorArg & weight,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)407 Tensor cudnn_convolution_transpose_forward(
408 CheckedFrom c,
409 const TensorArg& grad_output,
410 const TensorArg& weight,
411 IntArrayRef padding,
412 IntArrayRef output_padding,
413 IntArrayRef stride,
414 IntArrayRef dilation,
415 int64_t groups,
416 bool benchmark,
417 bool deterministic,
418 bool allow_tf32) {
419 auto input_size = conv_input_size(
420 grad_output->sizes(),
421 weight->sizes(),
422 padding,
423 output_padding,
424 stride,
425 dilation,
426 groups);
427 return cudnn_convolution_backward_input(
428 c,
429 input_size,
430 grad_output,
431 weight,
432 padding,
433 stride,
434 dilation,
435 groups,
436 benchmark,
437 deterministic,
438 allow_tf32);
439 }
440
cudnn_convolution_backward_input(IntArrayRef input_size,const Tensor & grad_output_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)441 Tensor cudnn_convolution_backward_input(
442 IntArrayRef input_size,
443 const Tensor& grad_output_t,
444 const Tensor& weight_t,
445 IntArrayRef padding,
446 IntArrayRef stride,
447 IntArrayRef dilation,
448 int64_t groups,
449 bool benchmark,
450 bool deterministic,
451 bool allow_tf32) {
452 TensorArg grad_output{grad_output_t, "grad_output", 1},
453 weight{weight_t, "weight", 2};
454 return cudnn_convolution_backward_input(
455 "cudnn_convolution_backward_input",
456 input_size,
457 grad_output,
458 weight,
459 padding,
460 stride,
461 dilation,
462 groups,
463 benchmark,
464 deterministic,
465 allow_tf32);
466 }
467
cudnn_convolution_transpose(const Tensor & input_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)468 Tensor cudnn_convolution_transpose(
469 const Tensor& input_t,
470 const Tensor& weight_t,
471 IntArrayRef padding,
472 IntArrayRef output_padding,
473 IntArrayRef stride,
474 IntArrayRef dilation,
475 int64_t groups,
476 bool benchmark,
477 bool deterministic,
478 bool allow_tf32) {
479 TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2};
480 CheckedFrom c = "cudnn_convolution_transpose";
481 auto output_t = cudnn_convolution_transpose_forward(
482 c,
483 input,
484 weight,
485 padding,
486 output_padding,
487 stride,
488 dilation,
489 groups,
490 benchmark,
491 deterministic,
492 allow_tf32);
493 return output_t;
494 }
495
496 // ---------------------------------------------------------------------
497 //
498 // Convolution backward (weight)
499 //
500 // ---------------------------------------------------------------------
501
cudnn_convolution_backward_weight(CheckedFrom c,IntArrayRef weight_size,const Tensor & grad_output_t,const Tensor & input_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)502 Tensor cudnn_convolution_backward_weight(
503 CheckedFrom c,
504 IntArrayRef weight_size,
505 const Tensor& grad_output_t,
506 const Tensor& input_t,
507 IntArrayRef padding,
508 IntArrayRef stride,
509 IntArrayRef dilation,
510 int64_t groups,
511 bool benchmark,
512 bool deterministic,
513 bool allow_tf32) {
514 auto layout = cudnn_conv_suggest_memory_format(input_t, grad_output_t);
515
516 Tensor grad_output_contig_t = grad_output_t.contiguous(layout);
517 TensorArg grad_output_contig{grad_output_contig_t, "grad_output", 1};
518
519 Tensor input_contig_t = input_t.contiguous(layout);
520 TensorArg input{input_contig_t, "input", 2};
521
522 checkAllSameType(c, {grad_output_contig, input});
523 checkAllSameGPU(c, {grad_output_contig, input});
524
525 auto grad_weight_t =
526 at::empty(weight_size, grad_output_contig->options(), layout);
527
528 // For uniformity with everything else, although it seems grad_weight
529 // would be unambiguous too.
530 TensorArg grad_weight{grad_weight_t, "result", 0};
531 convolution_shape_check(
532 c,
533 input,
534 grad_weight,
535 grad_output_contig,
536 padding,
537 stride,
538 dilation,
539 groups);
540
541 raw_cudnn_convolution_backward_weight_out(
542 *grad_weight,
543 *grad_output_contig,
544 *input,
545 padding,
546 stride,
547 dilation,
548 groups,
549 benchmark,
550 deterministic,
551 allow_tf32);
552
553 return grad_weight_t;
554 }
555
cudnn_convolution_backward_weight(IntArrayRef weight_size,const Tensor & grad_output_t,const Tensor & input_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)556 Tensor cudnn_convolution_backward_weight(
557 IntArrayRef weight_size,
558 const Tensor& grad_output_t,
559 const Tensor& input_t,
560 IntArrayRef padding,
561 IntArrayRef stride,
562 IntArrayRef dilation,
563 int64_t groups,
564 bool benchmark,
565 bool deterministic,
566 bool allow_tf32) {
567 return cudnn_convolution_backward_weight(
568 "cudnn_convolution_backward_weight",
569 weight_size,
570 grad_output_t,
571 input_t,
572 padding,
573 stride,
574 dilation,
575 groups,
576 benchmark,
577 deterministic,
578 allow_tf32);
579 }
580
cudnn_convolution_backward(const at::Tensor & input,const at::Tensor & grad_output_t,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32,std::array<bool,2> output_mask)581 std::tuple<at::Tensor, at::Tensor> cudnn_convolution_backward(
582 const at::Tensor& input,
583 const at::Tensor& grad_output_t,
584 const at::Tensor& weight,
585 IntArrayRef padding,
586 IntArrayRef stride,
587 IntArrayRef dilation,
588 int64_t groups,
589 bool benchmark,
590 bool deterministic,
591 bool allow_tf32,
592 std::array<bool, 2> output_mask) {
593 Tensor grad_output = grad_output_t.to(input.suggest_memory_format());
594
595 Tensor grad_input, grad_weight;
596 if (input.numel() == 0) {
597 if (output_mask[0]) {
598 grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
599 }
600 if (output_mask[1]) {
601 grad_weight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
602 }
603 } else {
604 if (output_mask[0]) {
605 grad_input = cudnn_convolution_backward_input(
606 input.sizes(),
607 grad_output,
608 weight,
609 padding,
610 stride,
611 dilation,
612 groups,
613 benchmark,
614 deterministic,
615 allow_tf32);
616 }
617 if (output_mask[1]) {
618 grad_weight = cudnn_convolution_backward_weight(
619 weight.sizes(),
620 grad_output,
621 input,
622 padding,
623 stride,
624 dilation,
625 groups,
626 benchmark,
627 deterministic,
628 allow_tf32);
629 }
630 }
631
632 return std::tuple<Tensor, Tensor>{grad_input, grad_weight};
633 }
634
cudnn_convolution_transpose_backward_weight(IntArrayRef weight_size,const Tensor & grad_output_t,const Tensor & input_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)635 Tensor cudnn_convolution_transpose_backward_weight(
636 IntArrayRef weight_size,
637 const Tensor& grad_output_t,
638 const Tensor& input_t,
639 IntArrayRef padding,
640 IntArrayRef stride,
641 IntArrayRef dilation,
642 int64_t groups,
643 bool benchmark,
644 bool deterministic,
645 bool allow_tf32) {
646 return cudnn_convolution_backward_weight(
647 "cudnn_convolution_backward_weight",
648 weight_size,
649 input_t,
650 grad_output_t,
651 padding,
652 stride,
653 dilation,
654 groups,
655 benchmark,
656 deterministic,
657 allow_tf32);
658 }
659
cudnn_convolution_transpose_backward(const at::Tensor & input,const at::Tensor & grad_output_t,const at::Tensor & weight,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32,std::array<bool,2> output_mask)660 std::tuple<at::Tensor, at::Tensor> cudnn_convolution_transpose_backward(
661 const at::Tensor& input,
662 const at::Tensor& grad_output_t,
663 const at::Tensor& weight,
664 IntArrayRef padding,
665 IntArrayRef output_padding,
666 IntArrayRef stride,
667 IntArrayRef dilation,
668 int64_t groups,
669 bool benchmark,
670 bool deterministic,
671 bool allow_tf32,
672 std::array<bool, 2> output_mask) {
673 Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
674
675 Tensor grad_input, grad_weight;
676 if (output_mask[0]) {
677 grad_input = cudnn_convolution_transpose_backward_input(
678 grad_output,
679 weight,
680 padding,
681 stride,
682 dilation,
683 groups,
684 benchmark,
685 deterministic,
686 allow_tf32);
687 }
688 if (output_mask[1]) {
689 grad_weight = cudnn_convolution_transpose_backward_weight(
690 weight.sizes(),
691 grad_output,
692 input,
693 padding,
694 stride,
695 dilation,
696 groups,
697 benchmark,
698 deterministic,
699 allow_tf32);
700 }
701
702 return std::tuple<Tensor, Tensor>{grad_input, grad_weight};
703 }
704
cudnn_convolution_relu(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_t,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups)705 Tensor cudnn_convolution_relu(
706 const Tensor& input_t,
707 const Tensor& weight_t,
708 const std::optional<Tensor>& bias_t,
709 IntArrayRef stride,
710 IntArrayRef padding,
711 IntArrayRef dilation,
712 int64_t groups) {
713 auto memory_format = cudnn_conv_suggest_memory_format(input_t, weight_t);
714 const Tensor input = input_t.contiguous(memory_format);
715 const Tensor weight = weight_t.contiguous(memory_format);
716
717 // FuseFrozenConvAddRelu performs some tensor shape checking
718 Tensor output_t = at::detail::empty_cuda(
719 conv_output_size(
720 input.sizes(), weight.sizes(), padding, stride, dilation),
721 input.options().memory_format(memory_format));
722 if (output_t.numel() == 0) {
723 return output_t;
724 }
725
726 auto& ctx = at::globalContext();
727 bool benchmark = ctx.benchmarkCuDNN();
728 bool allow_tf32 = ctx.allowTF32CuDNN();
729 auto _bias = bias_t.has_value()
730 ? bias_t.value()
731 : at::zeros(
732 {output_t.size(1)},
733 optTypeMetaToScalarType(output_t.options().dtype_opt()),
734 output_t.options().layout_opt(),
735 output_t.options().device_opt(),
736 output_t.options().pinned_memory_opt());
737
738 raw_cudnn_convolution_add_relu_out(
739 output_t,
740 input,
741 weight,
742 output_t, // use output_t as z to satisfy CUDNN API
743 0, // alpha
744 _bias,
745 stride,
746 padding,
747 dilation,
748 groups,
749 benchmark, // benchmark
750 false, // deterministic
751 allow_tf32 // allow_tf32
752 );
753
754 return output_t;
755 }
756
cudnn_convolution_add_relu(const Tensor & input_t,const Tensor & weight_t,const Tensor & z_t,const std::optional<Scalar> & alpha,const std::optional<Tensor> & bias_t,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups)757 Tensor cudnn_convolution_add_relu(
758 const Tensor& input_t,
759 const Tensor& weight_t,
760 const Tensor& z_t,
761 const std::optional<Scalar>& alpha,
762 const std::optional<Tensor>& bias_t,
763 IntArrayRef stride,
764 IntArrayRef padding,
765 IntArrayRef dilation,
766 int64_t groups) {
767 auto memory_format = cudnn_conv_suggest_memory_format(input_t, weight_t);
768 const Tensor input = input_t.contiguous(memory_format);
769 const Tensor weight = weight_t.contiguous(memory_format);
770 Tensor z = z_t;
771 if (z.suggest_memory_format() != memory_format) {
772 z = z.to(memory_format);
773 }
774 z = z.contiguous(memory_format);
775
776 // FuseFrozenConvAddRelu performs some tensor shape checking
777 Tensor output_t = at::detail::empty_cuda(
778 conv_output_size(
779 input.sizes(), weight.sizes(), padding, stride, dilation),
780 input.options().memory_format(memory_format));
781 if (output_t.numel() == 0) {
782 return output_t;
783 }
784
785 auto& ctx = at::globalContext();
786 bool allow_tf32 = ctx.allowTF32CuDNN();
787 bool benchmark = ctx.benchmarkCuDNN();
788 auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
789 auto _bias = bias_t.has_value()
790 ? bias_t.value()
791 : at::zeros(
792 {output_t.size(1)},
793 optTypeMetaToScalarType(output_t.options().dtype_opt()),
794 output_t.options().layout_opt(),
795 output_t.options().device_opt(),
796 output_t.options().pinned_memory_opt());
797
798 raw_cudnn_convolution_add_relu_out(
799 output_t,
800 input,
801 weight,
802 z,
803 _alpha,
804 _bias,
805 stride,
806 padding,
807 dilation,
808 groups,
809 benchmark,
810 false, // deterministic
811 allow_tf32 // allow_tf32
812 );
813
814 return output_t;
815 }
816
817 REGISTER_CUDA_DISPATCH(
818 cudnn_convolution_backward_stub,
819 &cudnn_convolution_backward);
820 REGISTER_CUDA_DISPATCH(
821 cudnn_convolution_transpose_backward_stub,
822 &cudnn_convolution_transpose_backward);
823
824 } // namespace native
825 } // namespace at
826
827 #endif // AT_CUDNN_ENABLED
828