1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
3
4 #if AT_CUDNN_ENABLED()
5
6 #include <ATen/core/Tensor.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #else
11 #include <ATen/ops/empty.h>
12 #include <ATen/ops/empty_like.h>
13 #include <ATen/ops/zeros.h>
14 #endif
15
16 #include <ATen/Config.h>
17 #include <ATen/cuda/Exceptions.h>
18 #include <ATen/native/cudnn/ConvShared.h>
19 #include <ATen/cuda/CUDAGraphsUtils.cuh>
20 #include <limits>
21 #include <vector>
22
23 #include <ATen/cudnn/Types.h>
24 #include <ATen/cudnn/Utils.h>
25 #include <ATen/native/utils/ParamsHash.h>
26
27 #include <ATen/TensorUtils.h>
28 #include <c10/util/irange.h>
29
30 #include <stdint.h>
31 #include <algorithm>
32 #include <functional>
33 #include <iterator>
34 #include <memory>
35 #include <mutex>
36 #include <sstream>
37 #include <unordered_map>
38
39 // Note [behavior of cudnnFind and cudnnGet]
40 // You'll notice that by default, in the ConvolutionDescriptor, we do the
41 // following:
42 //
43 // AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(),
44 // CUDNN_DEFAULT_MATH)); if(dataType == CUDNN_DATA_HALF)
45 // AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(),
46 // CUDNN_TENSOR_OP_MATH));
47 //
48 // Update: AT_CUDNN_CHECK is updated with AT_CUDNN_CHECK_WITH_SHAPES, which
49 // automatically prints tensor shapes and convolution parameters if there
50 // is a cuDNN exception thrown.
51 //
52 // When cudnnSetConvolutionMathType is called before cudnnGet/cudnnFind, it
53 // informs cudnnGet/cudnnFind to iterate/take into account both tensor core and
54 // non-tensor-core algos. If you don't call cudnnSetConvolutionMathType before
55 // calling cudnnGet/cudnnFind, cudnnGet/cudnnFind may not pick tensor core
56 // algos.
57 //
58 // Now after its run, cudnnGet/cudnnFind comes up with the best pair of
59 // algo+mathType with all the initial knowledge its given. It then becomes the
60 // user's responsibility to update mathType of the convolution descriptor and
61 // call the subsequent cudnn calls with the best algo and the updated
62 // descriptor. If we don't update the descriptor but just run with the best
63 // algo, under the hood, cudnn will run with the slower kernel since it sees
64 // fastest algorithm combination with a sub optimal mathType.
65
operator ""_TiB(unsigned long long n)66 constexpr size_t operator"" _TiB(unsigned long long n) {
67 return size_t(n) * 1024 * 1024 * 1024 * 1024;
68 }
69
70 namespace at {
71 namespace native {
72
73 // Convenience struct for passing around descriptors and data
74 // pointers
75 struct ConvolutionArgs {
76 cudnnHandle_t handle;
77 ConvolutionParams params;
78 TensorDescriptor idesc, odesc;
79 FilterDescriptor wdesc;
80 const Tensor &input, output, weight;
81 ConvolutionDescriptor cdesc;
82
ConvolutionArgsat::native::ConvolutionArgs83 ConvolutionArgs(
84 const Tensor& input,
85 const Tensor& output,
86 const Tensor& weight)
87 : input(input), output(output), weight(weight) {}
88 };
89
operator <<(std::ostream & out,const ConvolutionArgs & args)90 std::ostream& operator<<(std::ostream& out, const ConvolutionArgs& args) {
91 out << repro_from_args(args.params) // already has a trailing newline
92 << args.params // already has a trailing newline
93 << "input: " << args.idesc // already has a trailing newline
94 << "output: " << args.odesc // already has a trailing newline
95 << "weight: " << args.wdesc // already has a trailing newline
96 << "Pointer addresses: "
97 << "\n"
98 << " input: " << args.input.const_data_ptr() << "\n"
99 << " output: " << args.output.const_data_ptr() << "\n"
100 << " weight: " << args.weight.const_data_ptr() << "\n";
101
102 return out;
103 }
104
105 // ---------------------------------------------------------------------
106 //
107 // Benchmarking
108 //
109 // ---------------------------------------------------------------------
110
111 // TODO: Use something less heavy duty than a big honking mutex
112 template <typename T>
113 struct BenchmarkCache {
114 std::mutex mutex;
115 std::unordered_map<
116 ConvolutionParams,
117 T,
118 ParamsHash<ConvolutionParams>,
119 ParamsEqual<ConvolutionParams>>
120 map;
121
findat::native::BenchmarkCache122 bool find(const ConvolutionParams& params, T* results) {
123 std::lock_guard<std::mutex> guard(mutex);
124 auto it = map.find(params);
125 if (it == map.end()) {
126 return false;
127 }
128 *results = it->second;
129 return true;
130 }
131
insertat::native::BenchmarkCache132 void insert(const ConvolutionParams& params, const T& results) {
133 std::lock_guard<std::mutex> guard(mutex);
134 map[params] = results;
135 }
136 };
137
138 BenchmarkCache<cudnnConvolutionFwdAlgoPerf_t> fwd_algos;
139 BenchmarkCache<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_algos;
140 BenchmarkCache<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filter_algos;
141
142 // TODO: Stop manually allocating CUDA memory; allocate an ATen byte
143 // tensor instead.
144 struct Workspace {
Workspaceat::native::Workspace145 Workspace(size_t size) : size(size), data(NULL) {
146 // Sometimes cuDNN returns a workspace size > 2^63, this could makes the
147 // allocation of workspace fail with some 64bit indexing error instead of an
148 // OOM error. In such case, we manually fail with OOM.
149 TORCH_CHECK_WITH(
150 OutOfMemoryError, size < 1_TiB, "Not enough memory for workspace!");
151 data = c10::cuda::CUDACachingAllocator::raw_alloc(size);
152 }
153 Workspace(const Workspace&) = delete;
154 Workspace(Workspace&&) = default;
155 Workspace& operator=(Workspace&&) = default;
~Workspaceat::native::Workspace156 ~Workspace() {
157 if (data) {
158 c10::cuda::CUDACachingAllocator::raw_delete(data);
159 }
160 }
161
162 size_t size;
163 void* data;
164 };
165
166 template <typename perf_t>
167 struct algorithm_search {};
168
getWorkspaceSize(const ConvolutionArgs & args,cudnnConvolutionFwdAlgo_t algo,size_t * sz)169 cudnnStatus_t getWorkspaceSize(
170 const ConvolutionArgs& args,
171 cudnnConvolutionFwdAlgo_t algo,
172 size_t* sz) {
173 return cudnnGetConvolutionForwardWorkspaceSize(
174 args.handle,
175 args.idesc.desc(),
176 args.wdesc.desc(),
177 args.cdesc.desc(),
178 args.odesc.desc(),
179 algo,
180 sz);
181 }
getWorkspaceSize(const ConvolutionArgs & args,cudnnConvolutionBwdDataAlgo_t algo,size_t * sz)182 cudnnStatus_t getWorkspaceSize(
183 const ConvolutionArgs& args,
184 cudnnConvolutionBwdDataAlgo_t algo,
185 size_t* sz) {
186 return cudnnGetConvolutionBackwardDataWorkspaceSize(
187 args.handle,
188 args.wdesc.desc(),
189 args.odesc.desc(),
190 args.cdesc.desc(),
191 args.idesc.desc(),
192 algo,
193 sz);
194 }
getWorkspaceSize(const ConvolutionArgs & args,cudnnConvolutionBwdFilterAlgo_t algo,size_t * sz)195 cudnnStatus_t getWorkspaceSize(
196 const ConvolutionArgs& args,
197 cudnnConvolutionBwdFilterAlgo_t algo,
198 size_t* sz) {
199 return cudnnGetConvolutionBackwardFilterWorkspaceSize(
200 args.handle,
201 args.idesc.desc(),
202 args.odesc.desc(),
203 args.cdesc.desc(),
204 args.wdesc.desc(),
205 algo,
206 sz);
207 }
208
209 template <typename algo_t>
getMaxWorkspaceSize(const ConvolutionArgs & args,const algo_t * algo,int n_algo)210 size_t getMaxWorkspaceSize(
211 const ConvolutionArgs& args,
212 const algo_t* algo,
213 int n_algo) {
214 size_t max_ws_size = 0;
215 size_t max_block_size = 0;
216
217 const auto device = c10::cuda::current_device();
218 // For the native allocator, retrieves the size of the largest unused block.
219 // For cudaMallocAsync, see c10/cuda/CUDAMallocAsync.cpp:cacheInfo for
220 // details.
221 c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
222
223 for (const auto i : c10::irange(n_algo)) {
224 cudnnStatus_t err;
225 size_t sz;
226 err = getWorkspaceSize(args, algo[i], &sz);
227 if (CUDNN_STATUS_SUCCESS != err || sz == 0 || sz < max_ws_size ||
228 sz > max_block_size)
229 continue;
230 max_ws_size = sz;
231 }
232 return max_ws_size;
233 }
234
235 template <typename perf_t>
getValidAlgorithms(perf_t * perfResults,const ConvolutionArgs & args,int n_algo)236 std::vector<perf_t> getValidAlgorithms(
237 perf_t* perfResults,
238 const ConvolutionArgs& args,
239 int n_algo) {
240 std::vector<perf_t> result;
241 result.reserve(n_algo);
242 for (const auto i : c10::irange(n_algo)) {
243 perf_t perf = perfResults[i];
244
245 // TODO: Shouldn't all returned results be successful?
246 // Double check documentation for cudnnFindConvolutionForwardAlgorithmEx
247 if (perf.status == CUDNN_STATUS_SUCCESS) {
248 if (!args.params.deterministic ||
249 perf.determinism == CUDNN_DETERMINISTIC) {
250 result.push_back(perf);
251 }
252 }
253 }
254 TORCH_CHECK(
255 result.size() > 0, "no valid convolution algorithms available in CuDNN");
256 return result;
257 }
258
259 template <>
260 struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
261 using perf_t = cudnnConvolutionFwdAlgoPerf_t;
262 using algo_t = cudnnConvolutionFwdAlgo_t;
263
264 static constexpr auto DEFAULT_ALGO =
265 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
cacheat::native::algorithm_search266 static BenchmarkCache<perf_t>& cache() {
267 return fwd_algos;
268 }
269
findAlgorithmsat::native::algorithm_search270 static std::vector<perf_t> findAlgorithms(
271 const ConvolutionArgs& args,
272 bool benchmark) {
273 static const algo_t algos[] = {
274 CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
275 CUDNN_CONVOLUTION_FWD_ALGO_FFT,
276 CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
277 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
278 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
279 CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
280 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
281 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
282 };
283 static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
284 static_assert(
285 sizeof(algos) / sizeof(algos[0]) == num_algos,
286 "Missing cuDNN convolution forward algorithms");
287 int perf_count;
288 std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
289 if (!benchmark) {
290 AT_CUDNN_CHECK_WITH_SHAPES(
291 cudnnGetConvolutionForwardAlgorithm_v7(
292 args.handle,
293 args.idesc.desc(),
294 args.wdesc.desc(),
295 args.cdesc.desc(),
296 args.odesc.desc(),
297 num_algos,
298 &perf_count,
299 perf_results.get()),
300 args);
301 } else {
302 size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
303 Workspace ws(max_ws_size);
304 at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
305 AT_CUDNN_CHECK_WITH_SHAPES(
306 cudnnFindConvolutionForwardAlgorithmEx(
307 args.handle,
308 args.idesc.desc(),
309 args.input.const_data_ptr(),
310 args.wdesc.desc(),
311 args.weight.const_data_ptr(),
312 args.cdesc.desc(),
313 args.odesc.desc(),
314 args.output.data_ptr(),
315 num_algos,
316 &perf_count,
317 perf_results.get(),
318 ws.data,
319 ws.size),
320 args);
321
322 // Free the cached blocks in our caching allocator. They are
323 // needed here because the above benchmarking uses a huge amount of
324 // memory, e.g. a few GBs.
325 c10::cuda::CUDACachingAllocator::emptyCache();
326 }
327 return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count);
328 }
329
getWorkspaceSizeat::native::algorithm_search330 static void getWorkspaceSize(
331 const ConvolutionArgs& args,
332 algo_t algo,
333 size_t* workspaceSize) {
334 AT_CUDNN_CHECK_WITH_SHAPES(
335 cudnnGetConvolutionForwardWorkspaceSize(
336 args.handle,
337 args.idesc.desc(),
338 args.wdesc.desc(),
339 args.cdesc.desc(),
340 args.odesc.desc(),
341 algo,
342 workspaceSize),
343 args);
344 }
345 };
346
347 template <>
348 struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
349 using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
350 using algo_t = cudnnConvolutionBwdDataAlgo_t;
351
352 static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
cacheat::native::algorithm_search353 static BenchmarkCache<perf_t>& cache() {
354 return bwd_data_algos;
355 }
356
findAlgorithmsat::native::algorithm_search357 static std::vector<perf_t> findAlgorithms(
358 const ConvolutionArgs& args,
359 bool benchmark) {
360 static const algo_t algos[] = {
361 CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
362 CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
363 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
364 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
365 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
366 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED};
367 static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
368 static_assert(
369 sizeof(algos) / sizeof(algos[0]) == num_algos,
370 "Missing cuDNN convolution backward data algorithms.");
371 int perf_count;
372 std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
373 if (!benchmark) {
374 AT_CUDNN_CHECK_WITH_SHAPES(
375 cudnnGetConvolutionBackwardDataAlgorithm_v7(
376 args.handle,
377 args.wdesc.desc(),
378 args.odesc.desc(),
379 args.cdesc.desc(),
380 args.idesc.desc(),
381 num_algos,
382 &perf_count,
383 perf_results.get()),
384 args);
385 } else {
386 size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
387 Workspace ws(max_ws_size);
388 at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
389 AT_CUDNN_CHECK_WITH_SHAPES(
390 cudnnFindConvolutionBackwardDataAlgorithmEx(
391 args.handle,
392 args.wdesc.desc(),
393 args.weight.const_data_ptr(),
394 args.odesc.desc(),
395 args.output.const_data_ptr(),
396 args.cdesc.desc(),
397 args.idesc.desc(),
398 args.input.data_ptr(),
399 num_algos,
400 &perf_count,
401 perf_results.get(),
402 ws.data,
403 ws.size),
404 args);
405
406 // Free the cached blocks in our caching allocator. They are
407 // needed here because the above benchmarking uses a huge amount of
408 // memory, e.g. a few GBs.
409 c10::cuda::CUDACachingAllocator::emptyCache();
410 }
411 return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count);
412 }
413
getWorkspaceSizeat::native::algorithm_search414 static void getWorkspaceSize(
415 const ConvolutionArgs& args,
416 cudnnConvolutionBwdDataAlgo_t algo,
417 size_t* workspaceSize) {
418 AT_CUDNN_CHECK_WITH_SHAPES(
419 cudnnGetConvolutionBackwardDataWorkspaceSize(
420 args.handle,
421 args.wdesc.desc(),
422 args.odesc.desc(),
423 args.cdesc.desc(),
424 args.idesc.desc(),
425 algo,
426 workspaceSize),
427 args);
428 }
429 };
430
431 template <>
432 struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
433 using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
434 using algo_t = cudnnConvolutionBwdFilterAlgo_t;
435
436 static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
437
cacheat::native::algorithm_search438 static BenchmarkCache<perf_t>& cache() {
439 return bwd_filter_algos;
440 }
441
findAlgorithmsat::native::algorithm_search442 static std::vector<perf_t> findAlgorithms(
443 const ConvolutionArgs& args,
444 bool benchmark) {
445 static const algo_t algos[] = {
446 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
447 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
448 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
449 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
450 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
451 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
452 };
453 // NOTE: - 1 because ALGO_WINOGRAD is not implemented
454 static constexpr int num_algos =
455 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1;
456 static_assert(
457 sizeof(algos) / sizeof(algos[0]) == num_algos,
458 "Missing cuDNN convolution backward filter algorithms.");
459 std::unique_ptr<perf_t[]> perf_results(new perf_t[num_algos]);
460 int perf_count;
461 if (!benchmark) {
462 AT_CUDNN_CHECK_WITH_SHAPES(
463 cudnnGetConvolutionBackwardFilterAlgorithm_v7(
464 args.handle,
465 args.idesc.desc(),
466 args.odesc.desc(),
467 args.cdesc.desc(),
468 args.wdesc.desc(),
469 num_algos,
470 &perf_count,
471 perf_results.get()),
472 args);
473 } else {
474 size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
475 Workspace ws(max_ws_size);
476 at::cuda::errorIfCapturingCudnnBenchmark("cudnnFind");
477 AT_CUDNN_CHECK_WITH_SHAPES(
478 cudnnFindConvolutionBackwardFilterAlgorithmEx(
479 args.handle,
480 args.idesc.desc(),
481 args.input.const_data_ptr(),
482 args.odesc.desc(),
483 args.output.const_data_ptr(),
484 args.cdesc.desc(),
485 args.wdesc.desc(),
486 args.weight.data_ptr(),
487 num_algos,
488 &perf_count,
489 perf_results.get(),
490 ws.data,
491 ws.size),
492 args);
493
494 // Free the cached blocks in our caching allocator. They are
495 // needed here because the above benchmarking uses a huge amount of
496 // memory, e.g. a few GBs.
497 c10::cuda::CUDACachingAllocator::emptyCache();
498 }
499 return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count);
500 }
501
getWorkspaceSizeat::native::algorithm_search502 static void getWorkspaceSize(
503 const ConvolutionArgs& args,
504 algo_t algo,
505 size_t* workspaceSize) {
506 AT_CUDNN_CHECK_WITH_SHAPES(
507 cudnnGetConvolutionBackwardFilterWorkspaceSize(
508 args.handle,
509 args.idesc.desc(),
510 args.odesc.desc(),
511 args.cdesc.desc(),
512 args.wdesc.desc(),
513 algo,
514 workspaceSize),
515 args);
516 }
517 };
518
519 template <typename perf_t>
520 class AlgoIterator {
521 using search = algorithm_search<perf_t>;
522 const ConvolutionArgs& args;
523 bool benchmark;
524
525 public:
AlgoIterator(const ConvolutionArgs & args,bool benchmark)526 AlgoIterator(const ConvolutionArgs& args, bool benchmark)
527 : args(args), benchmark(benchmark) {}
528
onlyDefaultAlgorithm(const ConvolutionArgs & args)529 static std::vector<perf_t> onlyDefaultAlgorithm(const ConvolutionArgs& args) {
530 std::vector<perf_t> perfResults(1);
531 perfResults[0].algo = search::DEFAULT_ALGO;
532 if (args.params.dataType == CUDNN_DATA_HALF) {
533 perfResults[0].mathType = CUDNN_TENSOR_OP_MATH;
534 } else {
535 perfResults[0].mathType = CUDNN_DEFAULT_MATH;
536 if (args.params.dataType == CUDNN_DATA_FLOAT && !args.params.allow_tf32) {
537 perfResults[0].mathType = CUDNN_FMA_MATH;
538 }
539 }
540 search::getWorkspaceSize(
541 args, perfResults[0].algo, &(perfResults[0].memory));
542 return perfResults;
543 }
544
try_all(std::function<void (const perf_t & perf)> f)545 void try_all(std::function<void(const perf_t& perf)> f) {
546 bool only_use_default = args.params.deterministic && !benchmark;
547
548 auto& cache = search::cache();
549 perf_t algoPerf;
550 if (!only_use_default && cache.find(args.params, &algoPerf)) {
551 try {
552 f(algoPerf);
553 return;
554 } catch (c10::OutOfMemoryError& e) {
555 cudaGetLastError(); // clear CUDA error
556 }
557 }
558
559 auto perfResults = only_use_default
560 ? onlyDefaultAlgorithm(args)
561 : search::findAlgorithms(args, benchmark);
562 for (auto& algoPerf : perfResults) {
563 try {
564 f(algoPerf);
565 cache.insert(args.params, algoPerf);
566 return;
567 } catch (c10::OutOfMemoryError& e) {
568 cudaGetLastError(); // clear CUDA error
569 } catch (c10::CuDNNError& e) {
570 cudaGetLastError(); // clear CUDA error
571 }
572 }
573 TORCH_CHECK(
574 false, "Unable to find a valid cuDNN algorithm to run convolution");
575 }
576 };
577
allocate_workspace(size_t size,const Tensor & other)578 inline Tensor allocate_workspace(size_t size, const Tensor& other) {
579 // Sometimes cuDNN returns a workspace size > 2^63, this could makes the
580 // allocation of workspace fail with some 64bit indexing error instead of an
581 // OOM error. In such case, we manually fail with OOM.
582 TORCH_CHECK_WITH(
583 OutOfMemoryError, size < 1_TiB, "Not enough memory for workspace!");
584 return at::empty({static_cast<int64_t>(size)}, other.options().dtype(kByte));
585 }
586
587 // NOTE [ raw_cudnn_convolution_forward_out ]
588 //
589 // - raw_cudnn_convolution_forward_out (Tensor)
590 // Functiont that handles tensors that are too large to use 32bit indexing.
591 // It just split the tensor and dispatches to
592 // `raw_cudnn_convolution_forward_out_32bit`.
593 //
594 // - raw_cudnn_convolution_forward_out_32bit (Tensor)
595 // Low level function which invokes CuDNN, and takes an output
596 // tensor which is directly written to (thus _out).
597 //
598
599 // ---------------------------------------------------------------------
600 //
601 // Splitting to 32bit
602 //
603 // ---------------------------------------------------------------------
604
605 template <typename func_t>
split_batch_dim_to_32bit_out(const at::Tensor & output,const at::Tensor & input,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32,int64_t max_worksize,func_t func_32bit)606 static inline void split_batch_dim_to_32bit_out(
607 const at::Tensor& output,
608 const at::Tensor& input,
609 const at::Tensor& weight,
610 IntArrayRef padding,
611 IntArrayRef stride,
612 IntArrayRef dilation,
613 int64_t groups,
614 bool benchmark,
615 bool deterministic,
616 bool allow_tf32,
617 int64_t max_worksize,
618 func_t func_32bit) {
619 constexpr int64_t int_max = std::numeric_limits<int>::max();
620 const int64_t ni = input.numel();
621 const int64_t no = output.numel();
622 // Assume the shape of the tensor is (N, C, D1, D2, ...)
623 // if N * C * D1 * D2 * ... <= int_max, then no need to split at all
624 if (ni <= int_max && no <= int_max) {
625 func_32bit(
626 output,
627 input,
628 weight,
629 padding,
630 stride,
631 dilation,
632 groups,
633 benchmark,
634 deterministic,
635 allow_tf32);
636 return;
637 }
638 // else, if C * D1 * D2 * ... <= int_max, then we just need to split across
639 // the N dimension
640 //
641 // Here we use a simple heuristics to determine the size of each split
642 // We don't max out the 2^31 address space because this number is super
643 // large and very likely to get an OOM.
644 int64_t n = output.size(0);
645 int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
646 int64_t split_size = std::max<int64_t>(max_worksize / max_inner_size, 1L);
647 int64_t num_splits = (n + split_size - 1) / split_size;
648 if (split_size * max_inner_size < int_max) {
649 for (const auto i : c10::irange(num_splits)) {
650 int64_t start = split_size * i;
651 int64_t split_size_ = std::min<int64_t>(split_size, n - start);
652 Tensor input_ = input.narrow(0, start, split_size_);
653 Tensor output_ = output.narrow(0, start, split_size_);
654 func_32bit(
655 output_,
656 input_,
657 weight,
658 padding,
659 stride,
660 dilation,
661 groups,
662 benchmark,
663 deterministic,
664 allow_tf32);
665 }
666 return;
667 }
668 // If control flow reaches here, this means even splitting N is not enough,
669 // then things starts to become complicated: For example, for conv2d, there
670 // following questions needs to be considered.
671 // - Is the memory layout NCHW or NHWC ?
672 // - If the conv is NCHW -> NC'H'W', then should we
673 // - split only NC?
674 // - split only N'C'?
675 // - split both?
676 // - If the conv is NHWC, then we need to split across H, we need to be very
677 // careful about the boundary condition
678 // to make sure that the boundary is handled correctly.
679 // - If we decide to make these splits, is the memory contiguous? Do we need
680 // to copy the memory? Considering the complexity of this issue, it is better
681 // not to use cuDNN for this case
682 TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
683 }
684
685 #define ASSERT_CORRECT_PRECISION(math_type) \
686 if (args.params.dataType == CUDNN_DATA_FLOAT) { \
687 TORCH_INTERNAL_ASSERT( \
688 args.params.allow_tf32 || math_type == CUDNN_FMA_MATH); \
689 }
690
691 // ---------------------------------------------------------------------
692 //
693 // Convolution forward / Transposed convolution backward
694 //
695 // ---------------------------------------------------------------------
696
raw_cudnn_convolution_forward_out_32bit(const Tensor & output,const Tensor & input,const Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)697 void raw_cudnn_convolution_forward_out_32bit(
698 const Tensor& output,
699 const Tensor& input,
700 const Tensor& weight,
701 IntArrayRef padding,
702 IntArrayRef stride,
703 IntArrayRef dilation,
704 int64_t groups,
705 bool benchmark,
706 bool deterministic,
707 bool allow_tf32) {
708 auto dataType = getCudnnDataType(input);
709
710 ConvolutionArgs args{input, output, weight};
711 args.handle = getCudnnHandle();
712 at::MemoryFormat memory_format =
713 cudnn_conv_suggest_memory_format(input, weight);
714 setConvolutionParams(
715 &args.params,
716 input,
717 weight,
718 padding,
719 stride,
720 dilation,
721 groups,
722 deterministic,
723 allow_tf32,
724 memory_format);
725 args.idesc.set(input, memory_format);
726 args.wdesc.set(weight, memory_format, 0);
727 args.odesc.set(output, memory_format);
728 args.cdesc.set(
729 dataType,
730 input.dim() - 2,
731 args.params.padding,
732 args.params.stride,
733 args.params.dilation,
734 args.params.groups,
735 args.params.allow_tf32);
736
737 // TODO: when we do legacy group convolution support, we'll repeatedly
738 // reinitialize the workspace for each convolution we do. This is
739 // wasteful; we'd rather reuse the workspace. OTOH, legacy group
740 // convolution support is already pretty slow, so this might not
741 // matter. (This applies to raw_cudnn_convolution_backward_input as well.)
742 AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark)
743 .try_all([&](const cudnnConvolutionFwdAlgoPerf_t& fwdAlgPerf) {
744 Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input);
745
746 // update convDesc mathType since cudnn 7.4+ now requires both algo +
747 // mathType to figure out whether to use Tensor core kernels or not See
748 // Note [behavior of cudnnFind and cudnnGet]
749 ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
750 AT_CUDNN_CHECK_WITH_SHAPES(
751 cudnnSetConvolutionMathType(
752 args.cdesc.mut_desc(), fwdAlgPerf.mathType),
753 args);
754
755 Constant one(dataType, 1);
756 Constant zero(dataType, 0);
757
758 AT_CUDNN_CHECK_WITH_SHAPES(
759 cudnnConvolutionForward(
760 args.handle,
761 &one,
762 args.idesc.desc(),
763 input.const_data_ptr(),
764 args.wdesc.desc(),
765 weight.const_data_ptr(),
766 args.cdesc.desc(),
767 fwdAlgPerf.algo,
768 workspace.data_ptr(),
769 fwdAlgPerf.memory,
770 &zero,
771 args.odesc.desc(),
772 output.data_ptr()),
773 args,
774 "Forward algorithm: ",
775 static_cast<int>(fwdAlgPerf.algo),
776 "\n");
777 });
778 }
779
raw_cudnn_convolution_forward_out_v7(const Tensor & output,const Tensor & input,const Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)780 void raw_cudnn_convolution_forward_out_v7(
781 const Tensor& output,
782 const Tensor& input,
783 const Tensor& weight,
784 IntArrayRef padding,
785 IntArrayRef stride,
786 IntArrayRef dilation,
787 int64_t groups,
788 bool benchmark,
789 bool deterministic,
790 bool allow_tf32) {
791 split_batch_dim_to_32bit_out(
792 output,
793 input,
794 weight,
795 padding,
796 stride,
797 dilation,
798 groups,
799 benchmark,
800 deterministic,
801 allow_tf32,
802 1024 * 1024 * 256,
803 raw_cudnn_convolution_forward_out_32bit);
804 }
805
806 // ---------------------------------------------------------------------
807 //
808 // Convolution backward / Transposed convolution forward
809 //
810 // ---------------------------------------------------------------------
811
raw_cudnn_convolution_backward_input_out_32bit(const at::Tensor & grad_input,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)812 void raw_cudnn_convolution_backward_input_out_32bit(
813 const at::Tensor& grad_input,
814 const at::Tensor& grad_output,
815 const at::Tensor& weight,
816 IntArrayRef padding,
817 IntArrayRef stride,
818 IntArrayRef dilation,
819 int64_t groups,
820 bool benchmark,
821 bool deterministic,
822 bool allow_tf32) {
823 auto dataType = getCudnnDataType(grad_output);
824
825 ConvolutionArgs args{grad_input, grad_output, weight};
826 args.handle = getCudnnHandle();
827 at::MemoryFormat memory_format =
828 cudnn_conv_suggest_memory_format(grad_input, weight);
829 setConvolutionParams(
830 &args.params,
831 grad_input,
832 weight,
833 padding,
834 stride,
835 dilation,
836 groups,
837 deterministic,
838 allow_tf32,
839 memory_format);
840 args.idesc.set(grad_input, memory_format);
841 args.wdesc.set(weight, memory_format, 0);
842 args.odesc.set(grad_output, memory_format);
843 args.cdesc.set(
844 dataType,
845 grad_output.dim() - 2,
846 args.params.padding,
847 args.params.stride,
848 args.params.dilation,
849 args.params.groups,
850 args.params.allow_tf32);
851
852 AlgoIterator<cudnnConvolutionBwdDataAlgoPerf_t>(args, benchmark)
853 .try_all([&](const cudnnConvolutionBwdDataAlgoPerf_t& bwdDataAlgPerf) {
854 Tensor workspace =
855 allocate_workspace(bwdDataAlgPerf.memory, grad_output);
856
857 // update convDesc mathType since cudnn 7.4+ now requires both algo +
858 // mathType to figure out whether to use Tensor core kernels or not See
859 // Note [behavior of cudnnFind and cudnnGet]
860 ASSERT_CORRECT_PRECISION(bwdDataAlgPerf.mathType);
861 AT_CUDNN_CHECK_WITH_SHAPES(
862 cudnnSetConvolutionMathType(
863 args.cdesc.mut_desc(), bwdDataAlgPerf.mathType),
864 args);
865
866 Constant one(dataType, 1);
867 Constant zero(dataType, 0);
868
869 AT_CUDNN_CHECK_WITH_SHAPES(
870 cudnnConvolutionBackwardData(
871 args.handle,
872 &one,
873 args.wdesc.desc(),
874 weight.const_data_ptr(),
875 args.odesc.desc(),
876 grad_output.const_data_ptr(),
877 args.cdesc.desc(),
878 bwdDataAlgPerf.algo,
879 workspace.data_ptr(),
880 bwdDataAlgPerf.memory,
881 &zero,
882 args.idesc.desc(),
883 grad_input.mutable_data_ptr()),
884 args,
885 "Additional pointer addresses: \n",
886 " grad_output: ",
887 grad_output.const_data_ptr(),
888 "\n",
889 " grad_input: ",
890 grad_input.mutable_data_ptr(),
891 "\n",
892 "Backward data algorithm: ",
893 static_cast<int>(bwdDataAlgPerf.algo),
894 "\n");
895 });
896 }
897
raw_cudnn_convolution_backward_input_out_v7(const at::Tensor & grad_input,const at::Tensor & grad_output,const at::Tensor & weight,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)898 void raw_cudnn_convolution_backward_input_out_v7(
899 const at::Tensor& grad_input,
900 const at::Tensor& grad_output,
901 const at::Tensor& weight,
902 IntArrayRef padding,
903 IntArrayRef stride,
904 IntArrayRef dilation,
905 int64_t groups,
906 bool benchmark,
907 bool deterministic,
908 bool allow_tf32) {
909 split_batch_dim_to_32bit_out(
910 grad_input,
911 grad_output,
912 weight,
913 padding,
914 stride,
915 dilation,
916 groups,
917 benchmark,
918 deterministic,
919 allow_tf32,
920 1024 * 1024 * 128,
921 raw_cudnn_convolution_backward_input_out_32bit);
922 }
923
924 // ---------------------------------------------------------------------
925 //
926 // Convolution backward (weight)
927 //
928 // ---------------------------------------------------------------------
929
raw_cudnn_convolution_backward_weight_out_32bit(const Tensor & grad_weight,const Tensor & grad_output,const Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)930 void raw_cudnn_convolution_backward_weight_out_32bit(
931 const Tensor& grad_weight,
932 const Tensor& grad_output,
933 const Tensor& input,
934 IntArrayRef padding,
935 IntArrayRef stride,
936 IntArrayRef dilation,
937 int64_t groups,
938 bool benchmark,
939 bool deterministic,
940 bool allow_tf32) {
941 auto dataType = getCudnnDataType(input);
942
943 ConvolutionArgs args{input, grad_output, grad_weight};
944 args.handle = getCudnnHandle();
945 at::MemoryFormat memory_format =
946 cudnn_conv_suggest_memory_format(input, grad_weight);
947 setConvolutionParams(
948 &args.params,
949 input,
950 grad_weight,
951 padding,
952 stride,
953 dilation,
954 groups,
955 deterministic,
956 allow_tf32,
957 memory_format);
958 args.idesc.set(input, memory_format);
959 args.wdesc.set(grad_weight, memory_format, 0);
960 args.odesc.set(grad_output, memory_format);
961 args.cdesc.set(
962 dataType,
963 input.dim() - 2,
964 args.params.padding,
965 args.params.stride,
966 args.params.dilation,
967 args.params.groups,
968 args.params.allow_tf32);
969
970 AlgoIterator<cudnnConvolutionBwdFilterAlgoPerf_t>(args, benchmark)
971 .try_all(
972 [&](const cudnnConvolutionBwdFilterAlgoPerf_t& bwdFilterAlgPerf) {
973 Tensor workspace =
974 allocate_workspace(bwdFilterAlgPerf.memory, input);
975
976 // update convDesc mathType since cudnn 7.4+ now requires both algo
977 // + mathType to figure out whether to use Tensor core kernels or
978 // not See Note [behavior of cudnnFind and cudnnGet]
979 ASSERT_CORRECT_PRECISION(bwdFilterAlgPerf.mathType);
980 AT_CUDNN_CHECK_WITH_SHAPES(
981 cudnnSetConvolutionMathType(
982 args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType),
983 args);
984
985 Constant one(dataType, 1);
986 Constant zero(dataType, 0);
987
988 AT_CUDNN_CHECK_WITH_SHAPES(
989 cudnnConvolutionBackwardFilter(
990 args.handle,
991 &one,
992 args.idesc.desc(),
993 input.const_data_ptr(),
994 args.odesc.desc(),
995 grad_output.const_data_ptr(),
996 args.cdesc.desc(),
997 bwdFilterAlgPerf.algo,
998 workspace.data_ptr(),
999 bwdFilterAlgPerf.memory,
1000 &zero,
1001 args.wdesc.desc(),
1002 grad_weight.data_ptr()),
1003 args,
1004 "Additional pointer addresses: \n",
1005 " grad_output: ",
1006 grad_output.const_data_ptr(),
1007 "\n",
1008 " grad_weight: ",
1009 grad_weight.data_ptr(),
1010 "\n",
1011 "Backward filter algorithm: ",
1012 static_cast<int>(bwdFilterAlgPerf.algo),
1013 "\n");
1014 });
1015 }
1016
raw_cudnn_convolution_backward_weight_out_v7(const Tensor & grad_weight,const Tensor & grad_output,const Tensor & input,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)1017 void raw_cudnn_convolution_backward_weight_out_v7(
1018 const Tensor& grad_weight,
1019 const Tensor& grad_output,
1020 const Tensor& input,
1021 IntArrayRef padding,
1022 IntArrayRef stride,
1023 IntArrayRef dilation,
1024 int64_t groups,
1025 bool benchmark,
1026 bool deterministic,
1027 bool allow_tf32) {
1028 constexpr int64_t int_max = std::numeric_limits<int>::max();
1029 const int64_t ni = input.numel();
1030 const int64_t no = grad_output.numel();
1031 // Assume the shape of the tensor is (N, C, D1, D2, ...)
1032 // if N * C * D1 * D2 * ... <= int_max, then no need to split at all
1033 if (ni <= int_max && no <= int_max) {
1034 raw_cudnn_convolution_backward_weight_out_32bit(
1035 grad_weight,
1036 grad_output,
1037 input,
1038 padding,
1039 stride,
1040 dilation,
1041 groups,
1042 benchmark,
1043 deterministic,
1044 allow_tf32);
1045 return;
1046 }
1047 // else, if C * D1 * D2 * ... <= int_max, then we just need to split across
1048 // the N dimension
1049 //
1050 // Here we use a simple heuristics to determine the size of each split
1051 // We don't max out the 2^31 address space because this number is super
1052 // large and very likely to get an OOM.
1053 int64_t n = grad_output.size(0);
1054 int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
1055 int64_t split_size =
1056 std::max<int64_t>(1024 * 1024 * 512 / max_inner_size, 1L);
1057 int64_t num_splits = (n + split_size - 1) / split_size;
1058 if (split_size * max_inner_size < int_max) {
1059 const auto kAccType = (grad_weight.scalar_type() == kHalf ||
1060 grad_weight.scalar_type() == kBFloat16)
1061 ? kFloat
1062 : grad_weight.scalar_type();
1063 Tensor grad_weight_accumulator =
1064 at::zeros(grad_weight.sizes(), grad_weight.options().dtype(kAccType));
1065 for (const auto i : c10::irange(num_splits)) {
1066 int64_t start = split_size * i;
1067 int64_t split_size_ = std::min<int64_t>(split_size, n - start);
1068 Tensor input_ = input.narrow(0, start, split_size_);
1069 Tensor grad_output_ = grad_output.narrow(0, start, split_size_);
1070 Tensor grad_weight_ = at::empty_like(grad_weight);
1071 raw_cudnn_convolution_backward_weight_out_32bit(
1072 grad_weight_,
1073 grad_output_,
1074 input_,
1075 padding,
1076 stride,
1077 dilation,
1078 groups,
1079 benchmark,
1080 deterministic,
1081 allow_tf32);
1082 grad_weight_accumulator.add_(grad_weight_);
1083 }
1084 grad_weight.copy_(grad_weight_accumulator);
1085 return;
1086 }
1087 // If control flow reaches here, this means even splitting N is not enough,
1088 // then things starts to become complicated: For example, for conv2d, there
1089 // following questions needs to be considered.
1090 // - Is the memory layout NCHW or NHWC ?
1091 // - If the conv is NCHW -> NC'H'W', then should we
1092 // - split only NC?
1093 // - split only N'C'?
1094 // - split both?
1095 // - If the conv is NHWC, then we need to split across H, we need to be very
1096 // careful about the boundary condition
1097 // to make sure that the boundary is handled correctly.
1098 // - If we decide to make these splits, is the memory contiguous? Do we need
1099 // to copy the memory? Considering the complexity of this issue, it is better
1100 // not to use cuDNN for this case
1101 TORCH_INTERNAL_ASSERT(false, "This case should not be dispatched to cuDNN.");
1102 }
1103
raw_cudnn_convolution_add_relu_out_v7(const Tensor & output,const Tensor & input,const Tensor & weight,const Tensor & z,float alpha,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)1104 void raw_cudnn_convolution_add_relu_out_v7(
1105 const Tensor& output,
1106 const Tensor& input,
1107 const Tensor& weight,
1108 const Tensor& z,
1109 float alpha,
1110 const Tensor& bias,
1111 IntArrayRef stride,
1112 IntArrayRef padding,
1113 IntArrayRef dilation,
1114 int64_t groups,
1115 bool benchmark,
1116 bool deterministic,
1117 bool allow_tf32) {
1118 auto dataType = getCudnnDataType(input);
1119 ConvolutionArgs args{input, output, weight};
1120 args.handle = getCudnnHandle();
1121 at::MemoryFormat memory_format =
1122 cudnn_conv_suggest_memory_format(input, weight);
1123 setConvolutionParams(
1124 &args.params,
1125 input,
1126 weight,
1127 padding,
1128 stride,
1129 dilation,
1130 groups,
1131 deterministic,
1132 allow_tf32,
1133 memory_format);
1134 args.idesc.set(input, memory_format);
1135 args.wdesc.set(weight, memory_format, 0);
1136 args.odesc.set(output, memory_format);
1137 args.cdesc.set(
1138 dataType,
1139 input.dim() - 2,
1140 args.params.padding,
1141 args.params.stride,
1142 args.params.dilation,
1143 args.params.groups,
1144 args.params.allow_tf32);
1145
1146 TensorDescriptor zdesc;
1147 zdesc.set(z, memory_format);
1148
1149 TensorDescriptor bdesc;
1150 bdesc.set(bias.expand({1, bias.size(0)}), memory_format, output.dim());
1151
1152 ActivationDescriptor adesc;
1153 adesc.set(CUDNN_ACTIVATION_RELU);
1154
1155 AlgoIterator<cudnnConvolutionFwdAlgoPerf_t>(args, benchmark)
1156 .try_all([&](const cudnnConvolutionFwdAlgoPerf_t& fwdAlgPerf) {
1157 Tensor workspace = allocate_workspace(fwdAlgPerf.memory, input);
1158
1159 // update convDesc mathType since cudnn 7.4+ now requires both algo +
1160 // mathType to figure out whether to use Tensor core kernels or not See
1161 // Note [behavior of cudnnFind and cudnnGet]
1162 ASSERT_CORRECT_PRECISION(fwdAlgPerf.mathType);
1163 AT_CUDNN_CHECK_WITH_SHAPES(
1164 cudnnSetConvolutionMathType(
1165 args.cdesc.mut_desc(), fwdAlgPerf.mathType),
1166 args);
1167
1168 Constant one(dataType, 1);
1169 Constant alpha_(dataType, alpha);
1170
1171 AT_CUDNN_CHECK_WITH_SHAPES(
1172 cudnnConvolutionBiasActivationForward(
1173 args.handle,
1174 &one,
1175 args.idesc.desc(),
1176 input.const_data_ptr(),
1177 args.wdesc.desc(),
1178 weight.const_data_ptr(),
1179 args.cdesc.desc(),
1180 fwdAlgPerf.algo,
1181 workspace.data_ptr(),
1182 fwdAlgPerf.memory,
1183 &alpha_,
1184 zdesc.desc(),
1185 z.const_data_ptr(),
1186 bdesc.desc(),
1187 bias.const_data_ptr(),
1188 adesc.desc(),
1189 args.odesc.desc(),
1190 output.data_ptr()),
1191 args,
1192 "zdesc: ",
1193 zdesc,
1194 "bdesc: ",
1195 bdesc,
1196 "cudnnConvolutionBiasActivationForward: ",
1197 static_cast<int>(fwdAlgPerf.algo),
1198 "\n");
1199 });
1200 }
1201
raw_cudnn_convolution_add_relu_fallback_out(const Tensor & output,const Tensor & input,const Tensor & weight,const Tensor & z,float alpha,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)1202 void raw_cudnn_convolution_add_relu_fallback_out(
1203 const Tensor& output,
1204 const Tensor& input,
1205 const Tensor& weight,
1206 const Tensor& z,
1207 float alpha,
1208 const Tensor& bias,
1209 IntArrayRef stride,
1210 IntArrayRef padding,
1211 IntArrayRef dilation,
1212 int64_t groups,
1213 bool benchmark,
1214 bool deterministic,
1215 bool allow_tf32) {
1216 // cuDNN Conv-Bias-Activation:
1217 // y = act ( alpha1 * conv(x) + alpha2 * z + bias )
1218 // In pytorch function `raw_cudnn_convolution_add_relu_out`: alpha1 is 1,
1219 // alpha 2 is `float alpha`
1220
1221 raw_cudnn_convolution_forward_out(
1222 output,
1223 input,
1224 weight,
1225 padding,
1226 stride,
1227 dilation,
1228 groups,
1229 benchmark,
1230 deterministic,
1231 allow_tf32);
1232 at::Tensor alpha_mul_z_add_bias =
1233 at::native::reshape_bias(input.dim(), bias).add(z, alpha);
1234 output.add_(alpha_mul_z_add_bias);
1235 output.relu_();
1236 }
1237
1238 } // namespace native
1239 } // namespace at
1240
1241 #endif
1242