xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/Conv_v8.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 
3 #include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
4 
5 #if AT_CUDNN_ENABLED()
6 
7 #include <ATen/cudnn/cudnn-wrapper.h>
8 
9 #include <c10/macros/Macros.h>
10 
11 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override")
12 #include <cudnn_frontend.h>
13 C10_DIAGNOSTIC_POP()
14 
15 #include <ATen/TensorUtils.h>
16 #include <ATen/core/Tensor.h>
17 #include <ATen/cuda/Exceptions.h>
18 #include <ATen/cudnn/Handle.h>
19 #include <ATen/native/ConvUtils.h>
20 #include <ATen/native/cudnn/ConvShared.h>
21 #include <ATen/native/utils/ParamsHash.h>
22 #include <cudnn_frontend_find_plan.h>
23 #include <cudnn_frontend_get_plan.h>
24 
25 #include <c10/cuda/CUDACachingAllocator.h>
26 #include <c10/cuda/CUDAException.h>
27 #include <c10/util/env.h>
28 
29 #include <list>
30 #include <unordered_map>
31 
32 #ifndef AT_PER_OPERATOR_HEADERS
33 #include <ATen/Functions.h>
34 #else
35 #include <ATen/ops/empty.h>
36 #endif
37 
38 #ifdef __linux__
39 #include <dlfcn.h>
40 #endif
41 
42 namespace at {
43 namespace native {
44 
45 namespace {
46 
47 // TODO: remove duplicate code in Conv_v7.cpp
operator ""_TiB(unsigned long long n)48 constexpr int64_t operator"" _TiB(unsigned long long n) {
49   return size_t(n) << 40;
50 }
51 
getAlignment(const Tensor & t)52 uint8_t getAlignment(const Tensor& t) {
53   // alignment are in bytes
54   uint8_t alignment = 1;
55   uintptr_t address = reinterpret_cast<uintptr_t>(t.const_data_ptr());
56   for (; alignment < 32; alignment *= 2) {
57     if (address % (alignment * 2)) {
58       return alignment;
59     }
60   }
61   return alignment;
62 }
63 
getTensorDescriptorWithTypeVirtual(const Tensor & t,const int64_t id,const uint8_t alignment,const cudnnDataType_t dataType,const at::MemoryFormat memory_format,const bool _virtual)64 cudnn_frontend::Tensor getTensorDescriptorWithTypeVirtual(
65     const Tensor& t,
66     const int64_t id,
67     const uint8_t alignment,
68     const cudnnDataType_t dataType,
69     const at::MemoryFormat memory_format,
70     const bool _virtual) {
71 #if defined(__linux__) && !defined(FBCODE_CAFFE2) && CUDNN_MAJOR == 8 && \
72     CUDNN_MINOR > 5
73   // Workaround for cudnn error handling deficiency, that results in a crash on
74   // Ubuntu-22+ if `libnvrtc.so` is not found on the system, which strictly
75   // speaking is not necessary for usecases below See
76   // https://github.com/pytorch/pytorch/issues/97041
77   static C10_UNUSED auto cudnn_cnn_infer_handler = [] {
78     void* handle = dlopen("libcudnn_cnn_infer.so.8", RTLD_LAZY);
79     char* err = dlerror();
80     if (!handle) {
81       TORCH_WARN(
82           "Attempt to open cnn_infer failed: handle=", handle, " error: ", err);
83     } else if (err) {
84       TORCH_WARN("Applied workaround for CuDNN issue, install nvrtc.so");
85     }
86     return handle;
87   }();
88 #endif
89   auto sizes = t.sizes();
90   auto strides = t.strides();
91   bool channels_last = memory_format == at::MemoryFormat::ChannelsLast ||
92       memory_format == at::MemoryFormat::ChannelsLast3d;
93 
94   std::vector<int64_t> strides_copy(std::begin(strides), std::end(strides));
95   fixSizeOneDimStride<int64_t>(
96       sizes.size(), &sizes[0], (int64_t*)&strides_copy[0], channels_last);
97   auto r = cudnn_frontend::TensorBuilder()
98                .setDim(sizes.size(), sizes.data())
99                .setStrides(strides_copy.size(), strides_copy.data())
100                .setId(id)
101                .setAlignment(alignment)
102                .setDataType(dataType)
103                .setVirtual(_virtual)
104                .build();
105   return r;
106 }
107 
getTensorDescriptor(const Tensor & t,const int64_t id,const uint8_t alignment,const at::MemoryFormat memory_format)108 cudnn_frontend::Tensor getTensorDescriptor(
109     const Tensor& t,
110     const int64_t id,
111     const uint8_t alignment,
112     const at::MemoryFormat memory_format) {
113   return getTensorDescriptorWithTypeVirtual(
114       t, id, alignment, getCudnnDataType(t), memory_format, false);
115 }
116 
getConvDescriptor(cudnnDataType_t dataType,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,const at::ScalarType scalar_type)117 cudnn_frontend::ConvDesc_v8 getConvDescriptor(
118     cudnnDataType_t dataType,
119     IntArrayRef padding,
120     IntArrayRef stride,
121     IntArrayRef dilation,
122     const at::ScalarType scalar_type) {
123   uint64_t convDim = stride.size();
124   if (scalar_type == kBFloat16 || scalar_type == kHalf) {
125     dataType = CUDNN_DATA_FLOAT;
126   }
127   return cudnn_frontend::ConvDescBuilder()
128       .setDataType(dataType)
129       .setMathMode(CUDNN_CROSS_CORRELATION)
130       .setNDims(convDim)
131       .setStrides(convDim, stride.data())
132       .setPrePadding(convDim, padding.data())
133       .setPostPadding(convDim, padding.data())
134       .setDilation(convDim, dilation.data())
135       .build();
136 }
137 
filterEngineConfigs(cudnn_frontend::EngineConfigList & from,cudnn_frontend::EngineConfigList & to,bool deterministic,bool allow_tf32,c10::ScalarType scalar_type)138 void filterEngineConfigs(
139     cudnn_frontend::EngineConfigList& from,
140     cudnn_frontend::EngineConfigList& to,
141     bool deterministic,
142     bool allow_tf32,
143     c10::ScalarType scalar_type) {
144   auto filter = [=](cudnnBackendDescriptor_t c) {
145     if (deterministic) {
146       if (cudnn_frontend::hasNumericalNote<
147               CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(c)) {
148         return true;
149       }
150     }
151     if (cudnn_frontend::hasNumericalNote<
152             CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
153       return true;
154     }
155     if (scalar_type == kFloat) {
156       // TODO: check under which conditions this is OK
157       if (!allow_tf32 &&
158           cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(
159               c)) {
160         return true;
161       }
162     }
163     return false;
164   };
165   cudnn_frontend::filter(from, to, filter);
166 }
167 
168 struct CacheKey {
169   ConvolutionParams params;
170   cudnnBackendDescriptorType_t operation;
171   uint8_t x_alignment;
172   uint8_t w_alignment;
173   uint8_t y_alignment;
174 };
175 
176 struct CacheKeyFused {
177   ConvolutionParams params;
178   // No op here because it is assumed to be a forward conv op
179   uint8_t x_alignment;
180   uint8_t w_alignment;
181   uint8_t y_alignment;
182   uint8_t z_alignment;
183   uint8_t b_alignment;
184   // TODO: does it make sense to have this in the key? but alpha is a
185   // graph-level param...
186   float alpha;
187 };
188 
189 struct CacheKeyWrapper : ParamsWrapper<CacheKey> {
CacheKeyWrapperat::native::__anon11c0bfb50111::CacheKeyWrapper190   CacheKeyWrapper(
191       const cudnnBackendDescriptorType_t operation,
192       const Tensor& y,
193       const Tensor& x,
194       const Tensor& w,
195       const IntArrayRef padding,
196       const IntArrayRef stride,
197       const IntArrayRef dilation,
198       int64_t groups,
199       bool deterministic,
200       bool allow_tf32) {
201     at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(x, w);
202     setConvolutionParams(
203         &(this->pod.params),
204         x,
205         w,
206         padding,
207         stride,
208         dilation,
209         groups,
210         deterministic,
211         allow_tf32,
212         memory_format);
213     this->pod.operation = operation;
214     this->pod.x_alignment = getAlignment(x);
215     this->pod.y_alignment = getAlignment(y);
216     this->pod.w_alignment = getAlignment(w);
217   }
218 };
219 
220 struct CacheKeyFusedWrapper : ParamsWrapper<CacheKeyFused> {
CacheKeyFusedWrapperat::native::__anon11c0bfb50111::CacheKeyFusedWrapper221   CacheKeyFusedWrapper(
222       const Tensor& y,
223       const Tensor& x,
224       const Tensor& w,
225       const Tensor& z,
226       const Tensor& b,
227       const float alpha,
228       const IntArrayRef padding,
229       const IntArrayRef stride,
230       const IntArrayRef dilation,
231       int64_t groups,
232       bool deterministic,
233       bool allow_tf32) {
234     at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(x, w);
235     setConvolutionParams(
236         &(this->pod).params,
237         x,
238         w,
239         padding,
240         stride,
241         dilation,
242         groups,
243         deterministic,
244         allow_tf32,
245         memory_format);
246     this->pod.x_alignment = getAlignment(x);
247     this->pod.y_alignment = getAlignment(y);
248     this->pod.w_alignment = getAlignment(w);
249     this->pod.z_alignment = getAlignment(z);
250     this->pod.b_alignment = getAlignment(b);
251     this->pod.alpha = alpha;
252   }
253 };
254 
getLRUCacheLimit()255 static int getLRUCacheLimit() {
256   constexpr int DEFAULT_LIMIT =
257       10000; // roughly corresponds to 2GiB assuming 200KiB per ExecutionPlan
258   // 0 is used to indicate no limit
259   // negative values are used to indicate no caching
260   static int limit = [&] {
261     const char* val = getenv("TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT");
262     if (!val) {
263       return DEFAULT_LIMIT;
264     }
265     try {
266       return std::stoi(val);
267     } catch (std::invalid_argument const& e) {
268       TORCH_WARN(
269           "invalid TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT,",
270           " using default LRU cache limit of ",
271           DEFAULT_LIMIT,
272           " entries.");
273     } catch (std::out_of_range const& e) {
274       TORCH_WARN(
275           "invalid TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT,",
276           " using default LRU cache limit of ",
277           DEFAULT_LIMIT,
278           " entries.");
279     }
280     return DEFAULT_LIMIT;
281   }();
282   return limit;
283 }
284 
285 template <typename T, typename KeyType>
286 struct BenchmarkCache {
287   std::list<KeyType> engine_cache_order;
288   std::unordered_map<
289       KeyType,
290       std::pair<
291           cudnn_frontend::ExecutionPlan,
292           typename std::list<KeyType>::iterator>,
293       ParamsWrapperHash<KeyType>>
294       engine_cache;
295 
296   // no mutexes here as caches are now thread local for v8, can also return a
297   // pointer to the Execution Plan if we know it will not be invalidated by
298   // another thread
findat::native::__anon11c0bfb50111::BenchmarkCache299   cudnn_frontend::ExecutionPlan* find(const KeyType& key) {
300     const int lru_cache_limit = getLRUCacheLimit();
301     if (lru_cache_limit < 0) {
302       return nullptr;
303     }
304     auto it = engine_cache.find(key);
305     if (it == engine_cache.end()) {
306       return nullptr;
307     }
308     if (lru_cache_limit) {
309       // update most recently accessed
310       engine_cache_order.splice(
311           engine_cache_order.begin(), engine_cache_order, it->second.second);
312     }
313     return &(it->second.first);
314   }
315 
updateat::native::__anon11c0bfb50111::BenchmarkCache316   void update(const KeyType& key, T& results) {
317     int lru_cache_limit = getLRUCacheLimit();
318     if (lru_cache_limit < 0) {
319       return;
320     } else if (lru_cache_limit) {
321       auto it = engine_cache.find(key);
322       if (it == engine_cache.end()) {
323         if ((long)engine_cache.size() >= lru_cache_limit) {
324           auto erase_count = engine_cache.erase(engine_cache_order.back());
325           TORCH_INTERNAL_ASSERT(
326               erase_count == 1,
327               "CUDNN V8 LRU Cache Corrupted (eviction key not in map). Please report a bug to PyTorch.");
328           engine_cache_order.pop_back();
329         }
330         engine_cache_order.emplace_front(key);
331         engine_cache.emplace(
332             key, std::make_pair(results, engine_cache_order.begin()));
333       } else {
334         it->second.first = results;
335         // update most recently accessed
336         engine_cache_order.splice(
337             engine_cache_order.begin(), engine_cache_order, it->second.second);
338       }
339     } else {
340       engine_cache.erase(key);
341       engine_cache.emplace(
342           key,
343           std::make_pair(results, engine_cache_order.end())); // dummy iterator
344     }
345   }
346 };
347 
348 // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to
349 // be thread safe across all engines see Limitations in
350 // https://docs.nvidia.com/deeplearning/cudnn/release-notes/index.html
351 thread_local BenchmarkCache<cudnn_frontend::ExecutionPlan, CacheKeyWrapper>
352     benchmark_cache;
353 thread_local BenchmarkCache<cudnn_frontend::ExecutionPlan, CacheKeyFusedWrapper>
354     benchmark_cache_fused;
355 
356 } // namespace
357 
run_conv_plan(cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const cudnn_frontend::ExecutionPlan & plan,const cudnnBackendDescriptorType_t operation)358 void run_conv_plan(
359     cudnnHandle_t handle,
360     const Tensor& x,
361     const Tensor& y,
362     const Tensor& w,
363     const cudnn_frontend::ExecutionPlan& plan,
364     const cudnnBackendDescriptorType_t operation) {
365   c10::DeviceGuard g(x.options().device());
366   auto workspace_size = plan.getWorkspaceSize();
367   auto workspace_ptr =
368       c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
369   void* data_ptrs[3];
370 
371   if (operation == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
372     data_ptrs[0] = const_cast<void*>(x.const_data_ptr());
373     data_ptrs[1] = y.data_ptr();
374     data_ptrs[2] = const_cast<void*>(w.const_data_ptr());
375   } else if (
376       operation ==
377       CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
378     data_ptrs[0] = x.data_ptr();
379     data_ptrs[1] = const_cast<void*>(y.const_data_ptr());
380     data_ptrs[2] = const_cast<void*>(w.const_data_ptr());
381   } else if (
382       operation ==
383       CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
384     data_ptrs[0] = const_cast<void*>(x.const_data_ptr());
385     data_ptrs[1] = const_cast<void*>(y.const_data_ptr());
386     data_ptrs[2] = w.data_ptr();
387   } else {
388     data_ptrs[0] = x.data_ptr();
389     data_ptrs[1] = y.data_ptr();
390     data_ptrs[2] = w.data_ptr();
391   }
392 
393   int64_t uids[] = {'x', 'y', 'w'};
394   auto variantPack =
395       cudnn_frontend::VariantPackBuilder()
396           .setWorkspacePointer(workspace_size ? workspace_ptr.get() : nullptr)
397           .setDataPointers(3, data_ptrs)
398           .setUids(3, uids)
399           .build();
400   AT_CUDNN_CHECK(cudnnBackendExecute(
401       handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
402 }
403 
run_conv_plan_fused(cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b,const cudnn_frontend::ExecutionPlan & plan)404 void run_conv_plan_fused(
405     cudnnHandle_t handle,
406     const Tensor& x,
407     const Tensor& y,
408     const Tensor& w,
409     const Tensor& z,
410     const Tensor& b,
411     const cudnn_frontend::ExecutionPlan& plan) {
412   c10::DeviceGuard g(x.options().device());
413   auto workspace_size = plan.getWorkspaceSize();
414   auto workspace_ptr =
415       c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
416   void* data_ptrs[] = {
417       x.data_ptr(), y.data_ptr(), w.data_ptr(), z.data_ptr(), b.data_ptr()};
418   int64_t uids[] = {'x', 'y', 'w', 'z', 'b'};
419   auto variantPack =
420       cudnn_frontend::VariantPackBuilder()
421           .setWorkspacePointer(workspace_size ? workspace_ptr.get() : nullptr)
422           .setDataPointers(5, data_ptrs)
423           .setUids(5, uids)
424           .build();
425   AT_CUDNN_CHECK(cudnnBackendExecute(
426       handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
427 }
428 
build_opgraph(const cudnnHandle_t handle,const cudnnBackendDescriptorType_t desc,const Tensor & x,const Tensor & y,const Tensor & w,const CacheKeyWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation)429 auto build_opgraph(
430     const cudnnHandle_t handle,
431     const cudnnBackendDescriptorType_t desc,
432     const Tensor& x,
433     const Tensor& y,
434     const Tensor& w,
435     const CacheKeyWrapper& key,
436     const IntArrayRef padding,
437     const IntArrayRef stride,
438     const IntArrayRef dilation) {
439   auto op = cudnn_frontend::OperationBuilder(desc)
440                 .setxDesc(getTensorDescriptor(
441                     x, 'x', key.pod.x_alignment, key.pod.params.memory_format))
442                 .setyDesc(getTensorDescriptor(
443                     y, 'y', key.pod.y_alignment, key.pod.params.memory_format))
444                 .setwDesc(getTensorDescriptor(
445                     w, 'w', key.pod.w_alignment, key.pod.params.memory_format))
446                 .setcDesc(getConvDescriptor(
447                     key.pod.params.dataType,
448                     padding,
449                     stride,
450                     dilation,
451                     x.scalar_type()))
452                 .build();
453   std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
454   auto opGraph = cudnn_frontend::OperationGraphBuilder()
455                      .setHandle(handle)
456                      .setOperationGraph(ops.size(), ops.data())
457                      .build();
458   return opGraph;
459 }
460 
build_opgraph_fused(const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b,const float alpha,const CacheKeyFusedWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation)461 auto build_opgraph_fused(
462     const cudnnHandle_t handle,
463     const Tensor& x,
464     const Tensor& y,
465     const Tensor& w,
466     const Tensor& z,
467     const Tensor& b,
468     const float alpha,
469     const CacheKeyFusedWrapper& key,
470     const IntArrayRef padding,
471     const IntArrayRef stride,
472     const IntArrayRef dilation) {
473   // need computation to be done in FLOAT type regardless of reduced precision
474   // input
475   const auto precision = CUDNN_DATA_FLOAT;
476   auto addDesc = cudnn_frontend::PointWiseDescBuilder()
477                      .setMode(CUDNN_POINTWISE_ADD)
478                      .setMathPrecision(precision)
479                      .build();
480   auto addBiasDesc = cudnn_frontend::PointWiseDescBuilder()
481                          .setMode(CUDNN_POINTWISE_ADD)
482                          .setMathPrecision(precision)
483                          .build();
484   auto actDesc = cudnn_frontend::PointWiseDescBuilder()
485                      .setMode(CUDNN_POINTWISE_RELU_FWD)
486                      .setMathPrecision(precision)
487                      .build();
488   auto convDesc = getConvDescriptor(
489       key.pod.params.dataType, padding, stride, dilation, x.scalar_type());
490   const float alpha1 = 1.0;
491   const float alpha2 = alpha;
492   auto conv_op =
493       cudnn_frontend::OperationBuilder(
494           CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
495           .setxDesc(getTensorDescriptor(
496               x, 'x', key.pod.x_alignment, key.pod.params.memory_format))
497           // virtual output of conv
498           .setyDesc(getTensorDescriptorWithTypeVirtual(
499               y,
500               'C',
501               key.pod.y_alignment,
502               precision,
503               key.pod.params.memory_format,
504               true))
505           .setwDesc(getTensorDescriptor(
506               w, 'w', key.pod.w_alignment, key.pod.params.memory_format))
507           .setAlpha(alpha1)
508           .setcDesc(convDesc)
509           .build();
510   auto add_op =
511       cudnn_frontend::OperationBuilder(
512           CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
513           .setxDesc(conv_op.getOutputTensor())
514           .setbDesc(getTensorDescriptor(
515               z, 'z', key.pod.z_alignment, key.pod.params.memory_format))
516           // another virtual output (of add)
517           .setyDesc(getTensorDescriptorWithTypeVirtual(
518               y,
519               'A',
520               key.pod.y_alignment,
521               precision,
522               key.pod.params.memory_format,
523               true))
524           .setpwDesc(addDesc)
525           .setAlpha(alpha1)
526           .setAlpha2(alpha2)
527           .build();
528   auto add_bias_op =
529       cudnn_frontend::OperationBuilder(
530           CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
531           .setxDesc(add_op.getOutputTensor())
532           .setbDesc(getTensorDescriptor(
533               b, 'b', key.pod.b_alignment, key.pod.params.memory_format))
534           // another virtual output (of add bias)
535           .setyDesc(getTensorDescriptorWithTypeVirtual(
536               y,
537               'B',
538               key.pod.y_alignment,
539               precision,
540               key.pod.params.memory_format,
541               true))
542           .setpwDesc(addBiasDesc)
543           .build();
544   auto act_op =
545       cudnn_frontend::OperationBuilder(
546           CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
547           .setxDesc(add_bias_op.getOutputTensor())
548           // final output is in original datatype
549           .setyDesc(getTensorDescriptor(
550               y, 'y', key.pod.y_alignment, key.pod.params.memory_format))
551           .setpwDesc(actDesc)
552           .build();
553   std::array<cudnn_frontend::Operation const*, 4> ops = {
554       &conv_op, &add_op, &add_bias_op, &act_op};
555   auto opGraph = cudnn_frontend::OperationGraphBuilder()
556                      .setHandle(handle)
557                      .setOperationGraph(ops.size(), ops.data())
558                      .build();
559   return opGraph;
560 }
561 
get_generator_sources(const cudnnBackendDescriptorType_t & desc,const Tensor & x,const bool deterministic,const bool allow_tf32,const cudnnBackendHeurMode_t heur_mode,const bool heuristic,const bool fallback)562 auto get_generator_sources(
563     const cudnnBackendDescriptorType_t& desc,
564     const Tensor& x,
565     const bool deterministic,
566     const bool allow_tf32,
567     const cudnnBackendHeurMode_t heur_mode,
568     const bool heuristic,
569     const bool fallback) {
570   // Method for engine config generator based on heuristics
571   const auto heurgen_method =
572       [/*&desc,*/ &x, deterministic, allow_tf32, heur_mode](
573           cudnn_frontend::OperationGraph& opGraph)
574       -> cudnn_frontend::EngineConfigList {
575     auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
576                           .setOperationGraph(opGraph)
577                           .setHeurMode(heur_mode)
578                           .build();
579     auto& engine_configs =
580         heuristics.getEngineConfig(heuristics.getEngineConfigCount());
581     cudnn_frontend::EngineConfigList filtered_configs;
582     filterEngineConfigs(
583         engine_configs,
584         filtered_configs,
585         deterministic,
586         allow_tf32,
587         x.scalar_type());
588     return filtered_configs;
589   };
590   // Method for engine config generator based on fallback list
591   const auto fallback_method = [&desc, &x, deterministic, allow_tf32](
592                                    cudnn_frontend::OperationGraph& opGraph)
593       -> cudnn_frontend::EngineConfigList {
594     auto fallback = cudnn_frontend::EngineFallbackListBuilder()
595                         .setOperationGraph(opGraph)
596                         .setOperation(desc)
597                         .build();
598     auto& fallback_list = fallback.getFallbackList();
599     cudnn_frontend::EngineConfigList filtered_configs;
600     filterEngineConfigs(
601         fallback_list,
602         filtered_configs,
603         deterministic,
604         allow_tf32,
605         x.scalar_type());
606     return filtered_configs;
607   };
608   if (heuristic && fallback) {
609     std::vector<cudnn_frontend::GeneratorSource> sources = {
610         heurgen_method, fallback_method};
611     return sources;
612   } else if (heuristic) {
613     std::vector<cudnn_frontend::GeneratorSource> sources = {heurgen_method};
614     return sources;
615   } else {
616     std::vector<cudnn_frontend::GeneratorSource> sources = {fallback_method};
617     return sources;
618   }
619 }
620 
get_available_workspace()621 int64_t get_available_workspace() {
622   c10::DeviceIndex device = 0;
623   C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
624   size_t max_block_size = 0;
625   c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
626   return static_cast<int64_t>(max_block_size);
627 }
628 
629 static nlohmann::json errata_json_handle;
630 
plan_errata_exception(const cudnnHandle_t handle,const std::string & executionPlanTag)631 bool plan_errata_exception(
632     const cudnnHandle_t handle,
633     const std::string& executionPlanTag) {
634   static bool has_json =
635       cudnn_frontend::load_from_config(errata_json_handle, "");
636   if (!has_json) {
637     return false;
638   } else {
639     return cudnn_frontend::check_errata(
640         errata_json_handle, executionPlanTag, handle, []() { return true; });
641   }
642 }
643 
generate_and_filter_plans(const cudnnHandle_t handle,cudnn_frontend::OperationGraph & opGraph,cudnn_frontend::EngineConfigGenerator & generator,const Tensor & x,cudnn_frontend::executionPlans_t & valid_plans,at::DataPtr & workspace_ptr)644 void generate_and_filter_plans(
645     const cudnnHandle_t handle,
646     cudnn_frontend::OperationGraph& opGraph,
647     cudnn_frontend::EngineConfigGenerator& generator,
648     const Tensor& x,
649     cudnn_frontend::executionPlans_t& valid_plans,
650     at::DataPtr& workspace_ptr) {
651   auto initial_predicate_function =
652       [&](cudnn_frontend::ExecutionPlan const& plan) -> bool {
653     return plan_errata_exception(handle, plan.getTag());
654   };
655   auto plans =
656       generator.cudnnGetPlan(handle, opGraph, initial_predicate_function);
657   int64_t max_block_size = get_available_workspace();
658   int64_t max_workspace_size = 0;
659   std::for_each(
660       plans.begin(), plans.end(), [&](cudnn_frontend::ExecutionPlan& plan) {
661         int64_t curr_workspace_size = plan.getWorkspaceSize();
662         if (curr_workspace_size <= max_block_size) {
663           if (curr_workspace_size > max_workspace_size) {
664             max_workspace_size = plan.getWorkspaceSize();
665           }
666           valid_plans.emplace_back(std::move(plan));
667         }
668       });
669   TORCH_CHECK_WITH(
670       OutOfMemoryError,
671       max_workspace_size < 1_TiB,
672       "Not enough memory for workspace!");
673   bool remove_invalid = false;
674   while (max_workspace_size) {
675     try {
676       workspace_ptr =
677           c10::cuda::CUDACachingAllocator::get()->allocate(max_workspace_size);
678       break;
679     } catch (c10::OutOfMemoryError& e) {
680       max_workspace_size /= 2;
681       (void)cudaGetLastError(); // clear CUDA error
682       remove_invalid = true;
683     }
684   }
685   if (remove_invalid) {
686     cudnn_frontend::executionPlans_t new_valid_plans;
687     for (auto& plan : valid_plans) {
688       if (plan.getWorkspaceSize() <= max_workspace_size) {
689         new_valid_plans.emplace_back(std::move(plan));
690       }
691     }
692     valid_plans = std::move(new_valid_plans);
693   }
694 }
695 
get_plans_from_find(const cudnnHandle_t handle,const cudnnBackendDescriptorType_t desc,const Tensor & x,const Tensor & y,const Tensor & w,const CacheKeyWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const bool deterministic,const bool allow_tf32)696 auto get_plans_from_find(
697     const cudnnHandle_t handle,
698     const cudnnBackendDescriptorType_t desc,
699     const Tensor& x,
700     const Tensor& y,
701     const Tensor& w,
702     const CacheKeyWrapper& key,
703     const IntArrayRef padding,
704     const IntArrayRef stride,
705     const IntArrayRef dilation,
706     const bool deterministic,
707     const bool allow_tf32) {
708   auto opGraph =
709       build_opgraph(handle, desc, x, y, w, key, padding, stride, dilation);
710   void* data_ptrs[] = {x.data_ptr(), y.data_ptr(), w.data_ptr()};
711   int64_t uids[] = {'x', 'y', 'w'};
712   // We don't care about getting the best ordering of algos if we're roing to
713   // run all of them
714   auto sources = get_generator_sources(
715       desc, x, deterministic, allow_tf32, CUDNN_HEUR_MODE_INSTANT, true, true);
716   cudnn_frontend::EngineConfigGenerator generator(
717       sources.size(), sources.data());
718   cudnn_frontend::executionPlans_t valid_plans;
719   c10::DeviceGuard g(x.options().device());
720   at::DataPtr workspace_ptr;
721   generate_and_filter_plans(
722       handle, opGraph, generator, x, valid_plans, workspace_ptr);
723   auto variantPack =
724       cudnn_frontend::VariantPackBuilder()
725           .setDataPointers(3, data_ptrs)
726           .setUids(3, uids)
727           .setWorkspacePointer(workspace_ptr ? workspace_ptr.get() : nullptr)
728           .build();
729 
730   auto benchmark_limit = at::globalContext().benchmarkLimitCuDNN();
731   benchmark_limit = benchmark_limit ? benchmark_limit : 10000;
732   auto plans = cudnn_frontend::time_sorted_plan<
733       cudnn_frontend::CudnnFindSamplingTechnique::CUDNN_FIND_SAMPLE_ONCE>(
734       handle, std::move(valid_plans), variantPack, benchmark_limit);
735 
736   cudnn_frontend::executionPlans_t sorted_plans;
737   for (auto& plan : plans) {
738     sorted_plans.emplace_back(std::move(plan));
739   }
740   return sorted_plans;
741 }
742 
get_plans_from_find_fused(const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b,const float alpha,const CacheKeyFusedWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const bool deterministic,const bool allow_tf32)743 auto get_plans_from_find_fused(
744     const cudnnHandle_t handle,
745     const Tensor& x,
746     const Tensor& y,
747     const Tensor& w,
748     const Tensor& z,
749     const Tensor& b,
750     const float alpha,
751     const CacheKeyFusedWrapper& key,
752     const IntArrayRef padding,
753     const IntArrayRef stride,
754     const IntArrayRef dilation,
755     const bool deterministic,
756     const bool allow_tf32) {
757   auto opGraph = build_opgraph_fused(
758       handle, x, y, w, z, b, alpha, key, padding, stride, dilation);
759   void* data_ptrs[] = {
760       x.data_ptr(), y.data_ptr(), w.data_ptr(), z.data_ptr(), b.data_ptr()};
761   int64_t uids[] = {'x', 'y', 'w', 'z', 'b'};
762 
763   auto sources = get_generator_sources(
764       CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
765       x,
766       deterministic,
767       allow_tf32,
768       CUDNN_HEUR_MODE_INSTANT,
769       true,
770       true);
771   cudnn_frontend::EngineConfigGenerator generator(
772       sources.size(), sources.data());
773   cudnn_frontend::executionPlans_t valid_plans;
774   c10::DeviceGuard g(x.options().device());
775   at::DataPtr workspace_ptr;
776   generate_and_filter_plans(
777       handle, opGraph, generator, x, valid_plans, workspace_ptr);
778   auto variantPack =
779       cudnn_frontend::VariantPackBuilder()
780           .setDataPointers(5, data_ptrs)
781           .setUids(5, uids)
782           .setWorkspacePointer(workspace_ptr ? workspace_ptr.get() : nullptr)
783           .build();
784 
785   auto benchmark_limit = at::globalContext().benchmarkLimitCuDNN();
786   benchmark_limit = benchmark_limit ? benchmark_limit : 10000;
787   auto plans = cudnn_frontend::time_sorted_plan<
788       cudnn_frontend::CudnnFindSamplingTechnique::CUDNN_FIND_SAMPLE_ONCE>(
789       handle, std::move(valid_plans), variantPack, benchmark_limit);
790 
791   cudnn_frontend::executionPlans_t sorted_plans;
792   for (auto& plan : plans) {
793     sorted_plans.emplace_back(std::move(plan));
794   }
795   return sorted_plans;
796 }
797 
798 // We only get configs from this stage to avoid building unnecessary plans that
799 // are never executed
get_configs_from_heuristics(const cudnnHandle_t handle,const cudnnBackendDescriptorType_t desc,std::string & opgraph_tag,const Tensor & x,const Tensor & y,const Tensor & w,const CacheKeyWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const bool deterministic,const bool allow_tf32,const bool fallback)800 auto get_configs_from_heuristics(
801     const cudnnHandle_t handle,
802     const cudnnBackendDescriptorType_t desc,
803     std::string& opgraph_tag,
804     const Tensor& x,
805     const Tensor& y,
806     const Tensor& w,
807     const CacheKeyWrapper& key,
808     const IntArrayRef padding,
809     const IntArrayRef stride,
810     const IntArrayRef dilation,
811     const bool deterministic,
812     const bool allow_tf32,
813     const bool fallback) {
814   auto opGraph =
815       build_opgraph(handle, desc, x, y, w, key, padding, stride, dilation);
816   opgraph_tag = opGraph.getTag();
817   auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b()
818       ? CUDNN_HEUR_MODE_B
819       : CUDNN_HEUR_MODE_INSTANT;
820   auto sources = get_generator_sources(
821       desc, x, deterministic, allow_tf32, heuristic_mode, !fallback, fallback);
822 
823   cudnn_frontend::EngineConfigGenerator generator(
824       sources.size(), sources.data());
825   auto configs = generator.generate_engine_config(opGraph);
826   return configs;
827 }
828 
get_configs_from_heuristics_fused(const cudnnHandle_t handle,std::string & opgraph_tag,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b,const float alpha,const CacheKeyFusedWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const bool deterministic,const bool allow_tf32,const bool fallback)829 auto get_configs_from_heuristics_fused(
830     const cudnnHandle_t handle,
831     std::string& opgraph_tag,
832     const Tensor& x,
833     const Tensor& y,
834     const Tensor& w,
835     const Tensor& z,
836     const Tensor& b,
837     const float alpha,
838     const CacheKeyFusedWrapper& key,
839     const IntArrayRef padding,
840     const IntArrayRef stride,
841     const IntArrayRef dilation,
842     const bool deterministic,
843     const bool allow_tf32,
844     const bool fallback) {
845   auto opGraph = build_opgraph_fused(
846       handle, x, y, w, z, b, alpha, key, padding, stride, dilation);
847   opgraph_tag = opGraph.getTag();
848   auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b()
849       ? CUDNN_HEUR_MODE_B
850       : CUDNN_HEUR_MODE_INSTANT;
851   auto sources = get_generator_sources(
852       CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
853       x,
854       deterministic,
855       allow_tf32,
856       heuristic_mode,
857       !fallback,
858       fallback);
859 
860   cudnn_frontend::EngineConfigGenerator generator(
861       sources.size(), sources.data());
862   auto configs = generator.generate_engine_config(opGraph);
863   return configs;
864 }
865 
try_plans(cudnn_frontend::executionPlans_t & plans,const CacheKeyWrapper & key,const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const cudnnBackendDescriptorType_t operation)866 void try_plans(
867     cudnn_frontend::executionPlans_t& plans,
868     const CacheKeyWrapper& key,
869     const cudnnHandle_t handle,
870     const Tensor& x,
871     const Tensor& y,
872     const Tensor& w,
873     const cudnnBackendDescriptorType_t operation) {
874   for (auto& plan : plans) {
875     try {
876       run_conv_plan(handle, x, y, w, plan, operation);
877       benchmark_cache.update(key, plan);
878       return;
879     } catch (cudnn_frontend::cudnnException& e) {
880     } catch (CuDNNError& e) {
881     } catch (c10::OutOfMemoryError& e) {
882       (void)cudaGetLastError(); // clear CUDA error
883     }
884   }
885   TORCH_CHECK(
886       false, "FIND was unable to find an engine to execute this computation");
887 }
888 
try_plans_fused(cudnn_frontend::executionPlans_t & plans,const CacheKeyFusedWrapper & key,const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b)889 void try_plans_fused(
890     cudnn_frontend::executionPlans_t& plans,
891     const CacheKeyFusedWrapper& key,
892     const cudnnHandle_t handle,
893     const Tensor& x,
894     const Tensor& y,
895     const Tensor& w,
896     const Tensor& z,
897     const Tensor& b) {
898   for (auto& plan : plans) {
899     try {
900       run_conv_plan_fused(handle, x, y, w, z, b, plan);
901       benchmark_cache_fused.update(key, plan);
902       return;
903     } catch (cudnn_frontend::cudnnException& e) {
904     } catch (CuDNNError& e) {
905     } catch (c10::OutOfMemoryError& e) {
906       (void)cudaGetLastError(); // clear CUDA error
907     }
908   }
909   TORCH_CHECK(
910       false, "FIND was unable to find an engine to execute this computation");
911 }
912 
try_configs(cudnn_frontend::EngineConfigList & configs,const std::string & opgraph_tag,const CacheKeyWrapper & key,const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const cudnnBackendDescriptorType_t operation)913 bool try_configs(
914     cudnn_frontend::EngineConfigList& configs,
915     const std::string& opgraph_tag,
916     const CacheKeyWrapper& key,
917     const cudnnHandle_t handle,
918     const Tensor& x,
919     const Tensor& y,
920     const Tensor& w,
921     const cudnnBackendDescriptorType_t operation) {
922   for (auto& config : configs) {
923     try {
924       auto plan = cudnn_frontend::ExecutionPlanBuilder()
925                       .setHandle(handle)
926                       .setEngineConfig(config, opgraph_tag)
927                       .build();
928       if (plan_errata_exception(handle, plan.getTag())) {
929         continue;
930       }
931       run_conv_plan(handle, x, y, w, plan, operation);
932       benchmark_cache.update(key, plan);
933       return true;
934     } catch (cudnn_frontend::cudnnException& e) {
935     } catch (CuDNNError& e) {
936     } catch (c10::OutOfMemoryError& e) {
937       (void)cudaGetLastError(); // clear CUDA error
938     }
939   }
940   return false;
941 }
942 
try_configs_fused(cudnn_frontend::EngineConfigList & configs,const std::string & opgraph_tag,const CacheKeyFusedWrapper & key,const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b)943 bool try_configs_fused(
944     cudnn_frontend::EngineConfigList& configs,
945     const std::string& opgraph_tag,
946     const CacheKeyFusedWrapper& key,
947     const cudnnHandle_t handle,
948     const Tensor& x,
949     const Tensor& y,
950     const Tensor& w,
951     const Tensor& z,
952     const Tensor& b) {
953   for (auto& config : configs) {
954     try {
955       auto plan = cudnn_frontend::ExecutionPlanBuilder()
956                       .setHandle(handle)
957                       .setEngineConfig(config, opgraph_tag)
958                       .build();
959       if (plan_errata_exception(handle, plan.getTag())) {
960         continue;
961       }
962       run_conv_plan_fused(handle, x, y, w, z, b, plan);
963       benchmark_cache_fused.update(key, plan);
964       return true;
965     } catch (cudnn_frontend::cudnnException& e) {
966     } catch (CuDNNError& e) {
967     } catch (c10::OutOfMemoryError& e) {
968       (void)cudaGetLastError(); // clear CUDA error
969     }
970   }
971   return false;
972 }
973 
run_single_conv(const cudnnBackendDescriptorType_t operation,const Tensor & x,const Tensor & y,const Tensor & w,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool benchmark,const bool deterministic,const bool allow_tf32)974 void run_single_conv(
975     const cudnnBackendDescriptorType_t operation,
976     const Tensor& x,
977     const Tensor& y,
978     const Tensor& w,
979     const IntArrayRef padding,
980     const IntArrayRef stride,
981     const IntArrayRef dilation,
982     const int64_t groups,
983     const bool benchmark,
984     const bool deterministic,
985     const bool allow_tf32) {
986   cudnnHandle_t handle = getCudnnHandle();
987   CacheKeyWrapper key(
988       operation,
989       y,
990       x,
991       w,
992       padding,
993       stride,
994       dilation,
995       groups,
996       deterministic,
997       allow_tf32);
998   // TODO: is this thread safe if cache is updated? is pointer stale?
999   auto search = benchmark_cache.find(key);
1000   if (search) {
1001     try {
1002       run_conv_plan(handle, x, y, w, *search, operation);
1003       return;
1004     } catch (c10::OutOfMemoryError& e) {
1005       (void)cudaGetLastError(); // clear CUDA error
1006     }
1007   }
1008   if (!benchmark) {
1009     std::string opgraph_tag; // extra data needed for errata filter
1010     // heuristic configs
1011     cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics(
1012         handle,
1013         operation,
1014         opgraph_tag,
1015         x,
1016         y,
1017         w,
1018         key,
1019         padding,
1020         stride,
1021         dilation,
1022         deterministic,
1023         allow_tf32,
1024         false);
1025     if (try_configs(configs, opgraph_tag, key, handle, x, y, w, operation)) {
1026       return;
1027     }
1028     // fallback configs
1029     configs = get_configs_from_heuristics(
1030         handle,
1031         operation,
1032         opgraph_tag,
1033         x,
1034         y,
1035         w,
1036         key,
1037         padding,
1038         stride,
1039         dilation,
1040         deterministic,
1041         allow_tf32,
1042         true);
1043     if (try_configs(configs, opgraph_tag, key, handle, x, y, w, operation)) {
1044       return;
1045     }
1046     TORCH_CHECK(
1047         false, "GET was unable to find an engine to execute this computation");
1048   } else {
1049     cudnn_frontend::executionPlans_t plans = get_plans_from_find(
1050         handle,
1051         operation,
1052         x,
1053         y,
1054         w,
1055         key,
1056         padding,
1057         stride,
1058         dilation,
1059         deterministic,
1060         allow_tf32);
1061     // Replicate v7 behavior: clear cached blocks as benchmark incurs
1062     // significant memory consumptiont that is not needed after this step
1063     if (at::native::_cudnn_get_conv_benchmark_empty_cache()) {
1064       c10::cuda::CUDACachingAllocator::emptyCache();
1065     }
1066     try_plans(plans, key, handle, x, y, w, operation);
1067   }
1068 }
1069 
run_fused_conv(const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b,float alpha,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,const bool benchmark,const bool deterministic,const bool allow_tf32)1070 void run_fused_conv(
1071     const Tensor& x,
1072     const Tensor& y,
1073     const Tensor& w,
1074     const Tensor& z,
1075     const Tensor& b,
1076     float alpha,
1077     IntArrayRef stride,
1078     IntArrayRef padding,
1079     IntArrayRef dilation,
1080     int64_t groups,
1081     const bool benchmark,
1082     const bool deterministic,
1083     const bool allow_tf32) {
1084   cudnnHandle_t handle = getCudnnHandle();
1085 
1086   CacheKeyFusedWrapper key(
1087       y,
1088       x,
1089       w,
1090       z,
1091       b,
1092       alpha,
1093       padding,
1094       stride,
1095       dilation,
1096       groups,
1097       deterministic,
1098       allow_tf32);
1099   auto search = benchmark_cache_fused.find(key);
1100   if (search) {
1101     try {
1102       run_conv_plan_fused(handle, x, y, w, z, b, *search);
1103       return;
1104     } catch (c10::OutOfMemoryError& e) {
1105       (void)cudaGetLastError(); // clear CUDA error
1106     }
1107   }
1108   if (!benchmark) {
1109     std::string opgraph_tag; // extra data needed for errata filter
1110     // heuristic configs
1111     cudnn_frontend::EngineConfigList configs =
1112         get_configs_from_heuristics_fused(
1113             handle,
1114             opgraph_tag,
1115             x,
1116             y,
1117             w,
1118             z,
1119             b,
1120             alpha,
1121             key,
1122             padding,
1123             stride,
1124             dilation,
1125             deterministic,
1126             allow_tf32,
1127             false);
1128     if (try_configs_fused(configs, opgraph_tag, key, handle, x, y, w, z, b)) {
1129       return;
1130     }
1131     // fallback configs
1132     configs = get_configs_from_heuristics_fused(
1133         handle,
1134         opgraph_tag,
1135         x,
1136         y,
1137         w,
1138         z,
1139         b,
1140         alpha,
1141         key,
1142         padding,
1143         stride,
1144         dilation,
1145         deterministic,
1146         allow_tf32,
1147         true);
1148     if (try_configs_fused(configs, opgraph_tag, key, handle, x, y, w, z, b)) {
1149       return;
1150     }
1151     TORCH_CHECK(
1152         false, "GET was unable to find an engine to execute this computation");
1153   } else {
1154     cudnn_frontend::executionPlans_t plans = get_plans_from_find_fused(
1155         handle,
1156         x,
1157         y,
1158         w,
1159         z,
1160         b,
1161         alpha,
1162         key,
1163         padding,
1164         stride,
1165         dilation,
1166         deterministic,
1167         allow_tf32);
1168     try_plans_fused(plans, key, handle, x, y, w, z, b);
1169   }
1170 }
1171 
raw_cudnn_convolution_forward_out(const Tensor & output,const Tensor & input,const Tensor & weight,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool benchmark,const bool deterministic,const bool allow_tf32)1172 void raw_cudnn_convolution_forward_out(
1173     const Tensor& output,
1174     const Tensor& input,
1175     const Tensor& weight,
1176     const IntArrayRef padding,
1177     const IntArrayRef stride,
1178     const IntArrayRef dilation,
1179     const int64_t groups,
1180     const bool benchmark,
1181     const bool deterministic,
1182     const bool allow_tf32) {
1183   if (output.numel() == 0) {
1184     return;
1185   }
1186   if (at::native::cudnnv8_enabled_check_debug()) {
1187     run_single_conv(
1188         CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
1189         input,
1190         output,
1191         weight,
1192         padding,
1193         stride,
1194         dilation,
1195         groups,
1196         benchmark,
1197         deterministic,
1198         allow_tf32);
1199   } else {
1200     raw_cudnn_convolution_forward_out_v7(
1201         output,
1202         input,
1203         weight,
1204         padding,
1205         stride,
1206         dilation,
1207         groups,
1208         benchmark,
1209         deterministic,
1210         allow_tf32);
1211   }
1212 }
1213 
raw_cudnn_convolution_backward_input_out(const at::Tensor & grad_input,const at::Tensor & grad_output,const at::Tensor & weight,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool benchmark,const bool deterministic,const bool allow_tf32)1214 void raw_cudnn_convolution_backward_input_out(
1215     const at::Tensor& grad_input,
1216     const at::Tensor& grad_output,
1217     const at::Tensor& weight,
1218     const IntArrayRef padding,
1219     const IntArrayRef stride,
1220     const IntArrayRef dilation,
1221     const int64_t groups,
1222     const bool benchmark,
1223     const bool deterministic,
1224     const bool allow_tf32) {
1225   if (grad_input.numel() == 0) {
1226     return;
1227   }
1228   if (at::native::cudnnv8_enabled_check_debug()) {
1229     run_single_conv(
1230         CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR,
1231         grad_input,
1232         grad_output,
1233         weight,
1234         padding,
1235         stride,
1236         dilation,
1237         groups,
1238         benchmark,
1239         deterministic,
1240         allow_tf32);
1241   } else {
1242     raw_cudnn_convolution_backward_input_out_v7(
1243         grad_input,
1244         grad_output,
1245         weight,
1246         padding,
1247         stride,
1248         dilation,
1249         groups,
1250         benchmark,
1251         deterministic,
1252         allow_tf32);
1253   }
1254 }
1255 
raw_cudnn_convolution_backward_weight_out(const Tensor & grad_weight,const Tensor & grad_output,const Tensor & input,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool benchmark,const bool deterministic,const bool allow_tf32)1256 void raw_cudnn_convolution_backward_weight_out(
1257     const Tensor& grad_weight,
1258     const Tensor& grad_output,
1259     const Tensor& input,
1260     const IntArrayRef padding,
1261     const IntArrayRef stride,
1262     const IntArrayRef dilation,
1263     const int64_t groups,
1264     const bool benchmark,
1265     const bool deterministic,
1266     const bool allow_tf32) {
1267   if (grad_weight.numel() == 0) {
1268     return;
1269   }
1270   if (at::native::cudnnv8_enabled_check_debug()) {
1271     run_single_conv(
1272         CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR,
1273         input,
1274         grad_output,
1275         grad_weight,
1276         padding,
1277         stride,
1278         dilation,
1279         groups,
1280         benchmark,
1281         deterministic,
1282         allow_tf32);
1283   } else {
1284     raw_cudnn_convolution_backward_weight_out_v7(
1285         grad_weight,
1286         grad_output,
1287         input,
1288         padding,
1289         stride,
1290         dilation,
1291         groups,
1292         benchmark,
1293         deterministic,
1294         allow_tf32);
1295   }
1296 }
1297 
raw_cudnn_convolution_add_relu_out(const Tensor & output,const Tensor & input,const Tensor & weight,const Tensor & z,float alpha,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)1298 void raw_cudnn_convolution_add_relu_out(
1299     const Tensor& output,
1300     const Tensor& input,
1301     const Tensor& weight,
1302     const Tensor& z,
1303     float alpha,
1304     const Tensor& bias,
1305     IntArrayRef stride,
1306     IntArrayRef padding,
1307     IntArrayRef dilation,
1308     int64_t groups,
1309     bool benchmark,
1310     bool deterministic,
1311     bool allow_tf32) {
1312   if (output.numel() == 0) {
1313     return;
1314   }
1315   if (at::native::cudnnv8_enabled_check_debug()) {
1316     auto bias_ = input.ndimension() == 4
1317         ? bias.view({1, bias.numel(), 1, 1})
1318         : bias.view({1, bias.numel(), 1, 1, 1});
1319     run_fused_conv(
1320         input,
1321         output,
1322         weight,
1323         z,
1324         bias_,
1325         alpha,
1326         stride,
1327         padding,
1328         dilation,
1329         groups,
1330         benchmark,
1331         deterministic,
1332         allow_tf32);
1333   } else {
1334     raw_cudnn_convolution_add_relu_out_v7(
1335         output,
1336         input,
1337         weight,
1338         z,
1339         alpha,
1340         bias,
1341         stride,
1342         padding,
1343         dilation,
1344         groups,
1345         benchmark,
1346         deterministic,
1347         allow_tf32);
1348   }
1349 }
1350 
1351 } // namespace native
1352 } // namespace at
1353 
1354 #endif // AT_CUDNN_ENABLED
1355