xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/ConvShared.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/TensorGeometry.h>
4 #include <ATen/TensorUtils.h>
5 #include <ATen/core/Tensor.h>
6 #include <ATen/cuda/CUDAConfig.h>
7 #include <ATen/cuda/EmptyTensor.h>
8 #include <ATen/native/ConvUtils.h>
9 
10 #if AT_CUDNN_ENABLED()
11 
12 #include <ATen/native/cudnn/ConvShared.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #include <ATen/NativeFunctions.h>
17 #else
18 #include <ATen/ops/cudnn_convolution_add_relu_native.h>
19 #include <ATen/ops/cudnn_convolution_native.h>
20 #include <ATen/ops/cudnn_convolution_relu_native.h>
21 #include <ATen/ops/cudnn_convolution_transpose_native.h>
22 #include <ATen/ops/empty.h>
23 #include <ATen/ops/empty_like.h>
24 #include <ATen/ops/zeros.h>
25 #include <ATen/ops/zeros_like.h>
26 #endif
27 
28 // NOTE [cuDNN API version]
29 //
30 // ConvPlaceholders.cpp contains placeholder implementation of cudnn
31 // convolution when cudnn is not enabled. These operators only raises
32 // errors, and do no real computation. These operators are implemented
33 // using current operators.
34 //
35 // cuDNN v7 and v8 have different API. ConvShared.{cpp, h} contains
36 // code shared by v7 and v8. Conv_v7.cpp contains implementation of
37 // convolution using cuDNN v7 API. Conv_v8.cpp contains implementation
38 // with v8 API.
39 //
40 // NOTE [ Convolution design ]
41 //
42 // cuDNN convolutions does not handle bias. Bias is handled outside.
43 //
44 // The general strategy:
45 //
46 //    - cudnn_convolution (Tensor)
47 //      Entry points for clients
48 //
49 //    - cudnn_convolution_forward (TensorArg)
50 //      Entry point, which may be reused between regular
51 //      convolution and transposed convolution.
52 //
53 //    - raw_cudnn_convolution_forward_out (Tensor)
54 //      Function that has different implementation on Conv_v7.cpp
55 //      and Conv_v8.cpp
56 //
57 // The raw API directly invokes CuDNN and are implemented differently
58 // on cuDNN v7 and cuDNN v8
59 //
60 // There are a few reasons this should never be directly exposed
61 // via ATen:
62 //
63 //    - It takes output as a parameter (this should be computed!)
64 //    - It doesn't do input checking
65 //    - It doesn't resize output (it is assumed to be correctly sized)
66 //
67 // Where does argument checking happen?  Here's the division of
68 // responsibility:
69 //  - Things that happen in at::Tensor
70 //    - TensorArg allocation
71 //  - Things that happen in TensorArg
72 //    - Check arguments (type, GPU, shape)
73 
74 namespace at {
75 namespace native {
76 
77 // ---------------------------------------------------------------------
78 //
79 // ConvolutionParams
80 //
81 // ---------------------------------------------------------------------
82 
operator <<(std::ostream & out,const ConvolutionParams & params)83 std::ostream& operator<<(std::ostream& out, const ConvolutionParams& params) {
84   out << "ConvolutionParams \n"
85       << "    memory_format = " << params.memory_format << "\n"
86       << "    data_type = " << cudnnTypeToString(params.dataType) << "\n"
87       << "    padding = " << ArrayRef<int>{params.padding} << "\n"
88       << "    stride = " << ArrayRef<int>{params.stride} << "\n"
89       << "    dilation = " << ArrayRef<int>{params.dilation} << "\n"
90       << "    groups = " << params.groups << "\n"
91       << "    deterministic = " << (params.deterministic ? "true" : "false")
92       << "\n"
93       << "    allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << "\n";
94 
95   return out;
96 }
97 
98 // NB: This can't be a constructor, because then ConvolutionParams
99 // would not be a POD anymore.
100 // TODO: Use TensorGeometry here instead of the entire Tensor, which we
101 // don't actually need.  (OTOH: We can always pass in
102 // grad_input/grad_output, so this is not very pressing)
setConvolutionParams(ConvolutionParams * params,const at::Tensor & input,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool deterministic,bool allow_tf32,at::MemoryFormat memory_format)103 void setConvolutionParams(
104     ConvolutionParams* params,
105     const at::Tensor& input,
106     const at::Tensor& weight,
107     IntArrayRef padding,
108     IntArrayRef stride,
109     IntArrayRef dilation,
110     int64_t groups,
111     bool deterministic,
112     bool allow_tf32,
113     at::MemoryFormat memory_format) {
114   cudnnDataType_t dataType = getCudnnDataType(input);
115   memset(params, 0, sizeof(ConvolutionParams));
116   params->device_id = at::cuda::current_device();
117   params->dataType = dataType;
118   // ASSERT(weight.dim() == input.dim())
119   params->input_dim = input.dim();
120   params->memory_format = memory_format;
121   for (int i = 0; i != params->input_dim; ++i) {
122     params->input_size[i] = (int)input.sizes()[i];
123     params->weight_size[i] = (int)weight.sizes()[i];
124   }
125   // ASSERT(padding.size() == stride.size())
126   // ASSERT(padding.size() == dilation.size())
127   for (size_t i = 0; i != padding.size(); ++i) {
128     params->padding[i] = padding[i];
129     params->stride[i] = stride[i];
130     params->dilation[i] = dilation[i];
131   }
132   // In principle, we shouldn't parametrize by groups for legacy
133   // CuDNN, but it doesn't seem worth the effort to actually do this.
134   params->groups = groups;
135   params->deterministic = deterministic;
136   params->allow_tf32 = allow_tf32;
137 }
138 
repro_from_args(const ConvolutionParams & params)139 std::string repro_from_args(const ConvolutionParams& params) {
140   auto pybool = [](bool b) -> const char* { return b ? "True" : "False"; };
141   std::string partial_dtype;
142   switch (params.dataType) {
143     case CUDNN_DATA_FLOAT:
144       partial_dtype = "float";
145       break;
146     case CUDNN_DATA_DOUBLE:
147       partial_dtype = "double";
148       break;
149     case CUDNN_DATA_HALF:
150       partial_dtype = "half";
151       break;
152     default:
153       partial_dtype = "unsupported";
154   }
155   const std::string full_dtype = "torch." + partial_dtype;
156   const int out_channels = params.weight_size[0];
157   const int in_channels = params.weight_size[1] * params.groups;
158   const size_t dim = params.input_dim;
159   const std::string channels_last_xd =
160       dim == 4 ? "channels_last" : "channels_last_3d";
161   const std::string to_channels_last =
162       ((params.memory_format == at::MemoryFormat::ChannelsLast) ||
163        (params.memory_format == at::MemoryFormat::ChannelsLast3d))
164       ? ".to(memory_format=torch." + channels_last_xd + ")"
165       : "";
166 
167   std::ostringstream ss;
168   ss << "You can try to repro this exception using the following code snippet. ";
169   ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
170   ss << "import torch\n";
171   ss << "torch.backends.cuda.matmul.allow_tf32 = "
172      << pybool(at::globalContext().allowTF32CuBLAS()) << "\n";
173   ss << "torch.backends.cudnn.benchmark = "
174      << pybool(at::globalContext().benchmarkCuDNN()) << "\n";
175   ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic)
176      << "\n";
177   ss << "torch.backends.cudnn.allow_tf32 = " << pybool(params.allow_tf32)
178      << "\n";
179   ss << "data = torch.randn(" << ArrayRef<int>(params.input_size, dim)
180      << ", dtype=" << full_dtype << ", ";
181   ss << "device='cuda', requires_grad=True)" << to_channels_last << "\n";
182   ss << "net = torch.nn.Conv" << dim - 2 << "d(" << in_channels << ", "
183      << out_channels << ", ";
184   ss << "kernel_size=" << ArrayRef<int>(&params.weight_size[2], dim - 2)
185      << ", ";
186   ss << "padding=" << ArrayRef<int>(params.padding, dim - 2) << ", ";
187   ss << "stride=" << ArrayRef<int>(params.stride, dim - 2) << ", ";
188   ss << "dilation=" << ArrayRef<int>(params.dilation, dim - 2) << ", ";
189   ss << "groups=" << params.groups << ")\n";
190   ss << "net = net.cuda()." << partial_dtype << "()" << to_channels_last
191      << "\n";
192   ss << "out = net(data)\n";
193   ss << "out.backward(torch.randn_like(out))\n";
194   ss << "torch.cuda.synchronize()\n\n";
195 
196   return ss.str();
197 }
198 
199 // ---------------------------------------------------------------------
200 //
201 // Convolution forward / Transposed convolution backward
202 //
203 // ---------------------------------------------------------------------
204 
cudnn_convolution_forward_out(TensorArg & output,CheckedFrom c,const TensorArg & input,const TensorArg & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)205 void cudnn_convolution_forward_out(
206     TensorArg& output,
207     CheckedFrom c,
208     const TensorArg& input,
209     const TensorArg& weight,
210     IntArrayRef padding,
211     IntArrayRef stride,
212     IntArrayRef dilation,
213     int64_t groups,
214     bool benchmark,
215     bool deterministic,
216     bool allow_tf32) {
217   checkAllSameType(c, {input, weight});
218   checkAllSameGPU(c, {input, weight});
219 
220   auto memory_format = output->suggest_memory_format();
221   convolution_shape_check(
222       c, input, weight, output, padding, stride, dilation, groups);
223 
224   Tensor weight_contig = weight->contiguous(memory_format);
225   Tensor input_contig = input->contiguous(memory_format);
226 
227   raw_cudnn_convolution_forward_out(
228       *output,
229       input_contig,
230       weight_contig,
231       padding,
232       stride,
233       dilation,
234       groups,
235       benchmark,
236       deterministic,
237       allow_tf32);
238 }
239 
cudnn_convolution(const Tensor & input_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)240 Tensor cudnn_convolution(
241     const Tensor& input_t,
242     const Tensor& weight_t,
243     IntArrayRef padding,
244     IntArrayRef stride,
245     IntArrayRef dilation,
246     int64_t groups,
247     bool benchmark,
248     bool deterministic,
249     bool allow_tf32) {
250   TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2};
251   CheckedFrom c = "cudnn_convolution";
252   auto memory_format = cudnn_conv_suggest_memory_format(input_t, weight_t);
253   Tensor output_t = at::detail::empty_cuda(
254       conv_output_size(
255           input_t.sizes(), weight_t.sizes(), padding, stride, dilation),
256       input->options().memory_format(memory_format));
257   if (output_t.numel() == 0) {
258     return output_t;
259   }
260   // Avoid ambiguity of "output" when this is being used as backwards
261   TensorArg output{output_t, "result", 0};
262   cudnn_convolution_forward_out(
263       output,
264       c,
265       input,
266       weight,
267       padding,
268       stride,
269       dilation,
270       groups,
271       benchmark,
272       deterministic,
273       allow_tf32);
274   return *output;
275 }
cudnn_convolution_out(const Tensor & input_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32,Tensor & output_t)276 at::Tensor& cudnn_convolution_out(
277     const Tensor& input_t,
278     const Tensor& weight_t,
279     IntArrayRef padding,
280     IntArrayRef stride,
281     IntArrayRef dilation,
282     int64_t groups,
283     bool benchmark,
284     bool deterministic,
285     bool allow_tf32,
286     Tensor& output_t) {
287   TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2};
288   CheckedFrom c = "cudnn_convolution";
289   if (output_t.numel() == 0) {
290     return output_t;
291   }
292   TensorArg output{output_t, "result", 0};
293   cudnn_convolution_forward_out(
294       output,
295       c,
296       input,
297       weight,
298       padding,
299       stride,
300       dilation,
301       groups,
302       benchmark,
303       deterministic,
304       allow_tf32);
305   return output_t;
306 }
307 
308 // NB: output_padding not needed here, as there is no ambiguity to resolve
cudnn_convolution_transpose_backward_input(const Tensor & grad_output_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)309 Tensor cudnn_convolution_transpose_backward_input(
310     const Tensor& grad_output_t,
311     const Tensor& weight_t,
312     IntArrayRef padding,
313     IntArrayRef stride,
314     IntArrayRef dilation,
315     int64_t groups,
316     bool benchmark,
317     bool deterministic,
318     bool allow_tf32) {
319   TensorArg grad_output{grad_output_t, "grad_output", 1},
320       weight{weight_t, "weight", 2};
321   auto memory_format =
322       cudnn_conv_suggest_memory_format(grad_output_t, weight_t);
323   Tensor output_t = at::detail::empty_cuda(
324       conv_output_size(
325           grad_output_t.sizes(), weight_t.sizes(), padding, stride, dilation),
326       grad_output_t.options().memory_format(memory_format));
327 
328   if (output_t.numel() == 0) {
329     return output_t;
330   }
331   TensorArg output{output_t, "result", 0};
332   cudnn_convolution_forward_out(
333       output,
334       "cudnn_convolution_transpose_backward_input",
335       grad_output,
336       weight,
337       padding,
338       stride,
339       dilation,
340       groups,
341       benchmark,
342       deterministic,
343       allow_tf32);
344   return *output;
345 }
346 
347 // ---------------------------------------------------------------------
348 //
349 // Convolution backward / Transposed convolution forward
350 //
351 // ---------------------------------------------------------------------
352 
353 // NOTE [ Backward vs transpose convolutions ]
354 //
355 // Backward and transpose are algorithmically equivalent, but they
356 // compute their geometry differently.  In a backwards, you knew what
357 // the original size of the input tensor was, so you can cache that
358 // geometry and fill it directly.  In transposed convolution, it is
359 // more conventional to not explicitly specify the output (previously
360 // input) size, and compute it.  This, however, leaves a degree of
361 // freedom; this degree of freedom is resolved using the
362 // output_padding parameter.  Both of these interfaces are equivalent,
363 // but they are differently convenient depending on the use case.
364 
cudnn_convolution_backward_input(CheckedFrom c,IntArrayRef input_size,const TensorArg & grad_output,const TensorArg & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)365 Tensor cudnn_convolution_backward_input(
366     CheckedFrom c,
367     IntArrayRef input_size,
368     const TensorArg& grad_output,
369     const TensorArg& weight,
370     IntArrayRef padding,
371     IntArrayRef stride,
372     IntArrayRef dilation,
373     int64_t groups,
374     bool benchmark,
375     bool deterministic,
376     bool allow_tf32) {
377   checkAllSameType(c, {grad_output, weight});
378   checkAllSameGPU(c, {grad_output, weight});
379 
380   auto memory_format = cudnn_conv_suggest_memory_format(*grad_output, *weight);
381   Tensor grad_input_t = at::detail::empty_cuda(
382       input_size, grad_output->options().memory_format(memory_format));
383 
384   // Avoid "grad_input" when this is being used as transposed convolution
385   TensorArg grad_input{grad_input_t, "result", 0};
386   convolution_shape_check(
387       c, grad_input, weight, grad_output, padding, stride, dilation, groups);
388 
389   Tensor weight_contig = weight->contiguous(memory_format);
390   Tensor grad_output_contig = grad_output->contiguous(memory_format);
391 
392   raw_cudnn_convolution_backward_input_out(
393       *grad_input,
394       grad_output_contig,
395       weight_contig,
396       padding,
397       stride,
398       dilation,
399       groups,
400       benchmark,
401       deterministic,
402       allow_tf32);
403 
404   return *grad_input;
405 }
406 
cudnn_convolution_transpose_forward(CheckedFrom c,const TensorArg & grad_output,const TensorArg & weight,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)407 Tensor cudnn_convolution_transpose_forward(
408     CheckedFrom c,
409     const TensorArg& grad_output,
410     const TensorArg& weight,
411     IntArrayRef padding,
412     IntArrayRef output_padding,
413     IntArrayRef stride,
414     IntArrayRef dilation,
415     int64_t groups,
416     bool benchmark,
417     bool deterministic,
418     bool allow_tf32) {
419   auto input_size = conv_input_size(
420       grad_output->sizes(),
421       weight->sizes(),
422       padding,
423       output_padding,
424       stride,
425       dilation,
426       groups);
427   return cudnn_convolution_backward_input(
428       c,
429       input_size,
430       grad_output,
431       weight,
432       padding,
433       stride,
434       dilation,
435       groups,
436       benchmark,
437       deterministic,
438       allow_tf32);
439 }
440 
cudnn_convolution_backward_input(IntArrayRef input_size,const Tensor & grad_output_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)441 Tensor cudnn_convolution_backward_input(
442     IntArrayRef input_size,
443     const Tensor& grad_output_t,
444     const Tensor& weight_t,
445     IntArrayRef padding,
446     IntArrayRef stride,
447     IntArrayRef dilation,
448     int64_t groups,
449     bool benchmark,
450     bool deterministic,
451     bool allow_tf32) {
452   TensorArg grad_output{grad_output_t, "grad_output", 1},
453       weight{weight_t, "weight", 2};
454   return cudnn_convolution_backward_input(
455       "cudnn_convolution_backward_input",
456       input_size,
457       grad_output,
458       weight,
459       padding,
460       stride,
461       dilation,
462       groups,
463       benchmark,
464       deterministic,
465       allow_tf32);
466 }
467 
cudnn_convolution_transpose(const Tensor & input_t,const Tensor & weight_t,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)468 Tensor cudnn_convolution_transpose(
469     const Tensor& input_t,
470     const Tensor& weight_t,
471     IntArrayRef padding,
472     IntArrayRef output_padding,
473     IntArrayRef stride,
474     IntArrayRef dilation,
475     int64_t groups,
476     bool benchmark,
477     bool deterministic,
478     bool allow_tf32) {
479   TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2};
480   CheckedFrom c = "cudnn_convolution_transpose";
481   auto output_t = cudnn_convolution_transpose_forward(
482       c,
483       input,
484       weight,
485       padding,
486       output_padding,
487       stride,
488       dilation,
489       groups,
490       benchmark,
491       deterministic,
492       allow_tf32);
493   return output_t;
494 }
495 
496 // ---------------------------------------------------------------------
497 //
498 // Convolution backward (weight)
499 //
500 // ---------------------------------------------------------------------
501 
cudnn_convolution_backward_weight(CheckedFrom c,IntArrayRef weight_size,const Tensor & grad_output_t,const Tensor & input_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)502 Tensor cudnn_convolution_backward_weight(
503     CheckedFrom c,
504     IntArrayRef weight_size,
505     const Tensor& grad_output_t,
506     const Tensor& input_t,
507     IntArrayRef padding,
508     IntArrayRef stride,
509     IntArrayRef dilation,
510     int64_t groups,
511     bool benchmark,
512     bool deterministic,
513     bool allow_tf32) {
514   auto layout = cudnn_conv_suggest_memory_format(input_t, grad_output_t);
515 
516   Tensor grad_output_contig_t = grad_output_t.contiguous(layout);
517   TensorArg grad_output_contig{grad_output_contig_t, "grad_output", 1};
518 
519   Tensor input_contig_t = input_t.contiguous(layout);
520   TensorArg input{input_contig_t, "input", 2};
521 
522   checkAllSameType(c, {grad_output_contig, input});
523   checkAllSameGPU(c, {grad_output_contig, input});
524 
525   auto grad_weight_t =
526       at::empty(weight_size, grad_output_contig->options(), layout);
527 
528   // For uniformity with everything else, although it seems grad_weight
529   // would be unambiguous too.
530   TensorArg grad_weight{grad_weight_t, "result", 0};
531   convolution_shape_check(
532       c,
533       input,
534       grad_weight,
535       grad_output_contig,
536       padding,
537       stride,
538       dilation,
539       groups);
540 
541   raw_cudnn_convolution_backward_weight_out(
542       *grad_weight,
543       *grad_output_contig,
544       *input,
545       padding,
546       stride,
547       dilation,
548       groups,
549       benchmark,
550       deterministic,
551       allow_tf32);
552 
553   return grad_weight_t;
554 }
555 
cudnn_convolution_backward_weight(IntArrayRef weight_size,const Tensor & grad_output_t,const Tensor & input_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)556 Tensor cudnn_convolution_backward_weight(
557     IntArrayRef weight_size,
558     const Tensor& grad_output_t,
559     const Tensor& input_t,
560     IntArrayRef padding,
561     IntArrayRef stride,
562     IntArrayRef dilation,
563     int64_t groups,
564     bool benchmark,
565     bool deterministic,
566     bool allow_tf32) {
567   return cudnn_convolution_backward_weight(
568       "cudnn_convolution_backward_weight",
569       weight_size,
570       grad_output_t,
571       input_t,
572       padding,
573       stride,
574       dilation,
575       groups,
576       benchmark,
577       deterministic,
578       allow_tf32);
579 }
580 
cudnn_convolution_backward(const at::Tensor & input,const at::Tensor & grad_output_t,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32,std::array<bool,2> output_mask)581 std::tuple<at::Tensor, at::Tensor> cudnn_convolution_backward(
582     const at::Tensor& input,
583     const at::Tensor& grad_output_t,
584     const at::Tensor& weight,
585     IntArrayRef padding,
586     IntArrayRef stride,
587     IntArrayRef dilation,
588     int64_t groups,
589     bool benchmark,
590     bool deterministic,
591     bool allow_tf32,
592     std::array<bool, 2> output_mask) {
593   Tensor grad_output = grad_output_t.to(input.suggest_memory_format());
594 
595   Tensor grad_input, grad_weight;
596   if (input.numel() == 0) {
597     if (output_mask[0]) {
598       grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
599     }
600     if (output_mask[1]) {
601       grad_weight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
602     }
603   } else {
604     if (output_mask[0]) {
605       grad_input = cudnn_convolution_backward_input(
606           input.sizes(),
607           grad_output,
608           weight,
609           padding,
610           stride,
611           dilation,
612           groups,
613           benchmark,
614           deterministic,
615           allow_tf32);
616     }
617     if (output_mask[1]) {
618       grad_weight = cudnn_convolution_backward_weight(
619           weight.sizes(),
620           grad_output,
621           input,
622           padding,
623           stride,
624           dilation,
625           groups,
626           benchmark,
627           deterministic,
628           allow_tf32);
629     }
630   }
631 
632   return std::tuple<Tensor, Tensor>{grad_input, grad_weight};
633 }
634 
cudnn_convolution_transpose_backward_weight(IntArrayRef weight_size,const Tensor & grad_output_t,const Tensor & input_t,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)635 Tensor cudnn_convolution_transpose_backward_weight(
636     IntArrayRef weight_size,
637     const Tensor& grad_output_t,
638     const Tensor& input_t,
639     IntArrayRef padding,
640     IntArrayRef stride,
641     IntArrayRef dilation,
642     int64_t groups,
643     bool benchmark,
644     bool deterministic,
645     bool allow_tf32) {
646   return cudnn_convolution_backward_weight(
647       "cudnn_convolution_backward_weight",
648       weight_size,
649       input_t,
650       grad_output_t,
651       padding,
652       stride,
653       dilation,
654       groups,
655       benchmark,
656       deterministic,
657       allow_tf32);
658 }
659 
cudnn_convolution_transpose_backward(const at::Tensor & input,const at::Tensor & grad_output_t,const at::Tensor & weight,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32,std::array<bool,2> output_mask)660 std::tuple<at::Tensor, at::Tensor> cudnn_convolution_transpose_backward(
661     const at::Tensor& input,
662     const at::Tensor& grad_output_t,
663     const at::Tensor& weight,
664     IntArrayRef padding,
665     IntArrayRef output_padding,
666     IntArrayRef stride,
667     IntArrayRef dilation,
668     int64_t groups,
669     bool benchmark,
670     bool deterministic,
671     bool allow_tf32,
672     std::array<bool, 2> output_mask) {
673   Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
674 
675   Tensor grad_input, grad_weight;
676   if (output_mask[0]) {
677     grad_input = cudnn_convolution_transpose_backward_input(
678         grad_output,
679         weight,
680         padding,
681         stride,
682         dilation,
683         groups,
684         benchmark,
685         deterministic,
686         allow_tf32);
687   }
688   if (output_mask[1]) {
689     grad_weight = cudnn_convolution_transpose_backward_weight(
690         weight.sizes(),
691         grad_output,
692         input,
693         padding,
694         stride,
695         dilation,
696         groups,
697         benchmark,
698         deterministic,
699         allow_tf32);
700   }
701 
702   return std::tuple<Tensor, Tensor>{grad_input, grad_weight};
703 }
704 
cudnn_convolution_relu(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_t,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups)705 Tensor cudnn_convolution_relu(
706     const Tensor& input_t,
707     const Tensor& weight_t,
708     const std::optional<Tensor>& bias_t,
709     IntArrayRef stride,
710     IntArrayRef padding,
711     IntArrayRef dilation,
712     int64_t groups) {
713   auto memory_format = cudnn_conv_suggest_memory_format(input_t, weight_t);
714   const Tensor input = input_t.contiguous(memory_format);
715   const Tensor weight = weight_t.contiguous(memory_format);
716 
717   // FuseFrozenConvAddRelu performs some tensor shape checking
718   Tensor output_t = at::detail::empty_cuda(
719       conv_output_size(
720           input.sizes(), weight.sizes(), padding, stride, dilation),
721       input.options().memory_format(memory_format));
722   if (output_t.numel() == 0) {
723     return output_t;
724   }
725 
726   auto& ctx = at::globalContext();
727   bool benchmark = ctx.benchmarkCuDNN();
728   bool allow_tf32 = ctx.allowTF32CuDNN();
729   auto _bias = bias_t.has_value()
730       ? bias_t.value()
731       : at::zeros(
732             {output_t.size(1)},
733             optTypeMetaToScalarType(output_t.options().dtype_opt()),
734             output_t.options().layout_opt(),
735             output_t.options().device_opt(),
736             output_t.options().pinned_memory_opt());
737 
738   raw_cudnn_convolution_add_relu_out(
739       output_t,
740       input,
741       weight,
742       output_t, // use output_t as z to satisfy CUDNN API
743       0, // alpha
744       _bias,
745       stride,
746       padding,
747       dilation,
748       groups,
749       benchmark, // benchmark
750       false, // deterministic
751       allow_tf32 // allow_tf32
752   );
753 
754   return output_t;
755 }
756 
cudnn_convolution_add_relu(const Tensor & input_t,const Tensor & weight_t,const Tensor & z_t,const std::optional<Scalar> & alpha,const std::optional<Tensor> & bias_t,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups)757 Tensor cudnn_convolution_add_relu(
758     const Tensor& input_t,
759     const Tensor& weight_t,
760     const Tensor& z_t,
761     const std::optional<Scalar>& alpha,
762     const std::optional<Tensor>& bias_t,
763     IntArrayRef stride,
764     IntArrayRef padding,
765     IntArrayRef dilation,
766     int64_t groups) {
767   auto memory_format = cudnn_conv_suggest_memory_format(input_t, weight_t);
768   const Tensor input = input_t.contiguous(memory_format);
769   const Tensor weight = weight_t.contiguous(memory_format);
770   Tensor z = z_t;
771   if (z.suggest_memory_format() != memory_format) {
772     z = z.to(memory_format);
773   }
774   z = z.contiguous(memory_format);
775 
776   // FuseFrozenConvAddRelu performs some tensor shape checking
777   Tensor output_t = at::detail::empty_cuda(
778       conv_output_size(
779           input.sizes(), weight.sizes(), padding, stride, dilation),
780       input.options().memory_format(memory_format));
781   if (output_t.numel() == 0) {
782     return output_t;
783   }
784 
785   auto& ctx = at::globalContext();
786   bool allow_tf32 = ctx.allowTF32CuDNN();
787   bool benchmark = ctx.benchmarkCuDNN();
788   auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
789   auto _bias = bias_t.has_value()
790       ? bias_t.value()
791       : at::zeros(
792             {output_t.size(1)},
793             optTypeMetaToScalarType(output_t.options().dtype_opt()),
794             output_t.options().layout_opt(),
795             output_t.options().device_opt(),
796             output_t.options().pinned_memory_opt());
797 
798   raw_cudnn_convolution_add_relu_out(
799       output_t,
800       input,
801       weight,
802       z,
803       _alpha,
804       _bias,
805       stride,
806       padding,
807       dilation,
808       groups,
809       benchmark,
810       false, // deterministic
811       allow_tf32 // allow_tf32
812   );
813 
814   return output_t;
815 }
816 
817 REGISTER_CUDA_DISPATCH(
818     cudnn_convolution_backward_stub,
819     &cudnn_convolution_backward);
820 REGISTER_CUDA_DISPATCH(
821     cudnn_convolution_transpose_backward_stub,
822     &cudnn_convolution_transpose_backward);
823 
824 } // namespace native
825 } // namespace at
826 
827 #endif // AT_CUDNN_ENABLED
828