xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cudnn/Descriptors.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <string>
4 
5 #include <ATen/cuda/CUDAContext.h>
6 #include <ATen/cuda/Exceptions.h>
7 
8 #include <ATen/cudnn/cudnn-wrapper.h>
9 #include <ATen/cudnn/Utils.h>
10 #include <ATen/core/Tensor.h>
11 #include <ATen/TensorUtils.h>
12 #include <ATen/cuda/ATenCUDAGeneral.h>
13 #include <cuda.h>
14 
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #else
18 #include <ATen/ops/empty.h>
19 #endif
20 
21 #if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8907
22 #define USE_CUDNN_RNN_V8_API
23 #endif
24 
25 namespace at::native {
26 
27 std::string cudnnTypeToString(cudnnDataType_t dtype);
28 
29 // TODO: Add constructors for all of the descriptors
30 
dataSize(cudnnDataType_t dataType)31 inline int dataSize(cudnnDataType_t dataType)
32 {
33   switch (dataType) {
34     case CUDNN_DATA_BFLOAT16:
35     case CUDNN_DATA_HALF: return 2;
36     case CUDNN_DATA_FLOAT: return 4;
37     default: return 8;
38   }
39 }
40 
41 // The stride for a size-1 dimensions is not uniquely determined; in
42 // fact, it can be anything you want, because the fact that the
43 // tensor is size 1 at this dimension means that you will never actually
44 // try advancing your pointer by this stride.
45 //
46 // However, CuDNN has a much more stringent requirement on strides:
47 // if you are passing a contiguous input, it better be the case
48 // that the stride for dim i is the product of the sizes of dims
49 // i+1 to the end.  This stride is indeed uniquely determined.  This
50 // function modifies 'stride' in place so this invariant holds.
51 template <typename T>
fixSizeOneDimStride(int dim,const T * size,T * stride,bool nhwc)52 static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) {
53   int64_t z = 1;
54   int index = 0;
55   std::vector<int> permutation(dim);
56 
57   if (nhwc) {
58     permutation[index++] = 1;
59   }
60   for (int d = dim-1; d > 1; d--) {
61     permutation[index++] = d;
62   }
63   if (!nhwc) {
64     permutation[index++] = 1;
65   }
66   permutation[index++] = 0;
67   for (int d : permutation) {
68     if (size[d] == 1) {
69       stride[d] = z;
70     } else {
71       z *= size[d];
72     }
73   }
74 }
75 
76 template <typename T, cudnnStatus_t (*dtor)(T*)>
77 struct DescriptorDeleter {
operatorDescriptorDeleter78   void operator()(T* x) {
79     if (x != nullptr) {
80       AT_CUDNN_CHECK(dtor(x));
81     }
82   }
83 };
84 
85 // A generic class for wrapping cuDNN descriptor types.  All you need
86 // is to give the underlying type the Descriptor_t points to (usually,
87 // if it's cudnnTensorDescriptor_t it points to cudnnTensorStruct),
88 // the constructor and the destructor.  Subclasses are responsible
89 // for defining a set() function to actually set the descriptor.
90 //
91 // Descriptors default construct to a nullptr, and have a descriptor
92 // initialized the first time you call set() or any other initializing
93 // function.
94 template <typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)>
95 class TORCH_CUDA_CPP_API Descriptor {
96  public:
97   // TODO: Figure out why const-correctness doesn't work here
98 
99   // Use desc() to access the underlying descriptor pointer in
100   // a read-only fashion.  Most client code should use this.
101   // If the descriptor was never initialized, this will return
102   // nullptr.
desc()103   T* desc() const { return desc_.get(); }
desc()104   T* desc() { return desc_.get(); }
105 
106   // Use mut_desc() to access the underlying descriptor pointer
107   // if you intend to modify what it points to (e.g., using
108   // cudnnSetFooDescriptor).  This will ensure that the descriptor
109   // is initialized.  Code in this file will use this function.
mut_desc()110   T* mut_desc() { init(); return desc_.get(); }
111 protected:
init()112   void init() {
113     if (desc_ == nullptr) {
114       T* raw_desc = nullptr;
115       AT_CUDNN_CHECK(ctor(&raw_desc));
116       desc_.reset(raw_desc);
117     }
118   }
119 private:
120   std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
121 };
122 
123 class TORCH_CUDA_CPP_API RNNDataDescriptor : public Descriptor<
124                                        cudnnRNNDataStruct,
125                                        &cudnnCreateRNNDataDescriptor,
126                                        &cudnnDestroyRNNDataDescriptor> {
127 public:
128   void set(const at::Tensor &t, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray);
129 private:
set(cudnnDataType_t dataType,cudnnRNNDataLayout_t layout,int maxSeqLength,int batchSize,int vectorSize,const int * seqLengthArray)130   void set(cudnnDataType_t dataType, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray) {
131     AT_CUDNN_CHECK(cudnnSetRNNDataDescriptor(mut_desc(), dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, NULL));
132   }
133 };
134 
135 class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
136                                                cudnnTensorStruct,
137                                                &cudnnCreateTensorDescriptor,
138                                                &cudnnDestroyTensorDescriptor> {
139  public:
140   TensorDescriptor() = default;
141   explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
142     set(t, pad);
143   }
144 
145   // Note [CuDNN broadcast padding]
146   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
147   // pad specifies the minimum dimensionality of the tensor descriptor
148   // we produce (it doesn't have anything to do with, e.g., convolution
149   // padding).  If 't' is lower-dimensional than 'pad', the remaining
150   // dimensions (on the right) are padded with ones.  This doesn't
151   // affect the underlying data layout.  This is particularly useful for
152   // dealing with a peculiarity of the CuDNN API, which is that broadcasting in CuDNN is
153   // done in two steps: first, the client code is expected to pad out
154   // (the dimensions) input tensors to be the same dimension as the
155   // target broadcast, and then second, CuDNN takes of actually
156   // broadcasting size 1 dimensions.
157 
158   void set(const at::Tensor &t, size_t pad = 0);
159   void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0);
160   void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
161 
162   void print();
163 
164 private:
165   void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc);
166 
set(cudnnDataType_t dataType,int dim,int * size,int * stride,bool nhwc)167   void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
168     std::vector<int> strides_copy(stride, stride + dim);
169     fixSizeOneDimStride<int>(dim, size, strides_copy.data(), nhwc);
170     AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, strides_copy.data()));
171   }
172 };
173 
174 std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
175 
176 class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
177                                                cudnnFilterStruct,
178                                                &cudnnCreateFilterDescriptor,
179                                                &cudnnDestroyFilterDescriptor> {
180  public:
181   void set(const at::Tensor &t, int64_t pad = 0) {
182     set(t, at::MemoryFormat::Contiguous, pad);
183   }
184 
185   void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
186 
187   void print();
188 private:
set(cudnnDataType_t dataType,int dim,int * size,cudnnTensorFormat_t filter_format)189   void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) {
190     AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
191   }
192 };
193 
194 std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d);
195 
196 struct TORCH_CUDA_CPP_API ConvolutionDescriptor
197     : public Descriptor<
198           cudnnConvolutionStruct,
199           &cudnnCreateConvolutionDescriptor,
200           &cudnnDestroyConvolutionDescriptor> {
setConvolutionDescriptor201   void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool allow_tf32) {
202     cudnnDataType_t mathType = dataType;
203     if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
204     AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale,
205                                           CUDNN_CROSS_CORRELATION, mathType));
206     AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
207     // See Note [behavior of cudnnFind and cudnnGet]
208     AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
209     if(dataType == CUDNN_DATA_HALF) {
210       AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
211     } else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
212       AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
213     }
214   }
215 };
216 
217 struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor
218     : public Descriptor<
219           cudnnSpatialTransformerStruct,
220           &cudnnCreateSpatialTransformerDescriptor,
221           &cudnnDestroySpatialTransformerDescriptor> {
setSpatialTransformerDescriptor222   void set(cudnnDataType_t dataType, int dim, int* size) {
223     AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size));
224   }
225 };
226 
227 struct TORCH_CUDA_CPP_API DropoutDescriptor
228     : public Descriptor<
229           cudnnDropoutStruct,
230           &cudnnCreateDropoutDescriptor,
231           &cudnnDestroyDropoutDescriptor> {
232   at::Tensor state;
233 
234   // Initialize a dropout descriptor's RNG state.
235   // WARNING: This function is very expensive, avoid calling this function!
initialize_rngDropoutDescriptor236   void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions& options) {
237     TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
238     size_t state_size = 0;
239     AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
240     AT_ASSERT(options.device().type() == kCUDA);
241     AT_ASSERT(options.dtype() == kByte);
242     state = at::empty({static_cast<int64_t>(state_size)}, options);
243     AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
244   }
245 
246   // Restore a dropout descriptor given a dropout probability and existing RNG state.
setDropoutDescriptor247   void set(cudnnHandle_t handle, float dropout, at::Tensor state_) {
248     TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
249     state = state_;
250     void *state_ptr = state.data_ptr();
251     size_t state_size = state.size(0);
252     // NB: The seed doesn't actually matter, so we give a dummy value
253     AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */));
254   }
255 
256   // Restore a dropout descriptor corresponding to no dropout
set_no_dropoutDropoutDescriptor257   void set_no_dropout(cudnnHandle_t handle) {
258     // NB: seed doesn't matter when dropout = 0, because no random number
259     // initialization actually takes place when there is no dropout.
260     // NB: Empirically, cudnnSetDropoutDescriptor is cheap when
261     // dropout == 0
262     AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 /* dropout */, nullptr, 0 /* state_size */, 0 /* seed */));
263   }
264 };
265 
266 struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor<
267                                              cudnnRNNStruct,
268                                              &cudnnCreateRNNDescriptor,
269                                              &cudnnDestroyRNNDescriptor> {
270   DropoutDescriptor dropout_desc_;
setRNNDescriptor271   void set(cudnnHandle_t handle,
272 #ifdef USE_CUDNN_RNN_V8_API
273        int input_size,
274        bool packed,
275 #endif
276        int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc,
277            cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
278            cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) {
279     dropout_desc_ = std::move(dropout_desc);
280 #ifndef USE_CUDNN_RNN_V8_API
281     AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
282           handle,
283           mut_desc(),
284           hidden_size,
285           num_layers,
286           dropout_desc_.desc(),
287           input_mode,
288           bidirectional,
289           mode,
290           algo,
291           datatype));
292     if (proj_size != 0) {
293       AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
294             handle,
295             /*rnnDesc=*/mut_desc(),
296             /*recProjSize=*/proj_size,
297             /*outProjSize=*/0));
298     }
299     cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
300     if (prop->major >= 7) {
301       if (input_type == CUDNN_DATA_HALF) {
302         cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
303       }
304       else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
305         cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
306       }
307       else {
308         // Technically, as the default it's not necessary to explicitly
309         // set this.
310         cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
311       }
312     }
313 #else
314     cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
315     auto math_type = CUDNN_DEFAULT_MATH;
316     if (prop->major >= 7) {
317       if (input_type == CUDNN_DATA_HALF) {
318         math_type = CUDNN_TENSOR_OP_MATH;
319       } else if (!allow_tf32) {
320         math_type = CUDNN_FMA_MATH;
321       }
322     }
323     AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v8(
324           mut_desc(),
325           algo,
326           mode,
327           CUDNN_RNN_DOUBLE_BIAS,
328           bidirectional,
329           input_mode,
330           input_type,
331           datatype,
332           math_type,
333           input_size,
334           hidden_size,
335           proj_size ? proj_size : hidden_size,
336           num_layers,
337           dropout_desc_.desc(),
338           packed ? CUDNN_RNN_PADDED_IO_DISABLED : CUDNN_RNN_PADDED_IO_ENABLED));
339 #endif
340   }
341 };
342 
343 struct TORCH_CUDA_CPP_API CTCLossDescriptor
344     : public Descriptor<
345           cudnnCTCLossStruct,
346           &cudnnCreateCTCLossDescriptor,
347           &cudnnDestroyCTCLossDescriptor> {
setCTCLossDescriptor348   void set(cudnnDataType_t datatype) {
349     AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
350   }
setExCTCLossDescriptor351   void setEx(
352       cudnnDataType_t datatype,
353       cudnnLossNormalizationMode_t normMode,
354       cudnnNanPropagation_t gradMode) {
355     AT_CUDNN_CHECK(
356         cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
357   }
set_v8_v9CTCLossDescriptor358   void set_v8_v9(
359       cudnnDataType_t datatype,
360       cudnnLossNormalizationMode_t normMode,
361       cudnnNanPropagation_t gradMode,
362       int maxLabelLength) {
363 #if defined(CUDNN_VERSION) && CUDNN_VERSION >= 90000
364     auto gradModev9 = CUDNN_CTC_ZERO_OOB_GRADIENTS;
365     if (gradMode == cudnnNanPropagation_t::CUDNN_PROPAGATE_NAN) {
366       gradModev9 = CUDNN_CTC_SKIP_OOB_GRADIENTS;
367     }
368     AT_CUDNN_CHECK(
369         cudnnSetCTCLossDescriptor_v9(mut_desc(), datatype, normMode, gradModev9, maxLabelLength));
370 #else
371     AT_CUDNN_CHECK(
372         cudnnSetCTCLossDescriptor_v8(mut_desc(), datatype, normMode, gradMode, maxLabelLength));
373 #endif
374   }
375 
376 };
377 
378 struct TORCH_CUDA_CPP_API ActivationDescriptor
379     : public Descriptor<
380           cudnnActivationStruct,
381           &cudnnCreateActivationDescriptor,
382           &cudnnDestroyActivationDescriptor> {
setActivationDescriptor383   void set(cudnnActivationMode_t mode) {
384     AT_ASSERT(
385         mode == CUDNN_ACTIVATION_RELU,
386         "TODO: support more cuDNN activation modes");
387     AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
388         mut_desc(),
389         mode,
390         cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN,
391         std::numeric_limits<double>::max()));
392   }
393 };
394 
395 union Constant
396 {
397   float f;
398   double d;
Constant(cudnnDataType_t dataType,double value)399   Constant(cudnnDataType_t dataType, double value) {
400     if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
401       f = static_cast<float>(value);
402     } else {
403       d = value;
404     }
405   }
406 };
407 
408 } // namespace
409