1 #pragma once
2
3 #include <string>
4
5 #include <ATen/cuda/CUDAContext.h>
6 #include <ATen/cuda/Exceptions.h>
7
8 #include <ATen/cudnn/cudnn-wrapper.h>
9 #include <ATen/cudnn/Utils.h>
10 #include <ATen/core/Tensor.h>
11 #include <ATen/TensorUtils.h>
12 #include <ATen/cuda/ATenCUDAGeneral.h>
13 #include <cuda.h>
14
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #else
18 #include <ATen/ops/empty.h>
19 #endif
20
21 #if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8907
22 #define USE_CUDNN_RNN_V8_API
23 #endif
24
25 namespace at::native {
26
27 std::string cudnnTypeToString(cudnnDataType_t dtype);
28
29 // TODO: Add constructors for all of the descriptors
30
dataSize(cudnnDataType_t dataType)31 inline int dataSize(cudnnDataType_t dataType)
32 {
33 switch (dataType) {
34 case CUDNN_DATA_BFLOAT16:
35 case CUDNN_DATA_HALF: return 2;
36 case CUDNN_DATA_FLOAT: return 4;
37 default: return 8;
38 }
39 }
40
41 // The stride for a size-1 dimensions is not uniquely determined; in
42 // fact, it can be anything you want, because the fact that the
43 // tensor is size 1 at this dimension means that you will never actually
44 // try advancing your pointer by this stride.
45 //
46 // However, CuDNN has a much more stringent requirement on strides:
47 // if you are passing a contiguous input, it better be the case
48 // that the stride for dim i is the product of the sizes of dims
49 // i+1 to the end. This stride is indeed uniquely determined. This
50 // function modifies 'stride' in place so this invariant holds.
51 template <typename T>
fixSizeOneDimStride(int dim,const T * size,T * stride,bool nhwc)52 static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) {
53 int64_t z = 1;
54 int index = 0;
55 std::vector<int> permutation(dim);
56
57 if (nhwc) {
58 permutation[index++] = 1;
59 }
60 for (int d = dim-1; d > 1; d--) {
61 permutation[index++] = d;
62 }
63 if (!nhwc) {
64 permutation[index++] = 1;
65 }
66 permutation[index++] = 0;
67 for (int d : permutation) {
68 if (size[d] == 1) {
69 stride[d] = z;
70 } else {
71 z *= size[d];
72 }
73 }
74 }
75
76 template <typename T, cudnnStatus_t (*dtor)(T*)>
77 struct DescriptorDeleter {
operatorDescriptorDeleter78 void operator()(T* x) {
79 if (x != nullptr) {
80 AT_CUDNN_CHECK(dtor(x));
81 }
82 }
83 };
84
85 // A generic class for wrapping cuDNN descriptor types. All you need
86 // is to give the underlying type the Descriptor_t points to (usually,
87 // if it's cudnnTensorDescriptor_t it points to cudnnTensorStruct),
88 // the constructor and the destructor. Subclasses are responsible
89 // for defining a set() function to actually set the descriptor.
90 //
91 // Descriptors default construct to a nullptr, and have a descriptor
92 // initialized the first time you call set() or any other initializing
93 // function.
94 template <typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)>
95 class TORCH_CUDA_CPP_API Descriptor {
96 public:
97 // TODO: Figure out why const-correctness doesn't work here
98
99 // Use desc() to access the underlying descriptor pointer in
100 // a read-only fashion. Most client code should use this.
101 // If the descriptor was never initialized, this will return
102 // nullptr.
desc()103 T* desc() const { return desc_.get(); }
desc()104 T* desc() { return desc_.get(); }
105
106 // Use mut_desc() to access the underlying descriptor pointer
107 // if you intend to modify what it points to (e.g., using
108 // cudnnSetFooDescriptor). This will ensure that the descriptor
109 // is initialized. Code in this file will use this function.
mut_desc()110 T* mut_desc() { init(); return desc_.get(); }
111 protected:
init()112 void init() {
113 if (desc_ == nullptr) {
114 T* raw_desc = nullptr;
115 AT_CUDNN_CHECK(ctor(&raw_desc));
116 desc_.reset(raw_desc);
117 }
118 }
119 private:
120 std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
121 };
122
123 class TORCH_CUDA_CPP_API RNNDataDescriptor : public Descriptor<
124 cudnnRNNDataStruct,
125 &cudnnCreateRNNDataDescriptor,
126 &cudnnDestroyRNNDataDescriptor> {
127 public:
128 void set(const at::Tensor &t, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray);
129 private:
set(cudnnDataType_t dataType,cudnnRNNDataLayout_t layout,int maxSeqLength,int batchSize,int vectorSize,const int * seqLengthArray)130 void set(cudnnDataType_t dataType, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray) {
131 AT_CUDNN_CHECK(cudnnSetRNNDataDescriptor(mut_desc(), dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, NULL));
132 }
133 };
134
135 class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
136 cudnnTensorStruct,
137 &cudnnCreateTensorDescriptor,
138 &cudnnDestroyTensorDescriptor> {
139 public:
140 TensorDescriptor() = default;
141 explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
142 set(t, pad);
143 }
144
145 // Note [CuDNN broadcast padding]
146 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
147 // pad specifies the minimum dimensionality of the tensor descriptor
148 // we produce (it doesn't have anything to do with, e.g., convolution
149 // padding). If 't' is lower-dimensional than 'pad', the remaining
150 // dimensions (on the right) are padded with ones. This doesn't
151 // affect the underlying data layout. This is particularly useful for
152 // dealing with a peculiarity of the CuDNN API, which is that broadcasting in CuDNN is
153 // done in two steps: first, the client code is expected to pad out
154 // (the dimensions) input tensors to be the same dimension as the
155 // target broadcast, and then second, CuDNN takes of actually
156 // broadcasting size 1 dimensions.
157
158 void set(const at::Tensor &t, size_t pad = 0);
159 void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0);
160 void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
161
162 void print();
163
164 private:
165 void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc);
166
set(cudnnDataType_t dataType,int dim,int * size,int * stride,bool nhwc)167 void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
168 std::vector<int> strides_copy(stride, stride + dim);
169 fixSizeOneDimStride<int>(dim, size, strides_copy.data(), nhwc);
170 AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, strides_copy.data()));
171 }
172 };
173
174 std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
175
176 class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
177 cudnnFilterStruct,
178 &cudnnCreateFilterDescriptor,
179 &cudnnDestroyFilterDescriptor> {
180 public:
181 void set(const at::Tensor &t, int64_t pad = 0) {
182 set(t, at::MemoryFormat::Contiguous, pad);
183 }
184
185 void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
186
187 void print();
188 private:
set(cudnnDataType_t dataType,int dim,int * size,cudnnTensorFormat_t filter_format)189 void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) {
190 AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
191 }
192 };
193
194 std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d);
195
196 struct TORCH_CUDA_CPP_API ConvolutionDescriptor
197 : public Descriptor<
198 cudnnConvolutionStruct,
199 &cudnnCreateConvolutionDescriptor,
200 &cudnnDestroyConvolutionDescriptor> {
setConvolutionDescriptor201 void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool allow_tf32) {
202 cudnnDataType_t mathType = dataType;
203 if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
204 AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale,
205 CUDNN_CROSS_CORRELATION, mathType));
206 AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
207 // See Note [behavior of cudnnFind and cudnnGet]
208 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
209 if(dataType == CUDNN_DATA_HALF) {
210 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
211 } else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
212 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
213 }
214 }
215 };
216
217 struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor
218 : public Descriptor<
219 cudnnSpatialTransformerStruct,
220 &cudnnCreateSpatialTransformerDescriptor,
221 &cudnnDestroySpatialTransformerDescriptor> {
setSpatialTransformerDescriptor222 void set(cudnnDataType_t dataType, int dim, int* size) {
223 AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size));
224 }
225 };
226
227 struct TORCH_CUDA_CPP_API DropoutDescriptor
228 : public Descriptor<
229 cudnnDropoutStruct,
230 &cudnnCreateDropoutDescriptor,
231 &cudnnDestroyDropoutDescriptor> {
232 at::Tensor state;
233
234 // Initialize a dropout descriptor's RNG state.
235 // WARNING: This function is very expensive, avoid calling this function!
initialize_rngDropoutDescriptor236 void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions& options) {
237 TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
238 size_t state_size = 0;
239 AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
240 AT_ASSERT(options.device().type() == kCUDA);
241 AT_ASSERT(options.dtype() == kByte);
242 state = at::empty({static_cast<int64_t>(state_size)}, options);
243 AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
244 }
245
246 // Restore a dropout descriptor given a dropout probability and existing RNG state.
setDropoutDescriptor247 void set(cudnnHandle_t handle, float dropout, at::Tensor state_) {
248 TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
249 state = state_;
250 void *state_ptr = state.data_ptr();
251 size_t state_size = state.size(0);
252 // NB: The seed doesn't actually matter, so we give a dummy value
253 AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */));
254 }
255
256 // Restore a dropout descriptor corresponding to no dropout
set_no_dropoutDropoutDescriptor257 void set_no_dropout(cudnnHandle_t handle) {
258 // NB: seed doesn't matter when dropout = 0, because no random number
259 // initialization actually takes place when there is no dropout.
260 // NB: Empirically, cudnnSetDropoutDescriptor is cheap when
261 // dropout == 0
262 AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 /* dropout */, nullptr, 0 /* state_size */, 0 /* seed */));
263 }
264 };
265
266 struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor<
267 cudnnRNNStruct,
268 &cudnnCreateRNNDescriptor,
269 &cudnnDestroyRNNDescriptor> {
270 DropoutDescriptor dropout_desc_;
setRNNDescriptor271 void set(cudnnHandle_t handle,
272 #ifdef USE_CUDNN_RNN_V8_API
273 int input_size,
274 bool packed,
275 #endif
276 int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc,
277 cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
278 cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) {
279 dropout_desc_ = std::move(dropout_desc);
280 #ifndef USE_CUDNN_RNN_V8_API
281 AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
282 handle,
283 mut_desc(),
284 hidden_size,
285 num_layers,
286 dropout_desc_.desc(),
287 input_mode,
288 bidirectional,
289 mode,
290 algo,
291 datatype));
292 if (proj_size != 0) {
293 AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
294 handle,
295 /*rnnDesc=*/mut_desc(),
296 /*recProjSize=*/proj_size,
297 /*outProjSize=*/0));
298 }
299 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
300 if (prop->major >= 7) {
301 if (input_type == CUDNN_DATA_HALF) {
302 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
303 }
304 else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
305 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
306 }
307 else {
308 // Technically, as the default it's not necessary to explicitly
309 // set this.
310 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
311 }
312 }
313 #else
314 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
315 auto math_type = CUDNN_DEFAULT_MATH;
316 if (prop->major >= 7) {
317 if (input_type == CUDNN_DATA_HALF) {
318 math_type = CUDNN_TENSOR_OP_MATH;
319 } else if (!allow_tf32) {
320 math_type = CUDNN_FMA_MATH;
321 }
322 }
323 AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v8(
324 mut_desc(),
325 algo,
326 mode,
327 CUDNN_RNN_DOUBLE_BIAS,
328 bidirectional,
329 input_mode,
330 input_type,
331 datatype,
332 math_type,
333 input_size,
334 hidden_size,
335 proj_size ? proj_size : hidden_size,
336 num_layers,
337 dropout_desc_.desc(),
338 packed ? CUDNN_RNN_PADDED_IO_DISABLED : CUDNN_RNN_PADDED_IO_ENABLED));
339 #endif
340 }
341 };
342
343 struct TORCH_CUDA_CPP_API CTCLossDescriptor
344 : public Descriptor<
345 cudnnCTCLossStruct,
346 &cudnnCreateCTCLossDescriptor,
347 &cudnnDestroyCTCLossDescriptor> {
setCTCLossDescriptor348 void set(cudnnDataType_t datatype) {
349 AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
350 }
setExCTCLossDescriptor351 void setEx(
352 cudnnDataType_t datatype,
353 cudnnLossNormalizationMode_t normMode,
354 cudnnNanPropagation_t gradMode) {
355 AT_CUDNN_CHECK(
356 cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
357 }
set_v8_v9CTCLossDescriptor358 void set_v8_v9(
359 cudnnDataType_t datatype,
360 cudnnLossNormalizationMode_t normMode,
361 cudnnNanPropagation_t gradMode,
362 int maxLabelLength) {
363 #if defined(CUDNN_VERSION) && CUDNN_VERSION >= 90000
364 auto gradModev9 = CUDNN_CTC_ZERO_OOB_GRADIENTS;
365 if (gradMode == cudnnNanPropagation_t::CUDNN_PROPAGATE_NAN) {
366 gradModev9 = CUDNN_CTC_SKIP_OOB_GRADIENTS;
367 }
368 AT_CUDNN_CHECK(
369 cudnnSetCTCLossDescriptor_v9(mut_desc(), datatype, normMode, gradModev9, maxLabelLength));
370 #else
371 AT_CUDNN_CHECK(
372 cudnnSetCTCLossDescriptor_v8(mut_desc(), datatype, normMode, gradMode, maxLabelLength));
373 #endif
374 }
375
376 };
377
378 struct TORCH_CUDA_CPP_API ActivationDescriptor
379 : public Descriptor<
380 cudnnActivationStruct,
381 &cudnnCreateActivationDescriptor,
382 &cudnnDestroyActivationDescriptor> {
setActivationDescriptor383 void set(cudnnActivationMode_t mode) {
384 AT_ASSERT(
385 mode == CUDNN_ACTIVATION_RELU,
386 "TODO: support more cuDNN activation modes");
387 AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
388 mut_desc(),
389 mode,
390 cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN,
391 std::numeric_limits<double>::max()));
392 }
393 };
394
395 union Constant
396 {
397 float f;
398 double d;
Constant(cudnnDataType_t dataType,double value)399 Constant(cudnnDataType_t dataType, double value) {
400 if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
401 f = static_cast<float>(value);
402 } else {
403 d = value;
404 }
405 }
406 };
407
408 } // namespace
409