1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2
3 #include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
4
5 #if AT_CUDNN_ENABLED()
6
7 #include <ATen/cudnn/cudnn-wrapper.h>
8
9 #include <c10/macros/Macros.h>
10
11 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override")
12 #include <cudnn_frontend.h>
13 C10_DIAGNOSTIC_POP()
14
15 #include <ATen/TensorUtils.h>
16 #include <ATen/core/Tensor.h>
17 #include <ATen/cuda/Exceptions.h>
18 #include <ATen/cudnn/Handle.h>
19 #include <ATen/native/ConvUtils.h>
20 #include <ATen/native/cudnn/ConvShared.h>
21 #include <ATen/native/utils/ParamsHash.h>
22 #include <cudnn_frontend_find_plan.h>
23 #include <cudnn_frontend_get_plan.h>
24
25 #include <c10/cuda/CUDACachingAllocator.h>
26 #include <c10/cuda/CUDAException.h>
27 #include <c10/util/env.h>
28
29 #include <list>
30 #include <unordered_map>
31
32 #ifndef AT_PER_OPERATOR_HEADERS
33 #include <ATen/Functions.h>
34 #else
35 #include <ATen/ops/empty.h>
36 #endif
37
38 #ifdef __linux__
39 #include <dlfcn.h>
40 #endif
41
42 namespace at {
43 namespace native {
44
45 namespace {
46
47 // TODO: remove duplicate code in Conv_v7.cpp
operator ""_TiB(unsigned long long n)48 constexpr int64_t operator"" _TiB(unsigned long long n) {
49 return size_t(n) << 40;
50 }
51
getAlignment(const Tensor & t)52 uint8_t getAlignment(const Tensor& t) {
53 // alignment are in bytes
54 uint8_t alignment = 1;
55 uintptr_t address = reinterpret_cast<uintptr_t>(t.const_data_ptr());
56 for (; alignment < 32; alignment *= 2) {
57 if (address % (alignment * 2)) {
58 return alignment;
59 }
60 }
61 return alignment;
62 }
63
getTensorDescriptorWithTypeVirtual(const Tensor & t,const int64_t id,const uint8_t alignment,const cudnnDataType_t dataType,const at::MemoryFormat memory_format,const bool _virtual)64 cudnn_frontend::Tensor getTensorDescriptorWithTypeVirtual(
65 const Tensor& t,
66 const int64_t id,
67 const uint8_t alignment,
68 const cudnnDataType_t dataType,
69 const at::MemoryFormat memory_format,
70 const bool _virtual) {
71 #if defined(__linux__) && !defined(FBCODE_CAFFE2) && CUDNN_MAJOR == 8 && \
72 CUDNN_MINOR > 5
73 // Workaround for cudnn error handling deficiency, that results in a crash on
74 // Ubuntu-22+ if `libnvrtc.so` is not found on the system, which strictly
75 // speaking is not necessary for usecases below See
76 // https://github.com/pytorch/pytorch/issues/97041
77 static C10_UNUSED auto cudnn_cnn_infer_handler = [] {
78 void* handle = dlopen("libcudnn_cnn_infer.so.8", RTLD_LAZY);
79 char* err = dlerror();
80 if (!handle) {
81 TORCH_WARN(
82 "Attempt to open cnn_infer failed: handle=", handle, " error: ", err);
83 } else if (err) {
84 TORCH_WARN("Applied workaround for CuDNN issue, install nvrtc.so");
85 }
86 return handle;
87 }();
88 #endif
89 auto sizes = t.sizes();
90 auto strides = t.strides();
91 bool channels_last = memory_format == at::MemoryFormat::ChannelsLast ||
92 memory_format == at::MemoryFormat::ChannelsLast3d;
93
94 std::vector<int64_t> strides_copy(std::begin(strides), std::end(strides));
95 fixSizeOneDimStride<int64_t>(
96 sizes.size(), &sizes[0], (int64_t*)&strides_copy[0], channels_last);
97 auto r = cudnn_frontend::TensorBuilder()
98 .setDim(sizes.size(), sizes.data())
99 .setStrides(strides_copy.size(), strides_copy.data())
100 .setId(id)
101 .setAlignment(alignment)
102 .setDataType(dataType)
103 .setVirtual(_virtual)
104 .build();
105 return r;
106 }
107
getTensorDescriptor(const Tensor & t,const int64_t id,const uint8_t alignment,const at::MemoryFormat memory_format)108 cudnn_frontend::Tensor getTensorDescriptor(
109 const Tensor& t,
110 const int64_t id,
111 const uint8_t alignment,
112 const at::MemoryFormat memory_format) {
113 return getTensorDescriptorWithTypeVirtual(
114 t, id, alignment, getCudnnDataType(t), memory_format, false);
115 }
116
getConvDescriptor(cudnnDataType_t dataType,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,const at::ScalarType scalar_type)117 cudnn_frontend::ConvDesc_v8 getConvDescriptor(
118 cudnnDataType_t dataType,
119 IntArrayRef padding,
120 IntArrayRef stride,
121 IntArrayRef dilation,
122 const at::ScalarType scalar_type) {
123 uint64_t convDim = stride.size();
124 if (scalar_type == kBFloat16 || scalar_type == kHalf) {
125 dataType = CUDNN_DATA_FLOAT;
126 }
127 return cudnn_frontend::ConvDescBuilder()
128 .setDataType(dataType)
129 .setMathMode(CUDNN_CROSS_CORRELATION)
130 .setNDims(convDim)
131 .setStrides(convDim, stride.data())
132 .setPrePadding(convDim, padding.data())
133 .setPostPadding(convDim, padding.data())
134 .setDilation(convDim, dilation.data())
135 .build();
136 }
137
filterEngineConfigs(cudnn_frontend::EngineConfigList & from,cudnn_frontend::EngineConfigList & to,bool deterministic,bool allow_tf32,c10::ScalarType scalar_type)138 void filterEngineConfigs(
139 cudnn_frontend::EngineConfigList& from,
140 cudnn_frontend::EngineConfigList& to,
141 bool deterministic,
142 bool allow_tf32,
143 c10::ScalarType scalar_type) {
144 auto filter = [=](cudnnBackendDescriptor_t c) {
145 if (deterministic) {
146 if (cudnn_frontend::hasNumericalNote<
147 CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(c)) {
148 return true;
149 }
150 }
151 if (cudnn_frontend::hasNumericalNote<
152 CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) {
153 return true;
154 }
155 if (scalar_type == kFloat) {
156 // TODO: check under which conditions this is OK
157 if (!allow_tf32 &&
158 cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(
159 c)) {
160 return true;
161 }
162 }
163 return false;
164 };
165 cudnn_frontend::filter(from, to, filter);
166 }
167
168 struct CacheKey {
169 ConvolutionParams params;
170 cudnnBackendDescriptorType_t operation;
171 uint8_t x_alignment;
172 uint8_t w_alignment;
173 uint8_t y_alignment;
174 };
175
176 struct CacheKeyFused {
177 ConvolutionParams params;
178 // No op here because it is assumed to be a forward conv op
179 uint8_t x_alignment;
180 uint8_t w_alignment;
181 uint8_t y_alignment;
182 uint8_t z_alignment;
183 uint8_t b_alignment;
184 // TODO: does it make sense to have this in the key? but alpha is a
185 // graph-level param...
186 float alpha;
187 };
188
189 struct CacheKeyWrapper : ParamsWrapper<CacheKey> {
CacheKeyWrapperat::native::__anon11c0bfb50111::CacheKeyWrapper190 CacheKeyWrapper(
191 const cudnnBackendDescriptorType_t operation,
192 const Tensor& y,
193 const Tensor& x,
194 const Tensor& w,
195 const IntArrayRef padding,
196 const IntArrayRef stride,
197 const IntArrayRef dilation,
198 int64_t groups,
199 bool deterministic,
200 bool allow_tf32) {
201 at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(x, w);
202 setConvolutionParams(
203 &(this->pod.params),
204 x,
205 w,
206 padding,
207 stride,
208 dilation,
209 groups,
210 deterministic,
211 allow_tf32,
212 memory_format);
213 this->pod.operation = operation;
214 this->pod.x_alignment = getAlignment(x);
215 this->pod.y_alignment = getAlignment(y);
216 this->pod.w_alignment = getAlignment(w);
217 }
218 };
219
220 struct CacheKeyFusedWrapper : ParamsWrapper<CacheKeyFused> {
CacheKeyFusedWrapperat::native::__anon11c0bfb50111::CacheKeyFusedWrapper221 CacheKeyFusedWrapper(
222 const Tensor& y,
223 const Tensor& x,
224 const Tensor& w,
225 const Tensor& z,
226 const Tensor& b,
227 const float alpha,
228 const IntArrayRef padding,
229 const IntArrayRef stride,
230 const IntArrayRef dilation,
231 int64_t groups,
232 bool deterministic,
233 bool allow_tf32) {
234 at::MemoryFormat memory_format = cudnn_conv_suggest_memory_format(x, w);
235 setConvolutionParams(
236 &(this->pod).params,
237 x,
238 w,
239 padding,
240 stride,
241 dilation,
242 groups,
243 deterministic,
244 allow_tf32,
245 memory_format);
246 this->pod.x_alignment = getAlignment(x);
247 this->pod.y_alignment = getAlignment(y);
248 this->pod.w_alignment = getAlignment(w);
249 this->pod.z_alignment = getAlignment(z);
250 this->pod.b_alignment = getAlignment(b);
251 this->pod.alpha = alpha;
252 }
253 };
254
getLRUCacheLimit()255 static int getLRUCacheLimit() {
256 constexpr int DEFAULT_LIMIT =
257 10000; // roughly corresponds to 2GiB assuming 200KiB per ExecutionPlan
258 // 0 is used to indicate no limit
259 // negative values are used to indicate no caching
260 static int limit = [&] {
261 const char* val = getenv("TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT");
262 if (!val) {
263 return DEFAULT_LIMIT;
264 }
265 try {
266 return std::stoi(val);
267 } catch (std::invalid_argument const& e) {
268 TORCH_WARN(
269 "invalid TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT,",
270 " using default LRU cache limit of ",
271 DEFAULT_LIMIT,
272 " entries.");
273 } catch (std::out_of_range const& e) {
274 TORCH_WARN(
275 "invalid TORCH_CUDNN_V8_API_LRU_CACHE_LIMIT,",
276 " using default LRU cache limit of ",
277 DEFAULT_LIMIT,
278 " entries.");
279 }
280 return DEFAULT_LIMIT;
281 }();
282 return limit;
283 }
284
285 template <typename T, typename KeyType>
286 struct BenchmarkCache {
287 std::list<KeyType> engine_cache_order;
288 std::unordered_map<
289 KeyType,
290 std::pair<
291 cudnn_frontend::ExecutionPlan,
292 typename std::list<KeyType>::iterator>,
293 ParamsWrapperHash<KeyType>>
294 engine_cache;
295
296 // no mutexes here as caches are now thread local for v8, can also return a
297 // pointer to the Execution Plan if we know it will not be invalidated by
298 // another thread
findat::native::__anon11c0bfb50111::BenchmarkCache299 cudnn_frontend::ExecutionPlan* find(const KeyType& key) {
300 const int lru_cache_limit = getLRUCacheLimit();
301 if (lru_cache_limit < 0) {
302 return nullptr;
303 }
304 auto it = engine_cache.find(key);
305 if (it == engine_cache.end()) {
306 return nullptr;
307 }
308 if (lru_cache_limit) {
309 // update most recently accessed
310 engine_cache_order.splice(
311 engine_cache_order.begin(), engine_cache_order, it->second.second);
312 }
313 return &(it->second.first);
314 }
315
updateat::native::__anon11c0bfb50111::BenchmarkCache316 void update(const KeyType& key, T& results) {
317 int lru_cache_limit = getLRUCacheLimit();
318 if (lru_cache_limit < 0) {
319 return;
320 } else if (lru_cache_limit) {
321 auto it = engine_cache.find(key);
322 if (it == engine_cache.end()) {
323 if ((long)engine_cache.size() >= lru_cache_limit) {
324 auto erase_count = engine_cache.erase(engine_cache_order.back());
325 TORCH_INTERNAL_ASSERT(
326 erase_count == 1,
327 "CUDNN V8 LRU Cache Corrupted (eviction key not in map). Please report a bug to PyTorch.");
328 engine_cache_order.pop_back();
329 }
330 engine_cache_order.emplace_front(key);
331 engine_cache.emplace(
332 key, std::make_pair(results, engine_cache_order.begin()));
333 } else {
334 it->second.first = results;
335 // update most recently accessed
336 engine_cache_order.splice(
337 engine_cache_order.begin(), engine_cache_order, it->second.second);
338 }
339 } else {
340 engine_cache.erase(key);
341 engine_cache.emplace(
342 key,
343 std::make_pair(results, engine_cache_order.end())); // dummy iterator
344 }
345 }
346 };
347
348 // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to
349 // be thread safe across all engines see Limitations in
350 // https://docs.nvidia.com/deeplearning/cudnn/release-notes/index.html
351 thread_local BenchmarkCache<cudnn_frontend::ExecutionPlan, CacheKeyWrapper>
352 benchmark_cache;
353 thread_local BenchmarkCache<cudnn_frontend::ExecutionPlan, CacheKeyFusedWrapper>
354 benchmark_cache_fused;
355
356 } // namespace
357
run_conv_plan(cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const cudnn_frontend::ExecutionPlan & plan,const cudnnBackendDescriptorType_t operation)358 void run_conv_plan(
359 cudnnHandle_t handle,
360 const Tensor& x,
361 const Tensor& y,
362 const Tensor& w,
363 const cudnn_frontend::ExecutionPlan& plan,
364 const cudnnBackendDescriptorType_t operation) {
365 c10::DeviceGuard g(x.options().device());
366 auto workspace_size = plan.getWorkspaceSize();
367 auto workspace_ptr =
368 c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
369 void* data_ptrs[3];
370
371 if (operation == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
372 data_ptrs[0] = const_cast<void*>(x.const_data_ptr());
373 data_ptrs[1] = y.data_ptr();
374 data_ptrs[2] = const_cast<void*>(w.const_data_ptr());
375 } else if (
376 operation ==
377 CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
378 data_ptrs[0] = x.data_ptr();
379 data_ptrs[1] = const_cast<void*>(y.const_data_ptr());
380 data_ptrs[2] = const_cast<void*>(w.const_data_ptr());
381 } else if (
382 operation ==
383 CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
384 data_ptrs[0] = const_cast<void*>(x.const_data_ptr());
385 data_ptrs[1] = const_cast<void*>(y.const_data_ptr());
386 data_ptrs[2] = w.data_ptr();
387 } else {
388 data_ptrs[0] = x.data_ptr();
389 data_ptrs[1] = y.data_ptr();
390 data_ptrs[2] = w.data_ptr();
391 }
392
393 int64_t uids[] = {'x', 'y', 'w'};
394 auto variantPack =
395 cudnn_frontend::VariantPackBuilder()
396 .setWorkspacePointer(workspace_size ? workspace_ptr.get() : nullptr)
397 .setDataPointers(3, data_ptrs)
398 .setUids(3, uids)
399 .build();
400 AT_CUDNN_CHECK(cudnnBackendExecute(
401 handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
402 }
403
run_conv_plan_fused(cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b,const cudnn_frontend::ExecutionPlan & plan)404 void run_conv_plan_fused(
405 cudnnHandle_t handle,
406 const Tensor& x,
407 const Tensor& y,
408 const Tensor& w,
409 const Tensor& z,
410 const Tensor& b,
411 const cudnn_frontend::ExecutionPlan& plan) {
412 c10::DeviceGuard g(x.options().device());
413 auto workspace_size = plan.getWorkspaceSize();
414 auto workspace_ptr =
415 c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
416 void* data_ptrs[] = {
417 x.data_ptr(), y.data_ptr(), w.data_ptr(), z.data_ptr(), b.data_ptr()};
418 int64_t uids[] = {'x', 'y', 'w', 'z', 'b'};
419 auto variantPack =
420 cudnn_frontend::VariantPackBuilder()
421 .setWorkspacePointer(workspace_size ? workspace_ptr.get() : nullptr)
422 .setDataPointers(5, data_ptrs)
423 .setUids(5, uids)
424 .build();
425 AT_CUDNN_CHECK(cudnnBackendExecute(
426 handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
427 }
428
build_opgraph(const cudnnHandle_t handle,const cudnnBackendDescriptorType_t desc,const Tensor & x,const Tensor & y,const Tensor & w,const CacheKeyWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation)429 auto build_opgraph(
430 const cudnnHandle_t handle,
431 const cudnnBackendDescriptorType_t desc,
432 const Tensor& x,
433 const Tensor& y,
434 const Tensor& w,
435 const CacheKeyWrapper& key,
436 const IntArrayRef padding,
437 const IntArrayRef stride,
438 const IntArrayRef dilation) {
439 auto op = cudnn_frontend::OperationBuilder(desc)
440 .setxDesc(getTensorDescriptor(
441 x, 'x', key.pod.x_alignment, key.pod.params.memory_format))
442 .setyDesc(getTensorDescriptor(
443 y, 'y', key.pod.y_alignment, key.pod.params.memory_format))
444 .setwDesc(getTensorDescriptor(
445 w, 'w', key.pod.w_alignment, key.pod.params.memory_format))
446 .setcDesc(getConvDescriptor(
447 key.pod.params.dataType,
448 padding,
449 stride,
450 dilation,
451 x.scalar_type()))
452 .build();
453 std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
454 auto opGraph = cudnn_frontend::OperationGraphBuilder()
455 .setHandle(handle)
456 .setOperationGraph(ops.size(), ops.data())
457 .build();
458 return opGraph;
459 }
460
build_opgraph_fused(const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b,const float alpha,const CacheKeyFusedWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation)461 auto build_opgraph_fused(
462 const cudnnHandle_t handle,
463 const Tensor& x,
464 const Tensor& y,
465 const Tensor& w,
466 const Tensor& z,
467 const Tensor& b,
468 const float alpha,
469 const CacheKeyFusedWrapper& key,
470 const IntArrayRef padding,
471 const IntArrayRef stride,
472 const IntArrayRef dilation) {
473 // need computation to be done in FLOAT type regardless of reduced precision
474 // input
475 const auto precision = CUDNN_DATA_FLOAT;
476 auto addDesc = cudnn_frontend::PointWiseDescBuilder()
477 .setMode(CUDNN_POINTWISE_ADD)
478 .setMathPrecision(precision)
479 .build();
480 auto addBiasDesc = cudnn_frontend::PointWiseDescBuilder()
481 .setMode(CUDNN_POINTWISE_ADD)
482 .setMathPrecision(precision)
483 .build();
484 auto actDesc = cudnn_frontend::PointWiseDescBuilder()
485 .setMode(CUDNN_POINTWISE_RELU_FWD)
486 .setMathPrecision(precision)
487 .build();
488 auto convDesc = getConvDescriptor(
489 key.pod.params.dataType, padding, stride, dilation, x.scalar_type());
490 const float alpha1 = 1.0;
491 const float alpha2 = alpha;
492 auto conv_op =
493 cudnn_frontend::OperationBuilder(
494 CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
495 .setxDesc(getTensorDescriptor(
496 x, 'x', key.pod.x_alignment, key.pod.params.memory_format))
497 // virtual output of conv
498 .setyDesc(getTensorDescriptorWithTypeVirtual(
499 y,
500 'C',
501 key.pod.y_alignment,
502 precision,
503 key.pod.params.memory_format,
504 true))
505 .setwDesc(getTensorDescriptor(
506 w, 'w', key.pod.w_alignment, key.pod.params.memory_format))
507 .setAlpha(alpha1)
508 .setcDesc(convDesc)
509 .build();
510 auto add_op =
511 cudnn_frontend::OperationBuilder(
512 CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
513 .setxDesc(conv_op.getOutputTensor())
514 .setbDesc(getTensorDescriptor(
515 z, 'z', key.pod.z_alignment, key.pod.params.memory_format))
516 // another virtual output (of add)
517 .setyDesc(getTensorDescriptorWithTypeVirtual(
518 y,
519 'A',
520 key.pod.y_alignment,
521 precision,
522 key.pod.params.memory_format,
523 true))
524 .setpwDesc(addDesc)
525 .setAlpha(alpha1)
526 .setAlpha2(alpha2)
527 .build();
528 auto add_bias_op =
529 cudnn_frontend::OperationBuilder(
530 CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
531 .setxDesc(add_op.getOutputTensor())
532 .setbDesc(getTensorDescriptor(
533 b, 'b', key.pod.b_alignment, key.pod.params.memory_format))
534 // another virtual output (of add bias)
535 .setyDesc(getTensorDescriptorWithTypeVirtual(
536 y,
537 'B',
538 key.pod.y_alignment,
539 precision,
540 key.pod.params.memory_format,
541 true))
542 .setpwDesc(addBiasDesc)
543 .build();
544 auto act_op =
545 cudnn_frontend::OperationBuilder(
546 CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
547 .setxDesc(add_bias_op.getOutputTensor())
548 // final output is in original datatype
549 .setyDesc(getTensorDescriptor(
550 y, 'y', key.pod.y_alignment, key.pod.params.memory_format))
551 .setpwDesc(actDesc)
552 .build();
553 std::array<cudnn_frontend::Operation const*, 4> ops = {
554 &conv_op, &add_op, &add_bias_op, &act_op};
555 auto opGraph = cudnn_frontend::OperationGraphBuilder()
556 .setHandle(handle)
557 .setOperationGraph(ops.size(), ops.data())
558 .build();
559 return opGraph;
560 }
561
get_generator_sources(const cudnnBackendDescriptorType_t & desc,const Tensor & x,const bool deterministic,const bool allow_tf32,const cudnnBackendHeurMode_t heur_mode,const bool heuristic,const bool fallback)562 auto get_generator_sources(
563 const cudnnBackendDescriptorType_t& desc,
564 const Tensor& x,
565 const bool deterministic,
566 const bool allow_tf32,
567 const cudnnBackendHeurMode_t heur_mode,
568 const bool heuristic,
569 const bool fallback) {
570 // Method for engine config generator based on heuristics
571 const auto heurgen_method =
572 [/*&desc,*/ &x, deterministic, allow_tf32, heur_mode](
573 cudnn_frontend::OperationGraph& opGraph)
574 -> cudnn_frontend::EngineConfigList {
575 auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
576 .setOperationGraph(opGraph)
577 .setHeurMode(heur_mode)
578 .build();
579 auto& engine_configs =
580 heuristics.getEngineConfig(heuristics.getEngineConfigCount());
581 cudnn_frontend::EngineConfigList filtered_configs;
582 filterEngineConfigs(
583 engine_configs,
584 filtered_configs,
585 deterministic,
586 allow_tf32,
587 x.scalar_type());
588 return filtered_configs;
589 };
590 // Method for engine config generator based on fallback list
591 const auto fallback_method = [&desc, &x, deterministic, allow_tf32](
592 cudnn_frontend::OperationGraph& opGraph)
593 -> cudnn_frontend::EngineConfigList {
594 auto fallback = cudnn_frontend::EngineFallbackListBuilder()
595 .setOperationGraph(opGraph)
596 .setOperation(desc)
597 .build();
598 auto& fallback_list = fallback.getFallbackList();
599 cudnn_frontend::EngineConfigList filtered_configs;
600 filterEngineConfigs(
601 fallback_list,
602 filtered_configs,
603 deterministic,
604 allow_tf32,
605 x.scalar_type());
606 return filtered_configs;
607 };
608 if (heuristic && fallback) {
609 std::vector<cudnn_frontend::GeneratorSource> sources = {
610 heurgen_method, fallback_method};
611 return sources;
612 } else if (heuristic) {
613 std::vector<cudnn_frontend::GeneratorSource> sources = {heurgen_method};
614 return sources;
615 } else {
616 std::vector<cudnn_frontend::GeneratorSource> sources = {fallback_method};
617 return sources;
618 }
619 }
620
get_available_workspace()621 int64_t get_available_workspace() {
622 c10::DeviceIndex device = 0;
623 C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
624 size_t max_block_size = 0;
625 c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
626 return static_cast<int64_t>(max_block_size);
627 }
628
629 static nlohmann::json errata_json_handle;
630
plan_errata_exception(const cudnnHandle_t handle,const std::string & executionPlanTag)631 bool plan_errata_exception(
632 const cudnnHandle_t handle,
633 const std::string& executionPlanTag) {
634 static bool has_json =
635 cudnn_frontend::load_from_config(errata_json_handle, "");
636 if (!has_json) {
637 return false;
638 } else {
639 return cudnn_frontend::check_errata(
640 errata_json_handle, executionPlanTag, handle, []() { return true; });
641 }
642 }
643
generate_and_filter_plans(const cudnnHandle_t handle,cudnn_frontend::OperationGraph & opGraph,cudnn_frontend::EngineConfigGenerator & generator,const Tensor & x,cudnn_frontend::executionPlans_t & valid_plans,at::DataPtr & workspace_ptr)644 void generate_and_filter_plans(
645 const cudnnHandle_t handle,
646 cudnn_frontend::OperationGraph& opGraph,
647 cudnn_frontend::EngineConfigGenerator& generator,
648 const Tensor& x,
649 cudnn_frontend::executionPlans_t& valid_plans,
650 at::DataPtr& workspace_ptr) {
651 auto initial_predicate_function =
652 [&](cudnn_frontend::ExecutionPlan const& plan) -> bool {
653 return plan_errata_exception(handle, plan.getTag());
654 };
655 auto plans =
656 generator.cudnnGetPlan(handle, opGraph, initial_predicate_function);
657 int64_t max_block_size = get_available_workspace();
658 int64_t max_workspace_size = 0;
659 std::for_each(
660 plans.begin(), plans.end(), [&](cudnn_frontend::ExecutionPlan& plan) {
661 int64_t curr_workspace_size = plan.getWorkspaceSize();
662 if (curr_workspace_size <= max_block_size) {
663 if (curr_workspace_size > max_workspace_size) {
664 max_workspace_size = plan.getWorkspaceSize();
665 }
666 valid_plans.emplace_back(std::move(plan));
667 }
668 });
669 TORCH_CHECK_WITH(
670 OutOfMemoryError,
671 max_workspace_size < 1_TiB,
672 "Not enough memory for workspace!");
673 bool remove_invalid = false;
674 while (max_workspace_size) {
675 try {
676 workspace_ptr =
677 c10::cuda::CUDACachingAllocator::get()->allocate(max_workspace_size);
678 break;
679 } catch (c10::OutOfMemoryError& e) {
680 max_workspace_size /= 2;
681 (void)cudaGetLastError(); // clear CUDA error
682 remove_invalid = true;
683 }
684 }
685 if (remove_invalid) {
686 cudnn_frontend::executionPlans_t new_valid_plans;
687 for (auto& plan : valid_plans) {
688 if (plan.getWorkspaceSize() <= max_workspace_size) {
689 new_valid_plans.emplace_back(std::move(plan));
690 }
691 }
692 valid_plans = std::move(new_valid_plans);
693 }
694 }
695
get_plans_from_find(const cudnnHandle_t handle,const cudnnBackendDescriptorType_t desc,const Tensor & x,const Tensor & y,const Tensor & w,const CacheKeyWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const bool deterministic,const bool allow_tf32)696 auto get_plans_from_find(
697 const cudnnHandle_t handle,
698 const cudnnBackendDescriptorType_t desc,
699 const Tensor& x,
700 const Tensor& y,
701 const Tensor& w,
702 const CacheKeyWrapper& key,
703 const IntArrayRef padding,
704 const IntArrayRef stride,
705 const IntArrayRef dilation,
706 const bool deterministic,
707 const bool allow_tf32) {
708 auto opGraph =
709 build_opgraph(handle, desc, x, y, w, key, padding, stride, dilation);
710 void* data_ptrs[] = {x.data_ptr(), y.data_ptr(), w.data_ptr()};
711 int64_t uids[] = {'x', 'y', 'w'};
712 // We don't care about getting the best ordering of algos if we're roing to
713 // run all of them
714 auto sources = get_generator_sources(
715 desc, x, deterministic, allow_tf32, CUDNN_HEUR_MODE_INSTANT, true, true);
716 cudnn_frontend::EngineConfigGenerator generator(
717 sources.size(), sources.data());
718 cudnn_frontend::executionPlans_t valid_plans;
719 c10::DeviceGuard g(x.options().device());
720 at::DataPtr workspace_ptr;
721 generate_and_filter_plans(
722 handle, opGraph, generator, x, valid_plans, workspace_ptr);
723 auto variantPack =
724 cudnn_frontend::VariantPackBuilder()
725 .setDataPointers(3, data_ptrs)
726 .setUids(3, uids)
727 .setWorkspacePointer(workspace_ptr ? workspace_ptr.get() : nullptr)
728 .build();
729
730 auto benchmark_limit = at::globalContext().benchmarkLimitCuDNN();
731 benchmark_limit = benchmark_limit ? benchmark_limit : 10000;
732 auto plans = cudnn_frontend::time_sorted_plan<
733 cudnn_frontend::CudnnFindSamplingTechnique::CUDNN_FIND_SAMPLE_ONCE>(
734 handle, std::move(valid_plans), variantPack, benchmark_limit);
735
736 cudnn_frontend::executionPlans_t sorted_plans;
737 for (auto& plan : plans) {
738 sorted_plans.emplace_back(std::move(plan));
739 }
740 return sorted_plans;
741 }
742
get_plans_from_find_fused(const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b,const float alpha,const CacheKeyFusedWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const bool deterministic,const bool allow_tf32)743 auto get_plans_from_find_fused(
744 const cudnnHandle_t handle,
745 const Tensor& x,
746 const Tensor& y,
747 const Tensor& w,
748 const Tensor& z,
749 const Tensor& b,
750 const float alpha,
751 const CacheKeyFusedWrapper& key,
752 const IntArrayRef padding,
753 const IntArrayRef stride,
754 const IntArrayRef dilation,
755 const bool deterministic,
756 const bool allow_tf32) {
757 auto opGraph = build_opgraph_fused(
758 handle, x, y, w, z, b, alpha, key, padding, stride, dilation);
759 void* data_ptrs[] = {
760 x.data_ptr(), y.data_ptr(), w.data_ptr(), z.data_ptr(), b.data_ptr()};
761 int64_t uids[] = {'x', 'y', 'w', 'z', 'b'};
762
763 auto sources = get_generator_sources(
764 CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
765 x,
766 deterministic,
767 allow_tf32,
768 CUDNN_HEUR_MODE_INSTANT,
769 true,
770 true);
771 cudnn_frontend::EngineConfigGenerator generator(
772 sources.size(), sources.data());
773 cudnn_frontend::executionPlans_t valid_plans;
774 c10::DeviceGuard g(x.options().device());
775 at::DataPtr workspace_ptr;
776 generate_and_filter_plans(
777 handle, opGraph, generator, x, valid_plans, workspace_ptr);
778 auto variantPack =
779 cudnn_frontend::VariantPackBuilder()
780 .setDataPointers(5, data_ptrs)
781 .setUids(5, uids)
782 .setWorkspacePointer(workspace_ptr ? workspace_ptr.get() : nullptr)
783 .build();
784
785 auto benchmark_limit = at::globalContext().benchmarkLimitCuDNN();
786 benchmark_limit = benchmark_limit ? benchmark_limit : 10000;
787 auto plans = cudnn_frontend::time_sorted_plan<
788 cudnn_frontend::CudnnFindSamplingTechnique::CUDNN_FIND_SAMPLE_ONCE>(
789 handle, std::move(valid_plans), variantPack, benchmark_limit);
790
791 cudnn_frontend::executionPlans_t sorted_plans;
792 for (auto& plan : plans) {
793 sorted_plans.emplace_back(std::move(plan));
794 }
795 return sorted_plans;
796 }
797
798 // We only get configs from this stage to avoid building unnecessary plans that
799 // are never executed
get_configs_from_heuristics(const cudnnHandle_t handle,const cudnnBackendDescriptorType_t desc,std::string & opgraph_tag,const Tensor & x,const Tensor & y,const Tensor & w,const CacheKeyWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const bool deterministic,const bool allow_tf32,const bool fallback)800 auto get_configs_from_heuristics(
801 const cudnnHandle_t handle,
802 const cudnnBackendDescriptorType_t desc,
803 std::string& opgraph_tag,
804 const Tensor& x,
805 const Tensor& y,
806 const Tensor& w,
807 const CacheKeyWrapper& key,
808 const IntArrayRef padding,
809 const IntArrayRef stride,
810 const IntArrayRef dilation,
811 const bool deterministic,
812 const bool allow_tf32,
813 const bool fallback) {
814 auto opGraph =
815 build_opgraph(handle, desc, x, y, w, key, padding, stride, dilation);
816 opgraph_tag = opGraph.getTag();
817 auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b()
818 ? CUDNN_HEUR_MODE_B
819 : CUDNN_HEUR_MODE_INSTANT;
820 auto sources = get_generator_sources(
821 desc, x, deterministic, allow_tf32, heuristic_mode, !fallback, fallback);
822
823 cudnn_frontend::EngineConfigGenerator generator(
824 sources.size(), sources.data());
825 auto configs = generator.generate_engine_config(opGraph);
826 return configs;
827 }
828
get_configs_from_heuristics_fused(const cudnnHandle_t handle,std::string & opgraph_tag,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b,const float alpha,const CacheKeyFusedWrapper & key,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const bool deterministic,const bool allow_tf32,const bool fallback)829 auto get_configs_from_heuristics_fused(
830 const cudnnHandle_t handle,
831 std::string& opgraph_tag,
832 const Tensor& x,
833 const Tensor& y,
834 const Tensor& w,
835 const Tensor& z,
836 const Tensor& b,
837 const float alpha,
838 const CacheKeyFusedWrapper& key,
839 const IntArrayRef padding,
840 const IntArrayRef stride,
841 const IntArrayRef dilation,
842 const bool deterministic,
843 const bool allow_tf32,
844 const bool fallback) {
845 auto opGraph = build_opgraph_fused(
846 handle, x, y, w, z, b, alpha, key, padding, stride, dilation);
847 opgraph_tag = opGraph.getTag();
848 auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b()
849 ? CUDNN_HEUR_MODE_B
850 : CUDNN_HEUR_MODE_INSTANT;
851 auto sources = get_generator_sources(
852 CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
853 x,
854 deterministic,
855 allow_tf32,
856 heuristic_mode,
857 !fallback,
858 fallback);
859
860 cudnn_frontend::EngineConfigGenerator generator(
861 sources.size(), sources.data());
862 auto configs = generator.generate_engine_config(opGraph);
863 return configs;
864 }
865
try_plans(cudnn_frontend::executionPlans_t & plans,const CacheKeyWrapper & key,const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const cudnnBackendDescriptorType_t operation)866 void try_plans(
867 cudnn_frontend::executionPlans_t& plans,
868 const CacheKeyWrapper& key,
869 const cudnnHandle_t handle,
870 const Tensor& x,
871 const Tensor& y,
872 const Tensor& w,
873 const cudnnBackendDescriptorType_t operation) {
874 for (auto& plan : plans) {
875 try {
876 run_conv_plan(handle, x, y, w, plan, operation);
877 benchmark_cache.update(key, plan);
878 return;
879 } catch (cudnn_frontend::cudnnException& e) {
880 } catch (CuDNNError& e) {
881 } catch (c10::OutOfMemoryError& e) {
882 (void)cudaGetLastError(); // clear CUDA error
883 }
884 }
885 TORCH_CHECK(
886 false, "FIND was unable to find an engine to execute this computation");
887 }
888
try_plans_fused(cudnn_frontend::executionPlans_t & plans,const CacheKeyFusedWrapper & key,const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b)889 void try_plans_fused(
890 cudnn_frontend::executionPlans_t& plans,
891 const CacheKeyFusedWrapper& key,
892 const cudnnHandle_t handle,
893 const Tensor& x,
894 const Tensor& y,
895 const Tensor& w,
896 const Tensor& z,
897 const Tensor& b) {
898 for (auto& plan : plans) {
899 try {
900 run_conv_plan_fused(handle, x, y, w, z, b, plan);
901 benchmark_cache_fused.update(key, plan);
902 return;
903 } catch (cudnn_frontend::cudnnException& e) {
904 } catch (CuDNNError& e) {
905 } catch (c10::OutOfMemoryError& e) {
906 (void)cudaGetLastError(); // clear CUDA error
907 }
908 }
909 TORCH_CHECK(
910 false, "FIND was unable to find an engine to execute this computation");
911 }
912
try_configs(cudnn_frontend::EngineConfigList & configs,const std::string & opgraph_tag,const CacheKeyWrapper & key,const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const cudnnBackendDescriptorType_t operation)913 bool try_configs(
914 cudnn_frontend::EngineConfigList& configs,
915 const std::string& opgraph_tag,
916 const CacheKeyWrapper& key,
917 const cudnnHandle_t handle,
918 const Tensor& x,
919 const Tensor& y,
920 const Tensor& w,
921 const cudnnBackendDescriptorType_t operation) {
922 for (auto& config : configs) {
923 try {
924 auto plan = cudnn_frontend::ExecutionPlanBuilder()
925 .setHandle(handle)
926 .setEngineConfig(config, opgraph_tag)
927 .build();
928 if (plan_errata_exception(handle, plan.getTag())) {
929 continue;
930 }
931 run_conv_plan(handle, x, y, w, plan, operation);
932 benchmark_cache.update(key, plan);
933 return true;
934 } catch (cudnn_frontend::cudnnException& e) {
935 } catch (CuDNNError& e) {
936 } catch (c10::OutOfMemoryError& e) {
937 (void)cudaGetLastError(); // clear CUDA error
938 }
939 }
940 return false;
941 }
942
try_configs_fused(cudnn_frontend::EngineConfigList & configs,const std::string & opgraph_tag,const CacheKeyFusedWrapper & key,const cudnnHandle_t handle,const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b)943 bool try_configs_fused(
944 cudnn_frontend::EngineConfigList& configs,
945 const std::string& opgraph_tag,
946 const CacheKeyFusedWrapper& key,
947 const cudnnHandle_t handle,
948 const Tensor& x,
949 const Tensor& y,
950 const Tensor& w,
951 const Tensor& z,
952 const Tensor& b) {
953 for (auto& config : configs) {
954 try {
955 auto plan = cudnn_frontend::ExecutionPlanBuilder()
956 .setHandle(handle)
957 .setEngineConfig(config, opgraph_tag)
958 .build();
959 if (plan_errata_exception(handle, plan.getTag())) {
960 continue;
961 }
962 run_conv_plan_fused(handle, x, y, w, z, b, plan);
963 benchmark_cache_fused.update(key, plan);
964 return true;
965 } catch (cudnn_frontend::cudnnException& e) {
966 } catch (CuDNNError& e) {
967 } catch (c10::OutOfMemoryError& e) {
968 (void)cudaGetLastError(); // clear CUDA error
969 }
970 }
971 return false;
972 }
973
run_single_conv(const cudnnBackendDescriptorType_t operation,const Tensor & x,const Tensor & y,const Tensor & w,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool benchmark,const bool deterministic,const bool allow_tf32)974 void run_single_conv(
975 const cudnnBackendDescriptorType_t operation,
976 const Tensor& x,
977 const Tensor& y,
978 const Tensor& w,
979 const IntArrayRef padding,
980 const IntArrayRef stride,
981 const IntArrayRef dilation,
982 const int64_t groups,
983 const bool benchmark,
984 const bool deterministic,
985 const bool allow_tf32) {
986 cudnnHandle_t handle = getCudnnHandle();
987 CacheKeyWrapper key(
988 operation,
989 y,
990 x,
991 w,
992 padding,
993 stride,
994 dilation,
995 groups,
996 deterministic,
997 allow_tf32);
998 // TODO: is this thread safe if cache is updated? is pointer stale?
999 auto search = benchmark_cache.find(key);
1000 if (search) {
1001 try {
1002 run_conv_plan(handle, x, y, w, *search, operation);
1003 return;
1004 } catch (c10::OutOfMemoryError& e) {
1005 (void)cudaGetLastError(); // clear CUDA error
1006 }
1007 }
1008 if (!benchmark) {
1009 std::string opgraph_tag; // extra data needed for errata filter
1010 // heuristic configs
1011 cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics(
1012 handle,
1013 operation,
1014 opgraph_tag,
1015 x,
1016 y,
1017 w,
1018 key,
1019 padding,
1020 stride,
1021 dilation,
1022 deterministic,
1023 allow_tf32,
1024 false);
1025 if (try_configs(configs, opgraph_tag, key, handle, x, y, w, operation)) {
1026 return;
1027 }
1028 // fallback configs
1029 configs = get_configs_from_heuristics(
1030 handle,
1031 operation,
1032 opgraph_tag,
1033 x,
1034 y,
1035 w,
1036 key,
1037 padding,
1038 stride,
1039 dilation,
1040 deterministic,
1041 allow_tf32,
1042 true);
1043 if (try_configs(configs, opgraph_tag, key, handle, x, y, w, operation)) {
1044 return;
1045 }
1046 TORCH_CHECK(
1047 false, "GET was unable to find an engine to execute this computation");
1048 } else {
1049 cudnn_frontend::executionPlans_t plans = get_plans_from_find(
1050 handle,
1051 operation,
1052 x,
1053 y,
1054 w,
1055 key,
1056 padding,
1057 stride,
1058 dilation,
1059 deterministic,
1060 allow_tf32);
1061 // Replicate v7 behavior: clear cached blocks as benchmark incurs
1062 // significant memory consumptiont that is not needed after this step
1063 if (at::native::_cudnn_get_conv_benchmark_empty_cache()) {
1064 c10::cuda::CUDACachingAllocator::emptyCache();
1065 }
1066 try_plans(plans, key, handle, x, y, w, operation);
1067 }
1068 }
1069
run_fused_conv(const Tensor & x,const Tensor & y,const Tensor & w,const Tensor & z,const Tensor & b,float alpha,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,const bool benchmark,const bool deterministic,const bool allow_tf32)1070 void run_fused_conv(
1071 const Tensor& x,
1072 const Tensor& y,
1073 const Tensor& w,
1074 const Tensor& z,
1075 const Tensor& b,
1076 float alpha,
1077 IntArrayRef stride,
1078 IntArrayRef padding,
1079 IntArrayRef dilation,
1080 int64_t groups,
1081 const bool benchmark,
1082 const bool deterministic,
1083 const bool allow_tf32) {
1084 cudnnHandle_t handle = getCudnnHandle();
1085
1086 CacheKeyFusedWrapper key(
1087 y,
1088 x,
1089 w,
1090 z,
1091 b,
1092 alpha,
1093 padding,
1094 stride,
1095 dilation,
1096 groups,
1097 deterministic,
1098 allow_tf32);
1099 auto search = benchmark_cache_fused.find(key);
1100 if (search) {
1101 try {
1102 run_conv_plan_fused(handle, x, y, w, z, b, *search);
1103 return;
1104 } catch (c10::OutOfMemoryError& e) {
1105 (void)cudaGetLastError(); // clear CUDA error
1106 }
1107 }
1108 if (!benchmark) {
1109 std::string opgraph_tag; // extra data needed for errata filter
1110 // heuristic configs
1111 cudnn_frontend::EngineConfigList configs =
1112 get_configs_from_heuristics_fused(
1113 handle,
1114 opgraph_tag,
1115 x,
1116 y,
1117 w,
1118 z,
1119 b,
1120 alpha,
1121 key,
1122 padding,
1123 stride,
1124 dilation,
1125 deterministic,
1126 allow_tf32,
1127 false);
1128 if (try_configs_fused(configs, opgraph_tag, key, handle, x, y, w, z, b)) {
1129 return;
1130 }
1131 // fallback configs
1132 configs = get_configs_from_heuristics_fused(
1133 handle,
1134 opgraph_tag,
1135 x,
1136 y,
1137 w,
1138 z,
1139 b,
1140 alpha,
1141 key,
1142 padding,
1143 stride,
1144 dilation,
1145 deterministic,
1146 allow_tf32,
1147 true);
1148 if (try_configs_fused(configs, opgraph_tag, key, handle, x, y, w, z, b)) {
1149 return;
1150 }
1151 TORCH_CHECK(
1152 false, "GET was unable to find an engine to execute this computation");
1153 } else {
1154 cudnn_frontend::executionPlans_t plans = get_plans_from_find_fused(
1155 handle,
1156 x,
1157 y,
1158 w,
1159 z,
1160 b,
1161 alpha,
1162 key,
1163 padding,
1164 stride,
1165 dilation,
1166 deterministic,
1167 allow_tf32);
1168 try_plans_fused(plans, key, handle, x, y, w, z, b);
1169 }
1170 }
1171
raw_cudnn_convolution_forward_out(const Tensor & output,const Tensor & input,const Tensor & weight,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool benchmark,const bool deterministic,const bool allow_tf32)1172 void raw_cudnn_convolution_forward_out(
1173 const Tensor& output,
1174 const Tensor& input,
1175 const Tensor& weight,
1176 const IntArrayRef padding,
1177 const IntArrayRef stride,
1178 const IntArrayRef dilation,
1179 const int64_t groups,
1180 const bool benchmark,
1181 const bool deterministic,
1182 const bool allow_tf32) {
1183 if (output.numel() == 0) {
1184 return;
1185 }
1186 if (at::native::cudnnv8_enabled_check_debug()) {
1187 run_single_conv(
1188 CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR,
1189 input,
1190 output,
1191 weight,
1192 padding,
1193 stride,
1194 dilation,
1195 groups,
1196 benchmark,
1197 deterministic,
1198 allow_tf32);
1199 } else {
1200 raw_cudnn_convolution_forward_out_v7(
1201 output,
1202 input,
1203 weight,
1204 padding,
1205 stride,
1206 dilation,
1207 groups,
1208 benchmark,
1209 deterministic,
1210 allow_tf32);
1211 }
1212 }
1213
raw_cudnn_convolution_backward_input_out(const at::Tensor & grad_input,const at::Tensor & grad_output,const at::Tensor & weight,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool benchmark,const bool deterministic,const bool allow_tf32)1214 void raw_cudnn_convolution_backward_input_out(
1215 const at::Tensor& grad_input,
1216 const at::Tensor& grad_output,
1217 const at::Tensor& weight,
1218 const IntArrayRef padding,
1219 const IntArrayRef stride,
1220 const IntArrayRef dilation,
1221 const int64_t groups,
1222 const bool benchmark,
1223 const bool deterministic,
1224 const bool allow_tf32) {
1225 if (grad_input.numel() == 0) {
1226 return;
1227 }
1228 if (at::native::cudnnv8_enabled_check_debug()) {
1229 run_single_conv(
1230 CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR,
1231 grad_input,
1232 grad_output,
1233 weight,
1234 padding,
1235 stride,
1236 dilation,
1237 groups,
1238 benchmark,
1239 deterministic,
1240 allow_tf32);
1241 } else {
1242 raw_cudnn_convolution_backward_input_out_v7(
1243 grad_input,
1244 grad_output,
1245 weight,
1246 padding,
1247 stride,
1248 dilation,
1249 groups,
1250 benchmark,
1251 deterministic,
1252 allow_tf32);
1253 }
1254 }
1255
raw_cudnn_convolution_backward_weight_out(const Tensor & grad_weight,const Tensor & grad_output,const Tensor & input,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const bool benchmark,const bool deterministic,const bool allow_tf32)1256 void raw_cudnn_convolution_backward_weight_out(
1257 const Tensor& grad_weight,
1258 const Tensor& grad_output,
1259 const Tensor& input,
1260 const IntArrayRef padding,
1261 const IntArrayRef stride,
1262 const IntArrayRef dilation,
1263 const int64_t groups,
1264 const bool benchmark,
1265 const bool deterministic,
1266 const bool allow_tf32) {
1267 if (grad_weight.numel() == 0) {
1268 return;
1269 }
1270 if (at::native::cudnnv8_enabled_check_debug()) {
1271 run_single_conv(
1272 CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR,
1273 input,
1274 grad_output,
1275 grad_weight,
1276 padding,
1277 stride,
1278 dilation,
1279 groups,
1280 benchmark,
1281 deterministic,
1282 allow_tf32);
1283 } else {
1284 raw_cudnn_convolution_backward_weight_out_v7(
1285 grad_weight,
1286 grad_output,
1287 input,
1288 padding,
1289 stride,
1290 dilation,
1291 groups,
1292 benchmark,
1293 deterministic,
1294 allow_tf32);
1295 }
1296 }
1297
raw_cudnn_convolution_add_relu_out(const Tensor & output,const Tensor & input,const Tensor & weight,const Tensor & z,float alpha,const Tensor & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups,bool benchmark,bool deterministic,bool allow_tf32)1298 void raw_cudnn_convolution_add_relu_out(
1299 const Tensor& output,
1300 const Tensor& input,
1301 const Tensor& weight,
1302 const Tensor& z,
1303 float alpha,
1304 const Tensor& bias,
1305 IntArrayRef stride,
1306 IntArrayRef padding,
1307 IntArrayRef dilation,
1308 int64_t groups,
1309 bool benchmark,
1310 bool deterministic,
1311 bool allow_tf32) {
1312 if (output.numel() == 0) {
1313 return;
1314 }
1315 if (at::native::cudnnv8_enabled_check_debug()) {
1316 auto bias_ = input.ndimension() == 4
1317 ? bias.view({1, bias.numel(), 1, 1})
1318 : bias.view({1, bias.numel(), 1, 1, 1});
1319 run_fused_conv(
1320 input,
1321 output,
1322 weight,
1323 z,
1324 bias_,
1325 alpha,
1326 stride,
1327 padding,
1328 dilation,
1329 groups,
1330 benchmark,
1331 deterministic,
1332 allow_tf32);
1333 } else {
1334 raw_cudnn_convolution_add_relu_out_v7(
1335 output,
1336 input,
1337 weight,
1338 z,
1339 alpha,
1340 bias,
1341 stride,
1342 padding,
1343 dilation,
1344 groups,
1345 benchmark,
1346 deterministic,
1347 allow_tf32);
1348 }
1349 }
1350
1351 } // namespace native
1352 } // namespace at
1353
1354 #endif // AT_CUDNN_ENABLED
1355