xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/Conv_v7.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
3 
4 #if AT_CUDNN_ENABLED()
5 
6 #include <ATen/core/Tensor.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #else
11 #include <ATen/ops/empty.h>
12 #include <ATen/ops/empty_like.h>
13 #include <ATen/ops/zeros.h>
14 #endif
15 
16 #include <ATen/Config.h>
17 #include <ATen/cuda/Exceptions.h>
18 #include <ATen/native/cudnn/ConvShared.h>
19 #include <ATen/cuda/CUDAGraphsUtils.cuh>
20 #include <limits>
21 #include <vector>
22 
23 #include <ATen/cudnn/Types.h>
24 #include <ATen/cudnn/Utils.h>
25 #include <ATen/native/utils/ParamsHash.h>
26 
27 #include <ATen/TensorUtils.h>
28 #include <c10/util/irange.h>
29 
30 #include <stdint.h>
31 #include <algorithm>
32 #include <functional>
33 #include <iterator>
34 #include <memory>
35 #include <mutex>
36 #include <sstream>
37 #include <unordered_map>
38 
39 // Note [behavior of cudnnFind and cudnnGet]
40 // You'll notice that by default, in the ConvolutionDescriptor, we do the
41 // following:
42 //
43 //     AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(),
44 //     CUDNN_DEFAULT_MATH)); if(dataType == CUDNN_DATA_HALF)
45 //       AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(),
46 //       CUDNN_TENSOR_OP_MATH));
47 //
48 //     Update: AT_CUDNN_CHECK is updated with AT_CUDNN_CHECK_WITH_SHAPES, which
49 //        automatically prints tensor shapes and convolution parameters if there
50 //        is a cuDNN exception thrown.
51 //
52 // When cudnnSetConvolutionMathType is called before cudnnGet/cudnnFind, it
53 // informs cudnnGet/cudnnFind to iterate/take into account both tensor core and
54 // non-tensor-core algos. If you don't call cudnnSetConvolutionMathType before
55 // calling cudnnGet/cudnnFind, cudnnGet/cudnnFind may not pick tensor core
56 // algos.
57 //
58 // Now after its run, cudnnGet/cudnnFind comes up with the best pair of
59 // algo+mathType with all the initial knowledge its given. It then becomes the
60 // user's responsibility to update mathType of the convolution descriptor and
61 // call the subsequent cudnn calls with the best algo and the updated
62 // descriptor. If we don't update the descriptor but just run with the best
63 // algo, under the hood, cudnn will run with the slower kernel since it sees
64 // fastest algorithm combination with a sub optimal mathType.
65 
operator ""_TiB(unsigned long long n)66 constexpr size_t operator"" _TiB(unsigned long long n) {
67   return size_t(n) * 1024 * 1024 * 1024 * 1024;
68 }
69 
70 namespace at {
71 namespace native {
72 
73 // Convenience struct for passing around descriptors and data
74 // pointers
75 struct ConvolutionArgs {
76   cudnnHandle_t handle;
77   ConvolutionParams params;
78   TensorDescriptor idesc, odesc;
79   FilterDescriptor wdesc;
80   const Tensor &input, output, weight;
81   ConvolutionDescriptor cdesc;
82 
ConvolutionArgsat::native::ConvolutionArgs83   ConvolutionArgs(
84       const Tensor& input,
85       const Tensor& output,
86       const Tensor& weight)
87       : input(input), output(output), weight(weight) {}
88 };
89 
operator <<(std::ostream & out,const ConvolutionArgs & args)90 std::ostream& operator<<(std::ostream& out, const ConvolutionArgs& args) {
91   out << repro_from_args(args.params) // already has a trailing newline
92       << args.params // already has a trailing newline
93       << "input: " << args.idesc // already has a trailing newline
94       << "output: " << args.odesc // already has a trailing newline
95       << "weight: " << args.wdesc // already has a trailing newline
96       << "Pointer addresses: "
97       << "\n"
98       << "    input: " << args.input.const_data_ptr() << "\n"
99       << "    output: " << args.output.const_data_ptr() << "\n"
100       << "    weight: " << args.weight.const_data_ptr() << "\n";
101 
102   return out;
103 }
104 
105 // ---------------------------------------------------------------------
106 //
107 // Benchmarking
108 //
109 // ---------------------------------------------------------------------
110 
111 // TODO: Use something less heavy duty than a big honking mutex
112 template <typename T>
113 struct BenchmarkCache {
114   std::mutex mutex;
115   std::unordered_map<
116       ConvolutionParams,
117       T,
118       ParamsHash<ConvolutionParams>,
119       ParamsEqual<ConvolutionParams>>
120       map;
121 
findat::native::BenchmarkCache122   bool find(const ConvolutionParams& params, T* results) {
123     std::lock_guard<std::mutex> guard(mutex);
124     auto it = map.find(params);
125     if (it == map.end()) {
126       return false;
127     }
128     *results = it->second;
129     return true;
130   }
131 
insertat::native::BenchmarkCache132   void insert(const ConvolutionParams& params, const T& results) {
133     std::lock_guard<std::mutex> guard(mutex);
134     map[params] = results;
135   }
136 };
137 
138 BenchmarkCache<cudnnConvolutionFwdAlgoPerf_t> fwd_algos;
139 BenchmarkCache<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_algos;
140 BenchmarkCache<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filter_algos;
141 
142 // TODO: Stop manually allocating CUDA memory; allocate an ATen byte
143 // tensor instead.
144 struct Workspace {
Workspaceat::native::Workspace145   Workspace(size_t size) : size(size), data(NULL) {
146     // Sometimes cuDNN returns a workspace size > 2^63, this could makes the
147     // allocation of workspace fail with some 64bit indexing error instead of an
148     // OOM error. In such case, we manually fail with OOM.
149     TORCH_CHECK_WITH(
150         OutOfMemoryError, size < 1_TiB, "Not enough memory for workspace!");
151     data = c10::cuda::CUDACachingAllocator::raw_alloc(size);
152   }
153   Workspace(const Workspace&) = delete;
154   Workspace(Workspace&&) = default;
155   Workspace& operator=(Workspace&&) = default;
~Workspaceat::native::Workspace156   ~Workspace() {
157     if (data) {
158       c10::cuda::CUDACachingAllocator::raw_delete(data);
159     }
160   }
161 
162   size_t size;
163   void* data;
164 };
165 
166 template <typename perf_t>
167 struct algorithm_search {};
168 
getWorkspaceSize(const ConvolutionArgs & args,cudnnConvolutionFwdAlgo_t algo,size_t * sz)169 cudnnStatus_t getWorkspaceSize(
170     const ConvolutionArgs& args,
171     cudnnConvolutionFwdAlgo_t algo,
172     size_t* sz) {
173   return cudnnGetConvolutionForwardWorkspaceSize(
174       args.handle,
175       args.idesc.desc(),
176       args.wdesc.desc(),
177       args.cdesc.desc(),
178       args.odesc.desc(),
179       algo,
180       sz);
181 }
getWorkspaceSize(const ConvolutionArgs & args,cudnnConvolutionBwdDataAlgo_t algo,size_t * sz)182 cudnnStatus_t getWorkspaceSize(
183     const ConvolutionArgs& args,
184     cudnnConvolutionBwdDataAlgo_t algo,
185     size_t* sz) {
186   return cudnnGetConvolutionBackwardDataWorkspaceSize(
187       args.handle,
188       args.wdesc.desc(),
189       args.odesc.desc(),
190       args.cdesc.desc(),
191       args.idesc.desc(),
192       algo,
193       sz);
194 }
getWorkspaceSize(const ConvolutionArgs & args,cudnnConvolutionBwdFilterAlgo_t algo,size_t * sz)195 cudnnStatus_t getWorkspaceSize(
196     const ConvolutionArgs& args,
197     cudnnConvolutionBwdFilterAlgo_t algo,
198     size_t* sz) {
199   return cudnnGetConvolutionBackwardFilterWorkspaceSize(
200       args.handle,
201       args.idesc.desc(),
202       args.odesc.desc(),
203       args.cdesc.desc(),
204       args.wdesc.desc(),
205       algo,
206       sz);
207 }
208 
209 template <typename algo_t>
getMaxWorkspaceSize(const ConvolutionArgs & args,const algo_t * algo,int n_algo)210 size_t getMaxWorkspaceSize(
211     const ConvolutionArgs& args,
212     const algo_t* algo,
213     int n_algo) {
214   size_t max_ws_size = 0;
215   size_t max_block_size = 0;
216 
217   const auto device = c10::cuda::current_device();
218   // For the native allocator, retrieves the size of the largest unused block.
219   // For cudaMallocAsync, see c10/cuda/CUDAMallocAsync.cpp:cacheInfo for
220   // details.
221   c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
222 
223   for (const auto i : c10::irange(n_algo)) {
224     cudnnStatus_t err;
225     size_t sz;
226     err = getWorkspaceSize(args, algo[i], &sz);
227     if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size ||
228         sz > max_block_size)
229       continue;
230     max_ws_size = sz;
231   }
232   return max_ws_size;
233 }
234 
235 template <typename perf_t>
getValidAlgorithms(perf_t * perfResults,const ConvolutionArgs & args,int n_algo)236 std::vector<perf_t> getValidAlgorithms(
237     perf_t* perfResults,
238     const ConvolutionArgs& args,
239     int n_algo) {
240   std::vector<perf_t> result;
241   result.reserve(n_algo);
242   for (const auto i : c10::irange(n_algo)) {
243     perf_t perf = perfResults[i];
244 
245     // TODO: Shouldn't all returned results be successful?
246     // Double check documentation for cudnnFindConvolutionForwardAlgorithmEx
247     if (perf.status == CUDNN_STATUS_SUCCESS) {
248       if (!args.params.deterministic ||
249           perf.determinism == CUDNN_DETERMINISTIC) {
250         result.push_back(perf);
251       }
252     }
253   }
254   TORCH_CHECK(
255       result.size() > 0, "no valid convolution algorithms available in CuDNN");
256   return result;
257 }
258 
259 template <>
260 struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
261   using perf_t = cudnnConvolutionFwdAlgoPerf_t;
262   using algo_t = cudnnConvolutionFwdAlgo_t;
263 
264   static constexpr auto DEFAULT_ALGO =
265       CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
cacheat::native::algorithm_search266   static BenchmarkCache<perf_t>& cache() {
267     return fwd_algos;
268   }
269 
findAlgorithmsat::native::algorithm_search270   static std::vector<perf_t> findAlgorithms(
271       const ConvolutionArgs& args,
272       bool benchmark) {
273     static const algo_t algos[] = {
274         CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
275         CUDNN_CONVOLUTION_FWD_ALGO_FFT,
276         CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
277         CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
278         CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
279         CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
280         CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
281         CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
282     };
283     static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
284     static_assert(
285         sizeof(algos) / sizeof(algos[0]) == num_algos,
286         "Missing cuDNN convolution forward algorithms");
287     int perf_count;
288     std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
289     if (!benchmark) {
290       AT_CUDNN_CHECK_WITH_SHAPES(
291           cudnnGetConvolutionForwardAlgorithm_v7(
292               args.handle,
293               args.idesc.desc(),
294               args.wdesc.desc(),
295               args.cdesc.desc(),
296               args.odesc.desc(),
297               num_algos,
298               &perf_count,
299               perf_results.get()),
300           args);
301     } else {
302       size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
303       Workspace ws(max_ws_size);
304       at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
305       AT_CUDNN_CHECK_WITH_SHAPES(
306           cudnnFindConvolutionForwardAlgorithmEx(
307               args.handle,
308               args.idesc.desc(),
309               args.input.const_data_ptr(),
310               args.wdesc.desc(),
311               args.weight.const_data_ptr(),
312               args.cdesc.desc(),
313               args.odesc.desc(),
314               args.output.data_ptr(),
315               num_algos,
316               &perf_count,
317               perf_results.get(),
318               ws.data,
319               ws.size),
320           args);
321 
322       // Free the cached blocks in our caching allocator. They are
323       // needed here because the above benchmarking uses a huge amount of
324       // memory, e.g. a few GBs.
325       c10::cuda::CUDACachingAllocator::emptyCache();
326     }
327     return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count);
328   }
329 
getWorkspaceSizeat::native::algorithm_search330   static void getWorkspaceSize(
331       const ConvolutionArgs& args,
332       algo_t algo,
333       size_t* workspaceSize) {
334     AT_CUDNN_CHECK_WITH_SHAPES(
335         cudnnGetConvolutionForwardWorkspaceSize(
336             args.handle,
337             args.idesc.desc(),
338             args.wdesc.desc(),
339             args.cdesc.desc(),
340             args.odesc.desc(),
341             algo,
342             workspaceSize),
343         args);
344   }
345 };
346 
347 template <>
348 struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
349   using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
350   using algo_t = cudnnConvolutionBwdDataAlgo_t;
351 
352   static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
cacheat::native::algorithm_search353   static BenchmarkCache<perf_t>& cache() {
354     return bwd_data_algos;
355   }
356 
findAlgorithmsat::native::algorithm_search357   static std::vector<perf_t> findAlgorithms(
358       const ConvolutionArgs& args,
359       bool benchmark) {
360     static const algo_t algos[] = {
361         CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
362         CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
363         CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
364         CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
365         CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
366         CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED};
367     static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
368     static_assert(
369         sizeof(algos) / sizeof(algos[0]) == num_algos,
370         "Missing cuDNN convolution backward data algorithms.");
371     int perf_count;
372     std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
373     if (!benchmark) {
374       AT_CUDNN_CHECK_WITH_SHAPES(
375           cudnnGetConvolutionBackwardDataAlgorithm_v7(
376               args.handle,
377               args.wdesc.desc(),
378               args.odesc.desc(),
379               args.cdesc.desc(),
380               args.idesc.desc(),
381               num_algos,
382               &perf_count,
383               perf_results.get()),
384           args);
385     } else {
386       size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
387       Workspace ws(max_ws_size);
388       at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
389       AT_CUDNN_CHECK_WITH_SHAPES(
390           cudnnFindConvolutionBackwardDataAlgorithmEx(
391               args.handle,
392               args.wdesc.desc(),
393               args.weight.const_data_ptr(),
394               args.odesc.desc(),
395               args.output.const_data_ptr(),
396               args.cdesc.desc(),
397               args.idesc.desc(),
398               args.input.data_ptr(),
399               num_algos,
400               &perf_count,
401               perf_results.get(),
402               ws.data,
403               ws.size),
404           args);
405 
406       // Free the cached blocks in our caching allocator. They are
407       // needed here because the above benchmarking uses a huge amount of
408       // memory, e.g. a few GBs.
409       c10::cuda::CUDACachingAllocator::emptyCache();
410     }
411     return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count);
412   }
413 
getWorkspaceSizeat::native::algorithm_search414   static void getWorkspaceSize(
415       const ConvolutionArgs& args,
416       cudnnConvolutionBwdDataAlgo_t algo,
417       size_t* workspaceSize) {
418     AT_CUDNN_CHECK_WITH_SHAPES(
419         cudnnGetConvolutionBackwardDataWorkspaceSize(
420             args.handle,
421             args.wdesc.desc(),
422             args.odesc.desc(),
423             args.cdesc.desc(),
424             args.idesc.desc(),
425             algo,
426             workspaceSize),
427         args);
428   }
429 };
430 
431 template <>
432 struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
433   using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
434   using algo_t = cudnnConvolutionBwdFilterAlgo_t;
435 
436   static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
437 
cacheat::native::algorithm_search438   static BenchmarkCache<perf_t>& cache() {
439     return bwd_filter_algos;
440   }
441 
findAlgorithmsat::native::algorithm_search442   static std::vector<perf_t> findAlgorithms(
443       const ConvolutionArgs& args,
444       bool benchmark) {
445     static const algo_t algos[] = {
446         CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
447         CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
448         CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
449         CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
450         CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
451         CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
452     };
453     // NOTE: - 1 because ALGO_WINOGRAD is not implemented
454     static constexpr int num_algos =
455         CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1;
456     static_assert(
457         sizeof(algos) / sizeof(algos[0]) == num_algos,
458         "Missing cuDNN convolution backward filter algorithms.");
459     std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
460     int perf_count;
461     if (!benchmark) {
462       AT_CUDNN_CHECK_WITH_SHAPES(
463           cudnnGetConvolutionBackwardFilterAlgorithm_v7(
464               args.handle,
465               args.idesc.desc(),
466               args.odesc.desc(),
467               args.cdesc.desc(),
468               args.wdesc.desc(),
469               num_algos,
470               &perf_count,
471               perf_results.get()),
472           args);
473     } else {
474       size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
475       Workspace ws(max_ws_size);
476       at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
477       AT_CUDNN_CHECK_WITH_SHAPES(
478           cudnnFindConvolutionBackwardFilterAlgorithmEx(
479               args.handle,
480               args.idesc.desc(),
481               args.input.const_data_ptr(),
482               args.odesc.desc(),
483               args.output.const_data_ptr(),
484               args.cdesc.desc(),
485               args.wdesc.desc(),
486               args.weight.data_ptr(),
487               num_algos,
488               &perf_count,
489               perf_results.get(),
490               ws.data,
491               ws.size),
492           args);
493 
494       // Free the cached blocks in our caching allocator. They are
495       // needed here because the above benchmarking uses a huge amount of
496       // memory, e.g. a few GBs.
497       c10::cuda::CUDACachingAllocator::emptyCache();
498     }
499     return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count);
500   }
501 
getWorkspaceSizeat::native::algorithm_search502   static void getWorkspaceSize(
503       const ConvolutionArgs& args,
504       algo_t algo,
505       size_t* workspaceSize) {
506     AT_CUDNN_CHECK_WITH_SHAPES(
507         cudnnGetConvolutionBackwardFilterWorkspaceSize(
508             args.handle,
509             args.idesc.desc(),
510             args.odesc.desc(),
511             args.cdesc.desc(),
512             args.wdesc.desc(),
513             algo,
514             workspaceSize),
515         args);
516   }
517 };
518 
519 template <typename perf_t>
520 class AlgoIterator {
521   using search = algorithm_search<perf_t>;
522   const ConvolutionArgs& args;
523   bool benchmark;
524 
525  public:
AlgoIterator(const ConvolutionArgs & args,bool benchmark)526   AlgoIterator(const ConvolutionArgs& args, bool benchmark)
527       : args(args), benchmark(benchmark) {}
528 
onlyDefaultAlgorithm(const ConvolutionArgs & args)529   static std::vector<perf_t> onlyDefaultAlgorithm(const ConvolutionArgs& args) {
530     std::vector<perf_t> perfResults(1);
531     perfResults[0].algo = search::DEFAULT_ALGO;
532     if (args.params.dataType == CUDNN_DATA_HALF) {
533       perfResults[0].mathType = CUDNN_TENSOR_OP_MATH;
534     } else {
535       perfResults[0].mathType = CUDNN_DEFAULT_MATH;
536       if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) {
537         perfResults[0].mathType = CUDNN_FMA_MATH;
538       }
539     }
540     search::getWorkspaceSize(
541         args, perfResults[0].algo, &(perfResults[0].memory));
542     return perfResults;
543   }
544 
try_all(std::function<void (const perf_t & perf)> f)545   void try_all(std::function<void(const perf_t& perf)> f) {
546     bool only_use_default = args.params.deterministic && !benchmark;
547 
548     auto& cache = search::cache();
549     perf_t algoPerf;
550     if (!only_use_default && cache.find(args.params, &algoPerf)) {
551       try {
552         f(algoPerf);
553         return;
554       } catch (c10::OutOfMemoryError& e) {
555         cudaGetLastError(); // clear CUDA error
556       }
557     }
558 
559     auto perfResults = only_use_default
560         ? onlyDefaultAlgorithm(args)
561         : search::findAlgorithms(args, benchmark);
562     for (auto& algoPerf : perfResults) {
563       try {
564         f(algoPerf);
565         cache.insert(args.params, algoPerf);
566         return;
567       } catch (c10::OutOfMemoryError& e) {
568         cudaGetLastError(); // clear CUDA error
569       } catch (c10::CuDNNError& e) {
570         cudaGetLastError(); // clear CUDA error
571       }
572     }
573     TORCH_CHECK(
574         false, "Unable to find a valid cuDNN algorithm to run convolution");
575   }
576 };
577 
allocate_workspace(size_t size,const Tensor & other)578 inline Tensor allocate_workspace(size_t size, const Tensor& other) {
579   // Sometimes cuDNN returns a workspace size > 2^63, this could makes the
580   // allocation of workspace fail with some 64bit indexing error instead of an
581   // OOM error. In such case, we manually fail with OOM.
582   TORCH_CHECK_WITH(
583       OutOfMemoryError, size < 1_TiB, "Not enough memory for workspace!");
584   return at::empty({static_cast<int64_t>(size)}, other.options().dtype(kByte));
585 }
586 
587 // NOTE [ raw_cudnn_convolution_forward_out ]
588 //
589 //    - raw_cudnn_convolution_forward_out (Tensor)
590 //      Functiont that handles tensors that are too large to use 32bit indexing.
591 //      It just split the tensor and dispatches to
592 //      `raw_cudnn_convolution_forward_out_32bit`.
593 //
594 //    - raw_cudnn_convolution_forward_out_32bit (Tensor)
595 //      Low level function which invokes CuDNN, and takes an output
596 //      tensor which is directly written to (thus _out).
597 //
598 
599 // ---------------------------------------------------------------------
600 //
601 // Splitting to 32bit
602 //
603 // ---------------------------------------------------------------------
604 
605 template <typename func_t>
split_batch_dim_to_32bit_out(const at::Tensor & output,const at::Tensor & input,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32,int64_t max_worksize,func_t func_32bit)606 static inline void split_batch_dim_to_32bit_out(
607     const at::Tensor& output,
608     const at::Tensor& input,
609     const at::Tensor& weight,
610     IntArrayRef padding,
611     IntArrayRef stride,
612     IntArrayRef dilation,
613     int64_t groups,
614     bool benchmark,
615     bool deterministic,
616     bool allow_tf32,
617     int64_t max_worksize,
618     func_t func_32bit) {
619   constexpr int64_t int_max = std::numeric_limits<int>::max();
620   const int64_t ni = input.numel();
621   const int64_t no = output.numel();
622   // Assume the shape of the tensor is (N, C, D1, D2, ...)
623   // if N * C * D1 * D2 * ... <= int_max, then no need to split at all
624   if (ni <= int_max && no <= int_max) {
625     func_32bit(
626         output,
627         input,
628         weight,
629         padding,
630         stride,
631         dilation,
632         groups,
633         benchmark,
634         deterministic,
635         allow_tf32);
636     return;
637   }
638   // else, if C * D1 * D2 * ... <= int_max, then we just need to split across
639   // the N dimension
640   //
641   // Here we use a simple heuristics to determine the size of each split
642   // We don't max out the 2^31 address space because this number is super
643   // large and very likely to get an OOM.
644   int64_t n = output.size(0);
645   int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
646   int64_t split_size = std::max<int64_t>(max_worksize / max_inner_size, 1L);
647   int64_t num_splits = (n + split_size - 1) / split_size;
648   if (split_size * max_inner_size < int_max) {
649     for (const auto i : c10::irange(num_splits)) {
650       int64_t start = split_size * i;
651       int64_t split_size_ = std::min<int64_t>(split_size, n - start);
652       Tensor input_ = input.narrow(0, start, split_size_);
653       Tensor output_ = output.narrow(0, start, split_size_);
654       func_32bit(
655           output_,
656           input_,
657           weight,
658           padding,
659           stride,
660           dilation,
661           groups,
662           benchmark,
663           deterministic,
664           allow_tf32);
665     }
666     return;
667   }
668   // If control flow reaches here, this means even splitting N is not enough,
669   // then things starts to become complicated: For example, for conv2d, there
670   // following questions needs to be considered.
671   // - Is the memory layout NCHW or NHWC ?
672   // - If the conv is NCHW -> NC'H'W', then should we
673   //   - split only NC?
674   //   - split only N'C'?
675   //   - split both?
676   // - If the conv is NHWC, then we need to split across H, we need to be very
677   // careful about the boundary condition
678   //   to make sure that the boundary is handled correctly.
679   // - If we decide to make these splits, is the memory contiguous? Do we need
680   // to copy the memory? Considering the complexity of this issue, it is better
681   // not to use cuDNN for this case
682   TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
683 }
684 
685 #define ASSERT_CORRECT_PRECISION(math_type)                     \
686   if (args.params.dataType == CUDNN_DATA_FLOAT) {               \
687     TORCH_INTERNAL_ASSERT(                                      \
688         args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \
689   }
690 
691 // ---------------------------------------------------------------------
692 //
693 // Convolution forward / Transposed convolution backward
694 //
695 // ---------------------------------------------------------------------
696 
raw_cudnn_convolution_forward_out_32bit(const Tensor & output,const Tensor & input,const Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)697 void raw_cudnn_convolution_forward_out_32bit(
698     const Tensor& output,
699     const Tensor& input,
700     const Tensor& weight,
701     IntArrayRef padding,
702     IntArrayRef stride,
703     IntArrayRef dilation,
704     int64_t groups,
705     bool benchmark,
706     bool deterministic,
707     bool allow_tf32) {
708   auto dataType = getCudnnDataType(input);
709 
710   ConvolutionArgs args{input, output, weight};
711   args.handle = getCudnnHandle();
712   at::MemoryFormat memory_format =
713       cudnn_conv_suggest_memory_format(input, weight);
714   setConvolutionParams(
715       &args.params,
716       input,
717       weight,
718       padding,
719       stride,
720       dilation,
721       groups,
722       deterministic,
723       allow_tf32,
724       memory_format);
725   args.idesc.set(input, memory_format);
726   args.wdesc.set(weight, memory_format, 0);
727   args.odesc.set(output, memory_format);
728   args.cdesc.set(
729       dataType,
730       input.dim() - 2,
731       args.params.padding,
732       args.params.stride,
733       args.params.dilation,
734       args.params.groups,
735       args.params.allow_tf32);
736 
737   // TODO: when we do legacy group convolution support, we'll repeatedly
738   // reinitialize the workspace for each convolution we do.  This is
739   // wasteful; we'd rather reuse the workspace.  OTOH, legacy group
740   // convolution support is already pretty slow, so this might not
741   // matter.  (This applies to raw_cudnn_convolution_backward_input as well.)
742   AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark)
743       .try_all([&](const cudnnConvolutionFwdAlgoPerf_t& fwdAlgPerf) {
744         Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input);
745 
746         // update convDesc mathType since cudnn 7.4+ now requires both algo +
747         // mathType to figure out whether to use Tensor core kernels or not See
748         // Note [behavior of cudnnFind and cudnnGet]
749         ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
750         AT_CUDNN_CHECK_WITH_SHAPES(
751             cudnnSetConvolutionMathType(
752                 args.cdesc.mut_desc(), fwdAlgPerf.mathType),
753             args);
754 
755         Constant one(dataType, 1);
756         Constant zero(dataType, 0);
757 
758         AT_CUDNN_CHECK_WITH_SHAPES(
759             cudnnConvolutionForward(
760                 args.handle,
761                 &one,
762                 args.idesc.desc(),
763                 input.const_data_ptr(),
764                 args.wdesc.desc(),
765                 weight.const_data_ptr(),
766                 args.cdesc.desc(),
767                 fwdAlgPerf.algo,
768                 workspace.data_ptr(),
769                 fwdAlgPerf.memory,
770                 &zero,
771                 args.odesc.desc(),
772                 output.data_ptr()),
773             args,
774             "Forward algorithm: ",
775             static_cast<int>(fwdAlgPerf.algo),
776             "\n");
777       });
778 }
779 
raw_cudnn_convolution_forward_out_v7(const Tensor & output,const Tensor & input,const Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)780 void raw_cudnn_convolution_forward_out_v7(
781     const Tensor& output,
782     const Tensor& input,
783     const Tensor& weight,
784     IntArrayRef padding,
785     IntArrayRef stride,
786     IntArrayRef dilation,
787     int64_t groups,
788     bool benchmark,
789     bool deterministic,
790     bool allow_tf32) {
791   split_batch_dim_to_32bit_out(
792       output,
793       input,
794       weight,
795       padding,
796       stride,
797       dilation,
798       groups,
799       benchmark,
800       deterministic,
801       allow_tf32,
802       1024 * 1024 * 256,
803       raw_cudnn_convolution_forward_out_32bit);
804 }
805 
806 // ---------------------------------------------------------------------
807 //
808 // Convolution backward / Transposed convolution forward
809 //
810 // ---------------------------------------------------------------------
811 
raw_cudnn_convolution_backward_input_out_32bit(const at::Tensor & grad_input,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)812 void raw_cudnn_convolution_backward_input_out_32bit(
813     const at::Tensor& grad_input,
814     const at::Tensor& grad_output,
815     const at::Tensor& weight,
816     IntArrayRef padding,
817     IntArrayRef stride,
818     IntArrayRef dilation,
819     int64_t groups,
820     bool benchmark,
821     bool deterministic,
822     bool allow_tf32) {
823   auto dataType = getCudnnDataType(grad_output);
824 
825   ConvolutionArgs args{grad_input, grad_output, weight};
826   args.handle = getCudnnHandle();
827   at::MemoryFormat memory_format =
828       cudnn_conv_suggest_memory_format(grad_input, weight);
829   setConvolutionParams(
830       &args.params,
831       grad_input,
832       weight,
833       padding,
834       stride,
835       dilation,
836       groups,
837       deterministic,
838       allow_tf32,
839       memory_format);
840   args.idesc.set(grad_input, memory_format);
841   args.wdesc.set(weight, memory_format, 0);
842   args.odesc.set(grad_output, memory_format);
843   args.cdesc.set(
844       dataType,
845       grad_output.dim() - 2,
846       args.params.padding,
847       args.params.stride,
848       args.params.dilation,
849       args.params.groups,
850       args.params.allow_tf32);
851 
852   AlgoIterator<cudnnConvolutionBwdDataAlgoPerf_t>(args, benchmark)
853       .try_all([&](const cudnnConvolutionBwdDataAlgoPerf_t& bwdDataAlgPerf) {
854         Tensor workspace =
855             allocate_workspace(bwdDataAlgPerf.memory, grad_output);
856 
857         // update convDesc mathType since cudnn 7.4+ now requires both algo +
858         // mathType to figure out whether to use Tensor core kernels or not See
859         // Note [behavior of cudnnFind and cudnnGet]
860         ASSERT_CORRECT_PRECISION(bwdDataAlgPerf.mathType);
861         AT_CUDNN_CHECK_WITH_SHAPES(
862             cudnnSetConvolutionMathType(
863                 args.cdesc.mut_desc(), bwdDataAlgPerf.mathType),
864             args);
865 
866         Constant one(dataType, 1);
867         Constant zero(dataType, 0);
868 
869         AT_CUDNN_CHECK_WITH_SHAPES(
870             cudnnConvolutionBackwardData(
871                 args.handle,
872                 &one,
873                 args.wdesc.desc(),
874                 weight.const_data_ptr(),
875                 args.odesc.desc(),
876                 grad_output.const_data_ptr(),
877                 args.cdesc.desc(),
878                 bwdDataAlgPerf.algo,
879                 workspace.data_ptr(),
880                 bwdDataAlgPerf.memory,
881                 &zero,
882                 args.idesc.desc(),
883                 grad_input.mutable_data_ptr()),
884             args,
885             "Additional pointer addresses: \n",
886             "    grad_output: ",
887             grad_output.const_data_ptr(),
888             "\n",
889             "    grad_input: ",
890             grad_input.mutable_data_ptr(),
891             "\n",
892             "Backward data algorithm: ",
893             static_cast<int>(bwdDataAlgPerf.algo),
894             "\n");
895       });
896 }
897 
raw_cudnn_convolution_backward_input_out_v7(const at::Tensor & grad_input,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)898 void raw_cudnn_convolution_backward_input_out_v7(
899     const at::Tensor& grad_input,
900     const at::Tensor& grad_output,
901     const at::Tensor& weight,
902     IntArrayRef padding,
903     IntArrayRef stride,
904     IntArrayRef dilation,
905     int64_t groups,
906     bool benchmark,
907     bool deterministic,
908     bool allow_tf32) {
909   split_batch_dim_to_32bit_out(
910       grad_input,
911       grad_output,
912       weight,
913       padding,
914       stride,
915       dilation,
916       groups,
917       benchmark,
918       deterministic,
919       allow_tf32,
920       1024 * 1024 * 128,
921       raw_cudnn_convolution_backward_input_out_32bit);
922 }
923 
924 // ---------------------------------------------------------------------
925 //
926 // Convolution backward (weight)
927 //
928 // ---------------------------------------------------------------------
929 
raw_cudnn_convolution_backward_weight_out_32bit(const Tensor & grad_weight,const Tensor & grad_output,const Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)930 void raw_cudnn_convolution_backward_weight_out_32bit(
931     const Tensor& grad_weight,
932     const Tensor& grad_output,
933     const Tensor& input,
934     IntArrayRef padding,
935     IntArrayRef stride,
936     IntArrayRef dilation,
937     int64_t groups,
938     bool benchmark,
939     bool deterministic,
940     bool allow_tf32) {
941   auto dataType = getCudnnDataType(input);
942 
943   ConvolutionArgs args{input, grad_output, grad_weight};
944   args.handle = getCudnnHandle();
945   at::MemoryFormat memory_format =
946       cudnn_conv_suggest_memory_format(input, grad_weight);
947   setConvolutionParams(
948       &args.params,
949       input,
950       grad_weight,
951       padding,
952       stride,
953       dilation,
954       groups,
955       deterministic,
956       allow_tf32,
957       memory_format);
958   args.idesc.set(input, memory_format);
959   args.wdesc.set(grad_weight, memory_format, 0);
960   args.odesc.set(grad_output, memory_format);
961   args.cdesc.set(
962       dataType,
963       input.dim() - 2,
964       args.params.padding,
965       args.params.stride,
966       args.params.dilation,
967       args.params.groups,
968       args.params.allow_tf32);
969 
970   AlgoIterator<cudnnConvolutionBwdFilterAlgoPerf_t>(args, benchmark)
971       .try_all(
972           [&](const cudnnConvolutionBwdFilterAlgoPerf_t& bwdFilterAlgPerf) {
973             Tensor workspace =
974                 allocate_workspace(bwdFilterAlgPerf.memory, input);
975 
976             // update convDesc mathType since cudnn 7.4+ now requires both algo
977             // + mathType to figure out whether to use Tensor core kernels or
978             // not See Note [behavior of cudnnFind and cudnnGet]
979             ASSERT_CORRECT_PRECISION(bwdFilterAlgPerf.mathType);
980             AT_CUDNN_CHECK_WITH_SHAPES(
981                 cudnnSetConvolutionMathType(
982                     args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType),
983                 args);
984 
985             Constant one(dataType, 1);
986             Constant zero(dataType, 0);
987 
988             AT_CUDNN_CHECK_WITH_SHAPES(
989                 cudnnConvolutionBackwardFilter(
990                     args.handle,
991                     &one,
992                     args.idesc.desc(),
993                     input.const_data_ptr(),
994                     args.odesc.desc(),
995                     grad_output.const_data_ptr(),
996                     args.cdesc.desc(),
997                     bwdFilterAlgPerf.algo,
998                     workspace.data_ptr(),
999                     bwdFilterAlgPerf.memory,
1000                     &zero,
1001                     args.wdesc.desc(),
1002                     grad_weight.data_ptr()),
1003                 args,
1004                 "Additional pointer addresses: \n",
1005                 "    grad_output: ",
1006                 grad_output.const_data_ptr(),
1007                 "\n",
1008                 "    grad_weight: ",
1009                 grad_weight.data_ptr(),
1010                 "\n",
1011                 "Backward filter algorithm: ",
1012                 static_cast<int>(bwdFilterAlgPerf.algo),
1013                 "\n");
1014           });
1015 }
1016 
raw_cudnn_convolution_backward_weight_out_v7(const Tensor & grad_weight,const Tensor & grad_output,const Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)1017 void raw_cudnn_convolution_backward_weight_out_v7(
1018     const Tensor& grad_weight,
1019     const Tensor& grad_output,
1020     const Tensor& input,
1021     IntArrayRef padding,
1022     IntArrayRef stride,
1023     IntArrayRef dilation,
1024     int64_t groups,
1025     bool benchmark,
1026     bool deterministic,
1027     bool allow_tf32) {
1028   constexpr int64_t int_max = std::numeric_limits<int>::max();
1029   const int64_t ni = input.numel();
1030   const int64_t no = grad_output.numel();
1031   // Assume the shape of the tensor is (N, C, D1, D2, ...)
1032   // if N * C * D1 * D2 * ... <= int_max, then no need to split at all
1033   if (ni <= int_max && no <= int_max) {
1034     raw_cudnn_convolution_backward_weight_out_32bit(
1035         grad_weight,
1036         grad_output,
1037         input,
1038         padding,
1039         stride,
1040         dilation,
1041         groups,
1042         benchmark,
1043         deterministic,
1044         allow_tf32);
1045     return;
1046   }
1047   // else, if C * D1 * D2 * ... <= int_max, then we just need to split across
1048   // the N dimension
1049   //
1050   // Here we use a simple heuristics to determine the size of each split
1051   // We don't max out the 2^31 address space because this number is super
1052   // large and very likely to get an OOM.
1053   int64_t n = grad_output.size(0);
1054   int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
1055   int64_t split_size =
1056       std::max<int64_t>(1024 * 1024 * 512 / max_inner_size, 1L);
1057   int64_t num_splits = (n + split_size - 1) / split_size;
1058   if (split_size * max_inner_size < int_max) {
1059     const auto kAccType = (grad_weight.scalar_type() == kHalf ||
1060                            grad_weight.scalar_type() == kBFloat16)
1061         ? kFloat
1062         : grad_weight.scalar_type();
1063     Tensor grad_weight_accumulator =
1064         at::zeros(grad_weight.sizes(), grad_weight.options().dtype(kAccType));
1065     for (const auto i : c10::irange(num_splits)) {
1066       int64_t start = split_size * i;
1067       int64_t split_size_ = std::min<int64_t>(split_size, n - start);
1068       Tensor input_ = input.narrow(0, start, split_size_);
1069       Tensor grad_output_ = grad_output.narrow(0, start, split_size_);
1070       Tensor grad_weight_ = at::empty_like(grad_weight);
1071       raw_cudnn_convolution_backward_weight_out_32bit(
1072           grad_weight_,
1073           grad_output_,
1074           input_,
1075           padding,
1076           stride,
1077           dilation,
1078           groups,
1079           benchmark,
1080           deterministic,
1081           allow_tf32);
1082       grad_weight_accumulator.add_(grad_weight_);
1083     }
1084     grad_weight.copy_(grad_weight_accumulator);
1085     return;
1086   }
1087   // If control flow reaches here, this means even splitting N is not enough,
1088   // then things starts to become complicated: For example, for conv2d, there
1089   // following questions needs to be considered.
1090   // - Is the memory layout NCHW or NHWC ?
1091   // - If the conv is NCHW -> NC'H'W', then should we
1092   //   - split only NC?
1093   //   - split only N'C'?
1094   //   - split both?
1095   // - If the conv is NHWC, then we need to split across H, we need to be very
1096   // careful about the boundary condition
1097   //   to make sure that the boundary is handled correctly.
1098   // - If we decide to make these splits, is the memory contiguous? Do we need
1099   // to copy the memory? Considering the complexity of this issue, it is better
1100   // not to use cuDNN for this case
1101   TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
1102 }
1103 
raw_cudnn_convolution_add_relu_out_v7(const Tensor & output,const Tensor & input,const Tensor & weight,const Tensor & z,float alpha,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)1104 void raw_cudnn_convolution_add_relu_out_v7(
1105     const Tensor& output,
1106     const Tensor& input,
1107     const Tensor& weight,
1108     const Tensor& z,
1109     float alpha,
1110     const Tensor& bias,
1111     IntArrayRef stride,
1112     IntArrayRef padding,
1113     IntArrayRef dilation,
1114     int64_t groups,
1115     bool benchmark,
1116     bool deterministic,
1117     bool allow_tf32) {
1118   auto dataType = getCudnnDataType(input);
1119   ConvolutionArgs args{input, output, weight};
1120   args.handle = getCudnnHandle();
1121   at::MemoryFormat memory_format =
1122       cudnn_conv_suggest_memory_format(input, weight);
1123   setConvolutionParams(
1124       &args.params,
1125       input,
1126       weight,
1127       padding,
1128       stride,
1129       dilation,
1130       groups,
1131       deterministic,
1132       allow_tf32,
1133       memory_format);
1134   args.idesc.set(input, memory_format);
1135   args.wdesc.set(weight, memory_format, 0);
1136   args.odesc.set(output, memory_format);
1137   args.cdesc.set(
1138       dataType,
1139       input.dim() - 2,
1140       args.params.padding,
1141       args.params.stride,
1142       args.params.dilation,
1143       args.params.groups,
1144       args.params.allow_tf32);
1145 
1146   TensorDescriptor zdesc;
1147   zdesc.set(z, memory_format);
1148 
1149   TensorDescriptor bdesc;
1150   bdesc.set(bias.expand({1, bias.size(0)}), memory_format, output.dim());
1151 
1152   ActivationDescriptor adesc;
1153   adesc.set(CUDNN_ACTIVATION_RELU);
1154 
1155   AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark)
1156       .try_all([&](const cudnnConvolutionFwdAlgoPerf_t& fwdAlgPerf) {
1157         Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input);
1158 
1159         // update convDesc mathType since cudnn 7.4+ now requires both algo +
1160         // mathType to figure out whether to use Tensor core kernels or not See
1161         // Note [behavior of cudnnFind and cudnnGet]
1162         ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
1163         AT_CUDNN_CHECK_WITH_SHAPES(
1164             cudnnSetConvolutionMathType(
1165                 args.cdesc.mut_desc(), fwdAlgPerf.mathType),
1166             args);
1167 
1168         Constant one(dataType, 1);
1169         Constant alpha_(dataType, alpha);
1170 
1171         AT_CUDNN_CHECK_WITH_SHAPES(
1172             cudnnConvolutionBiasActivationForward(
1173                 args.handle,
1174                 &one,
1175                 args.idesc.desc(),
1176                 input.const_data_ptr(),
1177                 args.wdesc.desc(),
1178                 weight.const_data_ptr(),
1179                 args.cdesc.desc(),
1180                 fwdAlgPerf.algo,
1181                 workspace.data_ptr(),
1182                 fwdAlgPerf.memory,
1183                 &alpha_,
1184                 zdesc.desc(),
1185                 z.const_data_ptr(),
1186                 bdesc.desc(),
1187                 bias.const_data_ptr(),
1188                 adesc.desc(),
1189                 args.odesc.desc(),
1190                 output.data_ptr()),
1191             args,
1192             "zdesc: ",
1193             zdesc,
1194             "bdesc: ",
1195             bdesc,
1196             "cudnnConvolutionBiasActivationForward: ",
1197             static_cast<int>(fwdAlgPerf.algo),
1198             "\n");
1199       });
1200 }
1201 
raw_cudnn_convolution_add_relu_fallback_out(const Tensor & output,const Tensor & input,const Tensor & weight,const Tensor & z,float alpha,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)1202 void raw_cudnn_convolution_add_relu_fallback_out(
1203     const Tensor& output,
1204     const Tensor& input,
1205     const Tensor& weight,
1206     const Tensor& z,
1207     float alpha,
1208     const Tensor& bias,
1209     IntArrayRef stride,
1210     IntArrayRef padding,
1211     IntArrayRef dilation,
1212     int64_t groups,
1213     bool benchmark,
1214     bool deterministic,
1215     bool allow_tf32) {
1216   // cuDNN Conv-Bias-Activation:
1217   // y = act ( alpha1 * conv(x) + alpha2 * z + bias )
1218   // In pytorch function `raw_cudnn_convolution_add_relu_out`: alpha1 is 1,
1219   // alpha 2 is `float alpha`
1220 
1221   raw_cudnn_convolution_forward_out(
1222       output,
1223       input,
1224       weight,
1225       padding,
1226       stride,
1227       dilation,
1228       groups,
1229       benchmark,
1230       deterministic,
1231       allow_tf32);
1232   at::Tensor alpha_mul_z_add_bias =
1233       at::native::reshape_bias(input.dim(), bias).add(z, alpha);
1234   output.add_(alpha_mul_z_add_bias);
1235   output.relu_();
1236 }
1237 
1238 } // namespace native
1239 } // namespace at
1240 
1241 #endif
1242