1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/MatrixRef.h>
4 #include <ATen/TensorUtils.h>
5 #include <ATen/core/Tensor.h>
6 #include <ATen/cuda/CUDAConfig.h>
7 #include <ATen/cuda/CUDAEvent.h>
8 #include <ATen/cuda/Exceptions.h>
9 #include <ATen/native/RNN.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/accumulate.h>
12 #include <c10/util/irange.h>
13 #include <ATen/cuda/CUDAGraphsUtils.cuh>
14
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/_cudnn_init_dropout_state.h>
20 #include <ATen/ops/_cudnn_init_dropout_state_native.h>
21 #include <ATen/ops/_cudnn_rnn.h>
22 #include <ATen/ops/_cudnn_rnn_backward_native.h>
23 #include <ATen/ops/_cudnn_rnn_flatten_weight_native.h>
24 #include <ATen/ops/_cudnn_rnn_native.h>
25 #include <ATen/ops/empty.h>
26 #include <ATen/ops/zeros.h>
27 #include <ATen/ops/zeros_like.h>
28 #endif
29
30 #if !AT_CUDNN_ENABLED()
31
32 namespace at {
33 namespace native {
34
35 // See Note [ATen preprocessor philosophy]
36
_cudnn_rnn_flatten_weight(TensorList weight_arr,int64_t weight_stride0,int64_t input_size,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,bool fn_bidirectional)37 Tensor _cudnn_rnn_flatten_weight(
38 TensorList weight_arr,
39 int64_t weight_stride0,
40 int64_t input_size,
41 int64_t fn_mode,
42 int64_t fn_hidden_size,
43 int64_t fn_proj_size,
44 int64_t fn_num_layers,
45 bool batch_first,
46 bool fn_bidirectional) {
47 AT_ERROR("_cudnn_rnn_flatten_weight: ATen not compiled with cuDNN support");
48 }
49
_cudnn_rnn(const Tensor & input_r,TensorList weight,int64_t weight_stride0,const std::optional<Tensor> & weight_buf_r_opt,const Tensor & hx,const std::optional<Tensor> & cx_opt,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const std::optional<Tensor> & fn_dropout_state_opt)50 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
51 const Tensor& input_r,
52 TensorList weight,
53 int64_t weight_stride0,
54 const std::optional<Tensor>& weight_buf_r_opt,
55 const Tensor& hx,
56 const std::optional<Tensor>& cx_opt,
57 int64_t fn_mode,
58 int64_t fn_hidden_size,
59 int64_t fn_proj_size,
60 int64_t fn_num_layers,
61 bool batch_first,
62 double fn_dropout,
63 bool fn_train,
64 bool fn_bidirectional,
65 IntArrayRef fn_batch_sizes,
66 const std::optional<Tensor>& fn_dropout_state_opt) {
67 AT_ERROR("_cudnn_rnn: ATen not compiled with cuDNN support");
68 }
69
_cudnn_rnn_backward(const Tensor & input,TensorList weight,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const std::optional<Tensor> & cx_opt,const Tensor & output,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,int64_t mode,int64_t hidden_size,int64_t proj_size,int64_t num_layers,bool batch_first,double dropout,bool train,bool bidirectional,IntArrayRef batch_sizes,const std::optional<Tensor> & dropout_state_opt,const Tensor & reserve,std::array<bool,4> output_mask)70 std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> _cudnn_rnn_backward(
71 const Tensor& input,
72 TensorList weight,
73 int64_t weight_stride0,
74 const Tensor& weight_buf,
75 const Tensor& hx,
76 const std::optional<Tensor>& cx_opt,
77 const Tensor& output,
78 const std::optional<Tensor>& grad_output_r_opt,
79 const std::optional<Tensor>& grad_hy_r_opt,
80 const std::optional<Tensor>& grad_cy_r_opt,
81 int64_t mode,
82 int64_t hidden_size,
83 int64_t proj_size,
84 int64_t num_layers,
85 bool batch_first,
86 double dropout,
87 bool train,
88 bool bidirectional,
89 IntArrayRef batch_sizes,
90 const std::optional<Tensor>& dropout_state_opt,
91 const Tensor& reserve,
92 std::array<bool, 4> output_mask) {
93 AT_ERROR("_cudnn_rnn_backward: ATen not compiled with cuDNN support");
94 }
95
_cudnn_init_dropout_state(double dropout,bool train,int64_t dropout_seed,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)96 Tensor _cudnn_init_dropout_state(
97 double dropout,
98 bool train,
99 int64_t dropout_seed,
100 std::optional<ScalarType> dtype,
101 std::optional<Layout> layout,
102 std::optional<Device> device,
103 std::optional<bool> pin_memory) {
104 // See [Note: hacky wrapper removal for TensorOptions]
105 TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
106 pin_memory);
107
108 AT_ERROR("_cudnn_init_dropout_state: ATen not compiled with cuDNN support");
109 }
110
111 } // namespace native
112 } // namespace at
113
114 #else // AT_CUDNN_ENABLED()
115
116 #include <ATen/native/cudnn/RNNUtils.h>
117
118 namespace at {
119 namespace native {
120
121 namespace {
122 // DropoutDescriptor
123
124 struct DropoutDescriptorParams {
125 bool train;
126 double dropout;
127 Tensor dropout_state;
128 DropoutDescriptorParams() = default;
setat::native::__anon58f427e00111::DropoutDescriptorParams129 void set(bool train_, double dropout_, Tensor dropout_state_) {
130 train = train_;
131 dropout = dropout_;
132 dropout_state = dropout_state_;
133 }
descriptorat::native::__anon58f427e00111::DropoutDescriptorParams134 DropoutDescriptor descriptor(cudnnHandle_t handle) const {
135 auto dropout_p = train ? dropout : 0;
136 DropoutDescriptor dropout_desc;
137 if (dropout_p == 0) {
138 dropout_desc.set_no_dropout(handle);
139 } else {
140 dropout_desc.set(handle, dropout_p, dropout_state);
141 }
142 return dropout_desc;
143 }
144 };
145
146 // RNNDescriptor
147
148 struct RNNDescriptorParams {
149 #ifdef USE_CUDNN_RNN_V8_API
150 int64_t input_size;
151 bool packed;
152 #endif
153 int64_t hidden_size;
154 int64_t proj_size;
155 int64_t num_layers;
156 cudnnDirectionMode_t bidirectional;
157 cudnnRNNMode_t mode;
158 cudnnDataType_t datatype;
159 cudnnDataType_t input_datatype;
160 cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
161 cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
162
num_directionsat::native::__anon58f427e00111::RNNDescriptorParams163 int64_t num_directions() const {
164 return bidirectional ? 2 : 1;
165 }
166
set_modeat::native::__anon58f427e00111::RNNDescriptorParams167 void set_mode(int64_t fn_mode) {
168 switch (fn_mode) {
169 case CUDNN_RNN_RELU:
170 mode = CUDNN_RNN_RELU;
171 break;
172 case CUDNN_RNN_TANH:
173 mode = CUDNN_RNN_TANH;
174 break;
175 case CUDNN_LSTM:
176 mode = CUDNN_LSTM;
177 break;
178 case CUDNN_GRU:
179 mode = CUDNN_GRU;
180 break;
181 default: {
182 std::ostringstream oss;
183 oss << "unrecognized cuDNN RNN mode " << fn_mode;
184 AT_ERROR(oss.str());
185 }
186 }
187 }
188
set_bidirectionalat::native::__anon58f427e00111::RNNDescriptorParams189 void set_bidirectional(bool fn_bidirectional) {
190 bidirectional =
191 fn_bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
192 }
193
set_algoat::native::__anon58f427e00111::RNNDescriptorParams194 void set_algo(cudnnRNNAlgo_t algo) {
195 this->algo = algo;
196 }
197
198 #ifndef USE_CUDNN_RNN_V8_API
setat::native::__anon58f427e00111::RNNDescriptorParams199 void set(
200 int64_t mode,
201 int64_t hidden_size,
202 int64_t proj_size,
203 int64_t num_layers,
204 bool bidirectional,
205 cudnnDataType_t datatype,
206 cudnnDataType_t input_datatype){
207 #else
208 void set(
209 int64_t mode,
210 int64_t input_size,
211 bool packed,
212 int64_t hidden_size,
213 int64_t proj_size,
214 int64_t num_layers,
215 bool bidirectional,
216 cudnnDataType_t datatype,
217 cudnnDataType_t input_datatype) {
218 #endif
219 this->set_mode(mode);
220 #ifdef USE_CUDNN_RNN_V8_API
221 this->input_size = input_size;
222 this->packed = packed;
223 #endif
224 this->hidden_size = hidden_size;
225 this->proj_size = proj_size;
226 this->num_layers = num_layers;
227 this->set_bidirectional(bidirectional);
228 this->datatype = datatype;
229 this->input_datatype = input_datatype;
230 }
231
232 RNNDescriptor
233 descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
234 RNNDescriptor rnn_desc;
235 #ifndef USE_CUDNN_RNN_V8_API
236 rnn_desc.set(
237 handle,
238 hidden_size,
239 proj_size,
240 num_layers,
241 std::move(dropout_desc),
242 input_mode,
243 bidirectional,
244 mode,
245 datatype,
246 input_datatype,
247 algo,
248 at::globalContext().allowTF32CuDNN());
249 #else
250 rnn_desc.set(
251 handle,
252 input_size,
253 packed,
254 hidden_size,
255 proj_size,
256 num_layers,
257 std::move(dropout_desc),
258 input_mode,
259 bidirectional,
260 mode,
261 datatype,
262 input_datatype,
263 algo,
264 at::globalContext().allowTF32CuDNN());
265 #endif
266 return rnn_desc;
267 }
268
269 // In some cases, a use of RNNDescriptor does not rely on the
270 // DropoutDescriptor. In this case, we fake up a no-dropout
271 // descriptor to make the RNN descriptor initialization go through.
272 // This is used by _cudnn_rnn_flatten_weight, which needs an
273 // RNNDescriptor for get_parameters(), but does not actually need
274 // a fully initialized dropout descriptor. This lets us avoid
275 // having to pass the dropout state to flatten, which has no business
276 // knowing what the dropout state is.
277 RNNDescriptor descriptor(cudnnHandle_t handle) const {
278 DropoutDescriptor dropout_desc;
279 dropout_desc.set_no_dropout(handle);
280 return descriptor(handle, std::move(dropout_desc));
281 }
282 }; // namespace
283
284 // TensorDescriptor list
285 #ifndef USE_CUDNN_RNN_V8_API
rnn_descriptor_sequence(const Tensor & tensor,IntArrayRef batch_sizes)286 std::vector<TensorDescriptor> rnn_descriptor_sequence(
287 const Tensor& tensor,
288 IntArrayRef batch_sizes) {
289 std::vector<TensorDescriptor> descriptors(batch_sizes.size());
290 size_t i = 0;
291 // To be mutated in the loop
292 auto batch_tensor_size = tensor.sizes().vec();
293 for (auto batch_size : batch_sizes) {
294 batch_tensor_size[0] = batch_size;
295 // NB: cuDNN RNN API does not support 2d descriptors, so we
296 // must pad it out to 3d.
297 descriptors[i].set(
298 getCudnnDataType(tensor), batch_tensor_size, tensor.strides(), 3);
299 i++;
300 }
301 return descriptors;
302 }
303
rnn_descriptor(const Tensor & tensor,int64_t N)304 std::vector<TensorDescriptor> rnn_descriptor(const Tensor& tensor, int64_t N) {
305 std::vector<TensorDescriptor> descriptors(N);
306 for (const auto i : c10::irange(N)) {
307 descriptors[i].set(tensor, 5);
308 }
309 return descriptors;
310 }
311 #else
rnn_descriptor_sequence(const Tensor & tensor,uint32_t batch_size,const IntArrayRef batch_sizes,uint32_t seq_len,uint32_t vector_size)312 auto rnn_descriptor_sequence(
313 const Tensor& tensor,
314 uint32_t batch_size,
315 const IntArrayRef batch_sizes,
316 uint32_t seq_len,
317 uint32_t vector_size) { // packed case
318 RNNDataDescriptor r;
319 std::vector<int> seqLengthArray(batch_size, 1);
320 // cuDNN wants the sequence lengths for a packed batch as if they
321 // were unpacked, e.g., for the
322 // Sequence 1: ABCD
323 // Sequence 2: EF
324 // Sequence 3: G
325 // case below, this would be [4, 2, 1] (has length == mini_batch)
326 // TODO(eqy): There's probably a smarter way to do this than O(SN)
327 for (auto it = batch_sizes.begin(); it != batch_sizes.end(); it++) {
328 // everyone starts at sequence length 1 so we skip an iteration
329 if (it == batch_sizes.begin()) {
330 continue;
331 }
332 for (const auto idx : c10::irange(*it)) {
333 seqLengthArray[idx]++;
334 }
335 }
336 r.set(
337 tensor,
338 CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
339 seq_len,
340 batch_size,
341 vector_size,
342 seqLengthArray.data());
343 return r;
344 }
345
rnn_descriptor(const Tensor & tensor,uint32_t batch_size,uint32_t seq_len,uint32_t vector_size)346 auto rnn_descriptor(
347 const Tensor& tensor,
348 uint32_t batch_size,
349 uint32_t seq_len,
350 uint32_t vector_size) {
351 RNNDataDescriptor r;
352 // NB: Looks like even if batch_first is true here we always want
353 // SEQ_MAJOR_UNPACKED, because the input appears to be transposed if it is
354 // barch-major
355 const auto layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
356 std::vector<int32_t> seqLengthArray(batch_size, seq_len);
357 r.set(
358 tensor, layout, seq_len, batch_size, vector_size, seqLengthArray.data());
359 return r;
360 }
361 #endif
362
363 // The best way to understand the meaning of the values stored in
364 // this struct is to consider each of the possible ways our
365 // input can be structured.
366 //
367 // Suppose you want to run RNN on the following variable
368 // length inputs:
369 //
370 // Sequence 1: ABCD
371 // Sequence 2: EF
372 // Sequence 3: G
373 //
374 // (Let _ be padding when we have non-packed representations.)
375 //
376 // # Packed input (batch_sizes is non-empty)
377 //
378 // input_size
379 // +------+ +
380 // | A | |
381 // | E | mini_batch = |
382 // | G | batch_sizes[0] = 3 |
383 // +------+ |
384 // | B | | batch_sizes_sum = 7
385 // | F | batch_sizes[1] = 2 |
386 // +------+ |
387 // | C | batch_sizes[2] = 1 |
388 // +------+ |
389 // | D | batch_sizes[3] = 1 |
390 // +------+ +
391 //
392 // (seq_length = 4)
393 //
394 // input.size() = batch_sizes_sum x input_size
395 //
396 // # Unpacked input (batch_first = false)
397 //
398 // mini_batch = 3
399 // +-------+
400 // | A E G |
401 // | B F _ | seq_length = 4
402 // | C _ _ |
403 // | D _ _ |
404 // +-------+
405 // ... input_size
406 // +-------+
407 //
408 // input.size() = seq_length x mini_batch x input_size
409 //
410 // # Unpacked input (batch_first = true)
411 //
412 // seq_length = 4
413 // +---------+
414 // | A B C D |
415 // | E F _ _ | mini_batch = 3
416 // | G _ _ _ |
417 // +---------+
418 // ... input_size
419 // +---------+
420 //
421 // input.size() = mini_batch x seq_length x input_size
422 //
423 struct TensorDescriptorListParams {
424 IntArrayRef batch_sizes;
425 int64_t seq_length;
426 int64_t mini_batch;
427 // NB: this is not input.size(), which is an IntArrayRef; instead, this
428 // size of the inner-most dimension. In NL applications, this is usually
429 // the size of the embedding. You can also think of this as the size
430 // of the "channel" dimension (at risk of confusing vision researchers :)
431 int64_t input_size;
432 // Only valid when !is_input_packed
433 int64_t batch_sizes_sum; // == sum(batch_sizes)
434
is_input_packedat::native::__anon58f427e00111::TensorDescriptorListParams435 bool is_input_packed() const {
436 return batch_sizes.size() != 0;
437 }
438
setat::native::__anon58f427e00111::TensorDescriptorListParams439 void set(
440 IntArrayRef input_sizes,
441 IntArrayRef batch_sizes_,
442 bool batch_first) {
443 batch_sizes = batch_sizes_;
444 if (is_input_packed()) {
445 seq_length = batch_sizes.size();
446 mini_batch = batch_sizes[0];
447 // NB: When input is packed, the mini_batch size is NOT the size
448 // of the outer dimension
449 batch_sizes_sum = input_sizes[0];
450 input_size = input_sizes[1];
451 } else {
452 if (batch_first) {
453 seq_length = input_sizes[1];
454 mini_batch = input_sizes[0];
455 } else {
456 seq_length = input_sizes[0];
457 mini_batch = input_sizes[1];
458 }
459 input_size = input_sizes[2];
460 // TODO: Actually, would this make ASAN's job harder catching
461 // an uninitialized access?
462 batch_sizes_sum = -1; // something bogus in case we access it
463 }
464 }
465 #ifndef USE_CUDNN_RNN_V8_API
466 // TODO: check x for consistency with input_size?
descriptorsat::native::__anon58f427e00111::TensorDescriptorListParams467 std::vector<TensorDescriptor> descriptors(Tensor x) const {
468 auto is_input_packed = batch_sizes.size() != 0;
469 if (is_input_packed) {
470 return rnn_descriptor_sequence(x, batch_sizes);
471 } else {
472 return rnn_descriptor(x[0], seq_length);
473 }
474 }
475 #else
descriptorsat::native::__anon58f427e00111::TensorDescriptorListParams476 auto descriptors(Tensor x) const {
477 auto is_input_packed = batch_sizes.size() != 0;
478 if (is_input_packed) {
479 return rnn_descriptor_sequence(
480 x, mini_batch, batch_sizes, seq_length, x.size(-1));
481 } else {
482 return rnn_descriptor(x, mini_batch, seq_length, x.size(-1));
483 }
484 }
485 #endif
486 };
487
488 // Everything together
489
490 struct RNNParams {
491 DropoutDescriptorParams dropout;
492 RNNDescriptorParams rnn;
493 TensorDescriptorListParams tensors;
494 };
495
496 // NB: Doesn't include the weight descriptor
497 struct RNNDescriptors {
498 RNNDescriptor rnn_desc;
499 // NB: this won't actually lay out the tensor descriptor pointers
500 // in the right way, so you'll have to preprocess them
501 #ifndef USE_CUDNN_RNN_V8_API
502 std::vector<TensorDescriptor> x_descs;
503 std::vector<TensorDescriptor> y_descs;
504 #else
505 RNNDataDescriptor x_descs;
506 RNNDataDescriptor y_descs;
507 #endif
508 TensorDescriptor hx_desc;
509 TensorDescriptor hy_desc;
510 TensorDescriptor cx_desc;
511 TensorDescriptor cy_desc;
512
RNNDescriptorsat::native::__anon58f427e00111::RNNDescriptors513 RNNDescriptors(
514 const RNNParams& fn,
515 cudnnHandle_t handle,
516 Tensor x,
517 Tensor y,
518 Tensor hx,
519 Tensor cx) {
520 rnn_desc = fn.rnn.descriptor(handle, fn.dropout.descriptor(handle));
521 x_descs = fn.tensors.descriptors(x);
522 y_descs = fn.tensors.descriptors(y);
523 hx_desc.set(hx, 5);
524 hy_desc.set(hx, 5);
525 if (cx.defined()) {
526 cx_desc.set(cx, 5);
527 cy_desc.set(cx, 5);
528 }
529 }
530
531 // TODO: This is annoying, having to put the cudnnTensorDescriptor_t
532 // in a contiguous array...
get_descsat::native::__anon58f427e00111::RNNDescriptors533 std::vector<cudnnTensorDescriptor_t> get_descs(
534 const std::vector<TensorDescriptor>& descs) {
535 std::vector<cudnnTensorDescriptor_t> r;
536 r.reserve(descs.size());
537 for (auto& desc : descs) {
538 r.emplace_back(desc.desc());
539 }
540 return r;
541 }
542 #ifndef USE_CUDNN_RNN_V8_API
get_x_descsat::native::__anon58f427e00111::RNNDescriptors543 std::vector<cudnnTensorDescriptor_t> get_x_descs() {
544 return get_descs(x_descs);
545 }
546
get_y_descsat::native::__anon58f427e00111::RNNDescriptors547 std::vector<cudnnTensorDescriptor_t> get_y_descs() {
548 return get_descs(y_descs);
549 }
550 #endif
551 };
552
get_num_weights(cudnnHandle_t handle,const RNNDescriptor & rnn_desc,const TensorDescriptor & x_desc,cudnnDataType_t datatype)553 int64_t get_num_weights(
554 cudnnHandle_t handle,
555 const RNNDescriptor& rnn_desc,
556 #ifndef USE_CUDNN_RNN_V8_API
557 const TensorDescriptor& x_desc,
558 #endif
559 cudnnDataType_t datatype) {
560 size_t weight_size;
561 #ifndef USE_CUDNN_RNN_V8_API
562 AT_CUDNN_CHECK(cudnnGetRNNParamsSize(
563 handle, rnn_desc.desc(), x_desc.desc(), &weight_size, datatype));
564 #else
565 AT_CUDNN_CHECK(
566 cudnnGetRNNWeightSpaceSize(handle, rnn_desc.desc(), &weight_size));
567 #endif
568 auto elem_size = dataSize(datatype);
569 TORCH_INTERNAL_ASSERT(
570 weight_size % elem_size == 0,
571 "cudnnGetRNNParamsSize returned nonsensical weight_size");
572 return weight_size / elem_size;
573 }
574
_num_linear_layers(cudnnRNNMode_t mode)575 int64_t _num_linear_layers(cudnnRNNMode_t mode) {
576 switch (mode) {
577 case CUDNN_LSTM:
578 return 8;
579 case CUDNN_GRU:
580 return 6;
581 case CUDNN_RNN_RELU:
582 return 2;
583 case CUDNN_RNN_TANH:
584 return 2;
585 default:
586 AT_ERROR("unknown cuDNN RNN mode ", mode);
587 }
588 }
589
add_projection_weights(cudnnHandle_t handle,const RNNDescriptor & rnn_desc,const TensorDescriptor & x_desc,const FilterDescriptor & w_desc,const Tensor & weight_buf,int64_t layer,std::vector<Tensor> & params)590 void add_projection_weights(
591 cudnnHandle_t handle,
592 const RNNDescriptor& rnn_desc,
593 #ifndef USE_CUDNN_RNN_V8_API
594 const TensorDescriptor& x_desc,
595 const FilterDescriptor& w_desc,
596 #endif
597 const Tensor& weight_buf,
598 int64_t layer,
599 std::vector<Tensor>& params) {
600 void* matrix_pointer = nullptr;
601 // assuming it's LSTM which has 8 "linear layers" (i.e. 4 weights and 4
602 // biases)
603 int64_t linear_id = 8;
604 #ifndef USE_CUDNN_RNN_V8_API
605 FilterDescriptor lin_layer_mat_desc;
606 AT_CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams(
607 /*handle=*/handle,
608 /*rnnDesc=*/rnn_desc.desc(),
609 /*layer=*/layer,
610 /*xDesc=*/x_desc.desc(),
611 /*wDesc=*/w_desc.desc(),
612 /*w=*/weight_buf.data_ptr(),
613 /*linLayerID=*/linear_id,
614 /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(),
615 /*linLayerMat=*/&matrix_pointer));
616 #else
617 TensorDescriptor lin_layer_mat_desc;
618 AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
619 /*handle=*/handle,
620 /*rnnDesc=*/rnn_desc.desc(),
621 /*layer=*/layer,
622 /*wDesc=*/weight_buf.numel() * weight_buf.element_size(),
623 /*w=*/weight_buf.data_ptr(),
624 /*linLayerID=*/linear_id,
625 /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(),
626 /*linLayerMat=*/&matrix_pointer,
627 nullptr,
628 nullptr));
629 #endif
630
631 cudnnDataType_t data_type;
632 #ifndef USE_CUDNN_RNN_V8_API
633 cudnnTensorFormat_t format;
634 #else
635 int stride_dim_a[5];
636 #endif
637 int nb_dims;
638 constexpr int min_dim = 3;
639 int filter_dim_a[min_dim];
640 #ifndef USE_CUDNN_RNN_V8_API
641 AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor(
642 lin_layer_mat_desc.desc(),
643 min_dim,
644 &data_type,
645 &format,
646 &nb_dims,
647 filter_dim_a));
648 #else
649 AT_CUDNN_CHECK(cudnnGetTensorNdDescriptor(
650 lin_layer_mat_desc.desc(),
651 min_dim,
652 &data_type,
653 &nb_dims,
654 filter_dim_a,
655 stride_dim_a));
656 #endif
657
658 TORCH_INTERNAL_ASSERT(
659 nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim);
660 auto elem_size = dataSize(getCudnnDataType(weight_buf));
661 auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr();
662 TORCH_INTERNAL_ASSERT(
663 offset_bytes % elem_size == 0,
664 "offset_bytes = ",
665 offset_bytes,
666 "; elem_size = ",
667 elem_size);
668 size_t offset = offset_bytes / elem_size;
669
670 int mat_numel = c10::multiply_integers(filter_dim_a, filter_dim_a + nb_dims);
671 // Generate a new parameter tensor which is a view into the weight_buf.
672 std::initializer_list<int64_t> size = {mat_numel, 1};
673 Tensor param = at::empty({0}, weight_buf.options())
674 .set_(weight_buf.storage(), offset, size);
675 params.emplace_back(std::move(param));
676 }
677
678 /*
679 Returns weight and bias tensors for each layer of the RNN. These tensors
680 are views on the underlying weight buffer allocated by CuDNN.
681
682 Note: for LSTM and GRU, which have multiple parameters of each type (4 and 3,
683 respectively), these parameters are concatenated along the first dimension.
684 These parameters are returned in a consistent order by CuDNN:
685 (reset, forget, cell, output) for LSTM
686 (reset, input, new) for GRU
687 Args:
688 fn: The RNN function object holding the RNN state
689 handle: a CuDNN handle
690 weight_buf: a 1D tensor containing the CuDNN-allocated weight (or
691 grad_weight) buffer Returns: parameters: [(weight_ih, weight_hh, bias_ih,
692 bias_hh)*], with length equal to the num_layers. This is represented as a pair
693 of vector, and outer-dimension stride (NB: Can't return MatrixRef because we
694 need to allocate the underlying tensor)
695 */
696 std::pair<std::vector<Tensor>, size_t> // stride0
get_parameters(cudnnHandle_t handle,const RNNDescriptorParams & rnn,const RNNDescriptor & rnn_desc,const TensorDescriptor & x_desc,const FilterDescriptor & w_desc,const Tensor & weight_buf,bool include_bias=true)697 get_parameters(
698 cudnnHandle_t handle,
699 const RNNDescriptorParams& rnn,
700 const RNNDescriptor& rnn_desc,
701 #ifndef USE_CUDNN_RNN_V8_API
702 const TensorDescriptor& x_desc,
703 const FilterDescriptor& w_desc,
704 #endif
705 const Tensor& weight_buf,
706 bool include_bias = true) {
707 #ifndef USE_CUDNN_RNN_V8_API
708 auto cudnn_methods = {
709 cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams};
710 #else
711 auto cudnn_methods = {true, false};
712 #endif
713 std::vector<Tensor> params;
714 int64_t num_linear_layers = _num_linear_layers(rnn.mode);
715 int64_t num_layers = rnn.num_directions() * rnn.num_layers;
716 size_t cur_offset = 0;
717 size_t global_layer_params_count = 0;
718 for (const auto layer : c10::irange(num_layers)) {
719 size_t layer_params_count = 0;
720 for (auto cudnn_method : cudnn_methods) {
721 for (const auto linear_id : c10::irange(num_linear_layers)) {
722 void* matrix_pointer;
723 #ifndef USE_CUDNN_RNN_V8_API
724 FilterDescriptor lin_layer_mat_desc;
725 AT_CUDNN_CHECK(cudnn_method(
726 handle,
727 rnn_desc.desc(),
728 layer,
729 x_desc.desc(),
730 w_desc.desc(),
731 weight_buf.data_ptr(),
732 linear_id,
733 lin_layer_mat_desc.mut_desc(),
734 &matrix_pointer));
735 #else
736 TensorDescriptor lin_layer_mat_desc;
737 for (int stateless = 0; stateless < 100; stateless++) {
738 if (cudnn_method) { // matrix
739 AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
740 handle,
741 rnn_desc.desc(),
742 layer,
743 weight_buf.numel() * weight_buf.element_size(),
744 weight_buf.data_ptr(),
745 linear_id,
746 lin_layer_mat_desc.mut_desc(),
747 &matrix_pointer,
748 nullptr,
749 nullptr));
750 } else { // bias
751 AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
752 handle,
753 rnn_desc.desc(),
754 layer,
755 weight_buf.numel() * weight_buf.element_size(),
756 weight_buf.data_ptr(),
757 linear_id,
758 nullptr,
759 nullptr,
760 lin_layer_mat_desc.mut_desc(),
761 &matrix_pointer));
762 }
763 }
764 #endif
765 cudnnDataType_t data_type;
766 #ifndef USE_CUDNN_RNN_V8_API
767 cudnnTensorFormat_t format;
768 #else
769 int stride_dim_a[5];
770 #endif
771 int nb_dims;
772 constexpr int min_dim = 3;
773 int filter_dim_a[min_dim];
774 #ifndef USE_CUDNN_RNN_V8_API
775 AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor(
776 lin_layer_mat_desc.desc(),
777 min_dim,
778 &data_type,
779 &format,
780 &nb_dims,
781 filter_dim_a));
782 #else
783 AT_CUDNN_CHECK(cudnnGetTensorNdDescriptor(
784 lin_layer_mat_desc.desc(),
785 min_dim,
786 &data_type,
787 &nb_dims,
788 filter_dim_a,
789 stride_dim_a));
790 #endif
791
792 TORCH_INTERNAL_ASSERT(
793 nb_dims <= min_dim,
794 "nb_dims = ",
795 nb_dims,
796 "; min_dim = ",
797 min_dim);
798 auto elem_size = dataSize(getCudnnDataType(weight_buf));
799 auto offset_bytes =
800 (char*)matrix_pointer - (char*)weight_buf.data_ptr();
801 TORCH_INTERNAL_ASSERT(
802 offset_bytes % elem_size == 0,
803 "offset_bytes = ",
804 offset_bytes,
805 "; elem_size = ",
806 elem_size);
807 size_t offset = offset_bytes / elem_size;
808 // for all the RNN types provided by CUDNN, all the ih weights
809 // are the same size and are allocated in a contiguous chunk
810 // (same for the hh weights, and the ih and hh biases).
811 // Since we're storing all the weights in a single tensor anyway,
812 // might as well merge the CUDNN ones into a single tensor as well
813 int mat_numel =
814 c10::multiply_integers(filter_dim_a, filter_dim_a + nb_dims);
815 if (linear_id == 0 || linear_id == num_linear_layers / 2) {
816 // We could also exclude bias params by restricting cudnn_methods to
817 // just { cudnnGetRNNLinLayerMatrixParams } at the very top. However,
818 // to do so would throw off the cur_offset account, which is currently
819 // a strict and informative check that all params are laid out the way
820 // we think they are. If include_bias is false, I'd rather keep full
821 // cur_offset checks rather than save some CPU overhead by skipping
822 // the cudnn_method = cudnnGetRNNLinLayerBiasParams iteration.
823 #ifndef USE_CUDNN_RNN_V8_API
824 if (include_bias || cudnn_method != cudnnGetRNNLinLayerBiasParams) {
825 #else
826 if (include_bias || cudnn_method) {
827 #endif
828 // Generate a new parameter tensor which is a view into the
829 // weight_buf.
830 std::initializer_list<int64_t> size = {
831 mat_numel * num_linear_layers / 2, 1};
832 Tensor param = at::empty({0}, weight_buf.options())
833 .set_(weight_buf.storage(), offset, size);
834 params.emplace_back(std::move(param));
835 layer_params_count++;
836 }
837 } else {
838 TORCH_INTERNAL_ASSERT(
839 cur_offset == offset,
840 "cur_offset = ",
841 cur_offset,
842 "; offset = ",
843 offset);
844 }
845 cur_offset = offset + mat_numel;
846 }
847 } // for cudnn_method
848 if (rnn.proj_size != 0) {
849 #ifndef USE_CUDNN_RNN_V8_API
850 add_projection_weights(
851 handle, rnn_desc, x_desc, w_desc, weight_buf, layer, params);
852 #else
853 add_projection_weights(handle, rnn_desc, weight_buf, layer, params);
854 #endif
855 layer_params_count++;
856 }
857
858 if (layer == 0) {
859 global_layer_params_count = layer_params_count;
860 } else {
861 TORCH_INTERNAL_ASSERT(
862 global_layer_params_count == layer_params_count,
863 "global_layer_params_count = ",
864 global_layer_params_count,
865 "; layer_params_count = ",
866 layer_params_count);
867 }
868 } // for layer
869 return std::make_pair(params, global_layer_params_count);
870 }
871
872 // This is a lightweight version of the method above used to quickly get the
873 // expected parameter offsets.
874 std::vector<void*> get_expected_data_ptrs(
875 const Tensor& weight_buf,
876 cudnnHandle_t handle,
877 const RNNDescriptorParams& rnn,
878 const RNNDescriptor& rnn_desc,
879 const TensorDescriptor& x_desc,
880 cudnnDataType_t datatype) {
881 #ifndef USE_CUDNN_RNN_V8_API
882 FilterDescriptor w_desc;
883 w_desc.set(weight_buf, 3);
884 #endif
885
886 int64_t num_linear_layers = _num_linear_layers(rnn.mode);
887 int64_t num_dir_layers = rnn.num_directions() * rnn.num_layers;
888 #ifndef USE_CUDNN_RNN_V8_API
889 const auto cudnn_methods = {
890 cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams};
891 #else
892 const auto cudnn_methods = {true, false};
893 #endif
894 std::vector<void*> data_ptrs;
895 if (rnn.proj_size != 0) {
896 data_ptrs.reserve(num_dir_layers * (2 * 2 + 1));
897 } else {
898 data_ptrs.reserve(num_dir_layers * 2 * 2);
899 }
900 for (const auto layer : c10::irange(num_dir_layers)) {
901 for (auto cudnn_method : cudnn_methods) {
902 // This API returns a separate pointer for weight of every gate,
903 // but we represent them as a single tensor, so we're only interested
904 // in a very limited subset of possible values.
905 const std::array<int64_t, 2> linear_offsets = {0, num_linear_layers / 2};
906 for (int64_t linear_id : linear_offsets) {
907 void* matrix_pointer;
908 #ifndef USE_CUDNN_RNN_V8_API
909 FilterDescriptor lin_layer_mat_desc;
910 AT_CUDNN_CHECK(cudnn_method(
911 handle,
912 rnn_desc.desc(),
913 layer,
914 x_desc.desc(),
915 w_desc.desc(),
916 weight_buf.data_ptr(),
917 linear_id,
918 lin_layer_mat_desc.mut_desc(),
919 &matrix_pointer));
920 #else
921 TensorDescriptor lin_layer_mat_desc;
922 if (cudnn_method) { // matrix
923 AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
924 handle,
925 rnn_desc.desc(),
926 layer,
927 weight_buf.numel() * weight_buf.element_size(),
928 weight_buf.data_ptr(),
929 linear_id,
930 lin_layer_mat_desc.mut_desc(),
931 &matrix_pointer,
932 nullptr,
933 nullptr));
934 } else { // bias
935 AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
936 handle,
937 rnn_desc.desc(),
938 layer,
939 weight_buf.numel() * weight_buf.element_size(),
940 weight_buf.data_ptr(),
941 linear_id,
942 nullptr,
943 nullptr,
944 lin_layer_mat_desc.mut_desc(),
945 &matrix_pointer));
946 }
947 #endif
948 data_ptrs.push_back(matrix_pointer);
949 }
950 }
951 if (rnn.proj_size != 0) {
952 // assuming it's LSTM which has 8 "linear layers" (i.e. 4 weights and 4
953 // biases)
954 int64_t linear_id = 8;
955 void* matrix_pointer;
956 #ifndef USE_CUDNN_RNN_V8_API
957 FilterDescriptor lin_layer_mat_desc;
958 AT_CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams(
959 handle,
960 rnn_desc.desc(),
961 layer,
962 x_desc.desc(),
963 w_desc.desc(),
964 weight_buf.data_ptr(),
965 linear_id,
966 lin_layer_mat_desc.mut_desc(),
967 &matrix_pointer));
968 #else
969 TensorDescriptor lin_layer_mat_desc;
970
971 AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
972 handle,
973 rnn_desc.desc(),
974 layer,
975 weight_buf.numel() * weight_buf.element_size(),
976 weight_buf.data_ptr(),
977 linear_id,
978 lin_layer_mat_desc.mut_desc(),
979 &matrix_pointer,
980 nullptr,
981 nullptr));
982 #endif
983 data_ptrs.push_back(matrix_pointer);
984 }
985 }
986 return data_ptrs;
987 }
988
989 void _viewOrCopyOneParam(
990 const Tensor& param_from,
991 const Tensor& param_to,
992 bool copy,
993 bool allow_type_change = false) {
994 // if copying, allow_type_change may be true or false.
995 // if viewing, allow_type_change must be false.
996 TORCH_INTERNAL_ASSERT(
997 copy || !allow_type_change, "if viewing, type change is not allowed.");
998 TORCH_INTERNAL_ASSERT(
999 allow_type_change || (param_from.scalar_type() == param_to.scalar_type()),
1000 "parameter types mismatch");
1001 if (copy) {
1002 param_to.copy_(param_from.view_as(param_to));
1003 } else {
1004 param_from.resize_as_(param_to);
1005 }
1006 }
1007
1008 void _viewOrCopyParams(
1009 MatrixRef<Tensor> params_from,
1010 MatrixRef<Tensor> params_to,
1011 bool copy,
1012 bool allow_type_change = false) {
1013 TORCH_INTERNAL_ASSERT(
1014 params_from.size(0) == params_to.size(0), "number of layers mismatch");
1015 for (const auto i : c10::irange(params_from.size(0))) {
1016 auto layer_params_from = params_from[i];
1017 auto layer_params_to = params_to[i];
1018 // NOTE: these lists have all weights before all biases, so if the layer
1019 // doesn't use biases, iteration will terminate once layer_params_from ends
1020 // and ignore them.
1021
1022 // NOTE: there is an exception from the above statement. If LSTMs with
1023 // projections are used, weights layout will be w_ih, w_hh, b_ih, b_hh,
1024 // w_hr. So need to handle no-bias case specially, because will need to copy
1025 // 0->0, 1->1, 2->4. This case can be uniquely identified by checking if
1026 // number of defined parameters for each layer is 3.
1027 if (layer_params_from.size() == 3 && layer_params_to.size() != 3) {
1028 _viewOrCopyOneParam(
1029 layer_params_from[0], layer_params_to[0], copy, allow_type_change);
1030 _viewOrCopyOneParam(
1031 layer_params_from[1], layer_params_to[1], copy, allow_type_change);
1032 _viewOrCopyOneParam(
1033 layer_params_from[2], layer_params_to[4], copy, allow_type_change);
1034 continue;
1035 }
1036 if (layer_params_to.size() == 3 && layer_params_from.size() != 3) {
1037 _viewOrCopyOneParam(
1038 layer_params_from[0], layer_params_to[0], copy, allow_type_change);
1039 _viewOrCopyOneParam(
1040 layer_params_from[1], layer_params_to[1], copy, allow_type_change);
1041 _viewOrCopyOneParam(
1042 layer_params_from[4], layer_params_to[2], copy, allow_type_change);
1043 continue;
1044 }
1045 for (auto a = layer_params_from.begin(), b = layer_params_to.begin();
1046 a != layer_params_from.end() && b != layer_params_to.end();
1047 ++a, ++b) {
1048 _viewOrCopyOneParam(*a, *b, copy, allow_type_change);
1049 }
1050 }
1051 }
1052
1053 void _copyParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
1054 _viewOrCopyParams(params_from, params_to, true);
1055 }
1056
1057 void _viewParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
1058 _viewOrCopyParams(params_from, params_to, false);
1059 }
1060
1061 std::vector<int64_t> _input_size(const TensorDescriptorListParams& tensors) {
1062 if (tensors.is_input_packed()) {
1063 return {tensors.batch_sizes_sum, tensors.input_size};
1064 } else {
1065 return {tensors.seq_length, tensors.mini_batch, tensors.input_size};
1066 }
1067 }
1068
1069 std::vector<int64_t> _hidden_size(
1070 const RNNDescriptorParams& rnn,
1071 const TensorDescriptorListParams& tensors) {
1072 if (rnn.proj_size != 0) {
1073 return {
1074 rnn.num_layers * rnn.num_directions(),
1075 tensors.mini_batch,
1076 rnn.proj_size};
1077 } else {
1078 return {
1079 rnn.num_layers * rnn.num_directions(),
1080 tensors.mini_batch,
1081 rnn.hidden_size};
1082 }
1083 }
1084
1085 std::vector<int64_t> _cell_size(
1086 const RNNDescriptorParams& rnn,
1087 const TensorDescriptorListParams& tensors) {
1088 return {
1089 rnn.num_layers * rnn.num_directions(),
1090 tensors.mini_batch,
1091 rnn.hidden_size};
1092 }
1093
1094 std::vector<int64_t> _output_size(
1095 const RNNDescriptorParams& rnn,
1096 const TensorDescriptorListParams& tensors) {
1097 auto out_size = rnn.hidden_size;
1098 if (rnn.proj_size != 0) {
1099 out_size = rnn.proj_size;
1100 }
1101 if (tensors.is_input_packed()) {
1102 return {tensors.batch_sizes_sum, out_size * rnn.num_directions()};
1103 } else {
1104 return {
1105 tensors.seq_length,
1106 tensors.mini_batch,
1107 out_size * rnn.num_directions()};
1108 }
1109 }
1110
1111 inline bool use_persist_common_heuristics(
1112 const RNNDescriptorParams& rnn,
1113 const TensorDescriptorListParams& tensors) {
1114 return rnn.num_layers == 1 && rnn.hidden_size <= 1024 &&
1115 rnn.num_directions() == 1 && rnn.hidden_size % 128 == 0 &&
1116 tensors.input_size % 128 == 0;
1117 }
1118
1119 inline bool use_persist_device_heuristics(
1120 const RNNDescriptorParams& rnn,
1121 const TensorDescriptorListParams& tensors) {
1122 auto bsize = tensors.mini_batch;
1123 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
1124 if (prop->major == 7) {
1125 if (prop->minor == 5) {
1126 // Excludes Turing from using persistent rnn.
1127 return false;
1128 } else {
1129 // technically, batch size should be multiple of 8, but there are quite a
1130 // few multiple-of-8 batchsizes that give bad perf, weed them out
1131 return ((bsize % 16 == 0 && bsize != 80 && bsize != 112) || bsize == 8) &&
1132 ((tensors.seq_length >= 40 && bsize <= 128) ||
1133 (tensors.seq_length >= 20 && bsize <= 96) ||
1134 (tensors.seq_length >= 10 && bsize <= 32));
1135 }
1136 } else if (prop->major >= 8 && prop->multiProcessorCount >= 98) {
1137 // SM count check excludes A30 (similar issue to A40)
1138 if (prop->minor == 6) {
1139 // Excludes sm_86 GPU devices from using persistent rnn.
1140 // This is because there are some edge cases that will throw exceptions
1141 // with cudnn 8.0.5 on Nvidia A40 GPU.
1142 return false;
1143 }
1144 // Based on tests by Vasily Volkov and xwang233. Vasily only tried bsize <=
1145 // 128, so conservatively enable persistence for bsize <= 128 only.
1146 // TODO: Run more tests for bsize > 128.
1147 if (rnn.mode == CUDNN_GRU) {
1148 // Persistent GRU performance is flakier than other RNN types. Exclude
1149 // them for now.
1150 // TODO: Write a more refined GRU heuristic.
1151 return false;
1152 } else if (rnn.mode == CUDNN_LSTM) {
1153 // Persistent LSTMs are comparable to or better than non-persistent for
1154 // bsize <= 128.
1155 return (bsize % 8 == 0) && (bsize <= 128);
1156 } else {
1157 // Persistent RNN_RELU and TANH show poor performance when bsize >= 96 AND
1158 // hidden size >= 896.
1159 return (bsize % 8 == 0) && (bsize <= 128) &&
1160 (bsize < 96 || rnn.hidden_size < 896);
1161 }
1162 } else {
1163 return false;
1164 }
1165 }
1166
1167 inline bool use_rnn_persist_small_h(
1168 const RNNDescriptorParams& rnn,
1169 const TensorDescriptorListParams& tensors,
1170 bool forward) {
1171 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
1172 if (prop->major < 6)
1173 return false;
1174
1175 if (forward) {
1176 if (rnn.mode == CUDNN_RNN_RELU || rnn.mode == CUDNN_RNN_TANH) {
1177 return rnn.hidden_size <= 384;
1178 }
1179 if (rnn.mode == CUDNN_LSTM || rnn.mode == CUDNN_GRU) {
1180 return rnn.hidden_size <= 192;
1181 }
1182 } else /* backward */ {
1183 if (rnn.mode == CUDNN_RNN_RELU || rnn.mode == CUDNN_RNN_TANH) {
1184 return rnn.hidden_size <= 256;
1185 }
1186 if (rnn.mode == CUDNN_LSTM || rnn.mode == CUDNN_GRU) {
1187 return rnn.hidden_size <= 128;
1188 }
1189 }
1190
1191 return false;
1192 }
1193
1194 cudnnRNNAlgo_t get_algo(
1195 const RNNDescriptorParams& rnn,
1196 const TensorDescriptorListParams& tensors,
1197 const Tensor input,
1198 bool forward) {
1199 // LSTM with projections only works with standard algorithm
1200 if (rnn.proj_size != 0) {
1201 return CUDNN_RNN_ALGO_STANDARD;
1202 }
1203
1204 // Persistent algos typically don't work for packed inputs with sequence
1205 // lengths that vary across batch elements, and will return
1206 // CUDNN_STATUS_NOT_SUPPORTED if attempted. See
1207 // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#features-of-rnn-functions
1208 if (!tensors.is_input_packed()) {
1209 auto cudnnDataType = getCudnnDataType(input);
1210 if (cudnnDataType != CUDNN_DATA_DOUBLE) {
1211 if (use_rnn_persist_small_h(rnn, tensors, forward)) {
1212 return CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H;
1213 }
1214 }
1215 if (cudnnDataType == CUDNN_DATA_HALF) {
1216 if (use_persist_common_heuristics(rnn, tensors) &&
1217 use_persist_device_heuristics(rnn, tensors)) {
1218 return CUDNN_RNN_ALGO_PERSIST_STATIC;
1219 }
1220 }
1221 }
1222
1223 return CUDNN_RNN_ALGO_STANDARD;
1224 }
1225
1226 cudnnDataType_t promote_rnn_math_type(cudnnDataType_t dtype) {
1227 if (dtype == CUDNN_DATA_HALF) {
1228 return CUDNN_DATA_FLOAT;
1229 }
1230 return dtype;
1231 }
1232
1233 } // namespace native
1234
1235 // Utilities exposed in RNNUtils.h
1236 namespace cudnn_rnn {
1237
1238 TORCH_CUDA_CPP_API std::tuple<Tensor, std::vector<Tensor>>
copy_weights_to_flat_buf_views(TensorList weight_arr,int64_t weight_stride0,int64_t input_size,int64_t mode,int64_t hidden_size,int64_t proj_size,int64_t num_layers,bool batch_first,bool bidirectional,const cudnnDataType_t flat_buf_datatype,const TensorOptions & flat_buf_options,bool set_orig_weights_to_flat_buf,bool allow_type_change,bool include_bias)1239 copy_weights_to_flat_buf_views(
1240 TensorList weight_arr,
1241 int64_t weight_stride0,
1242 int64_t input_size,
1243 int64_t mode,
1244 int64_t hidden_size,
1245 int64_t proj_size,
1246 int64_t num_layers,
1247 bool batch_first,
1248 bool bidirectional,
1249 const cudnnDataType_t flat_buf_datatype,
1250 const TensorOptions& flat_buf_options,
1251 bool set_orig_weights_to_flat_buf,
1252 bool allow_type_change /*=false*/,
1253 bool include_bias /*=true*/) {
1254 // flat_buf_datatype is accepted as a separate argument (rather than extracted
1255 // from flat_buf_options) because to extract flat_buf_datatype from
1256 // flat_buf_options, we'd need to say auto flat_buf_datatype =
1257 // getCudnnDataTypeFromScalarType(typeMetaToScalarType(options.dtype()));
1258 // typeMetaToScalarType is a surprisingly nontrivial function. We should
1259 // avoid it if we can.
1260 TORCH_CHECK(
1261 weight_arr.size() > 0,
1262 "copy_weights_to_flat_buf_views: cannot flatten empty weight list");
1263
1264 RNNDescriptorParams rnn;
1265 rnn.set(
1266 mode,
1267 #ifdef USE_CUDNN_RNN_V8_API
1268 input_size,
1269 false, // eqy: bogus as we do not know if the input is packed here
1270 // but it should not affect the weights (what are are interested
1271 // in)
1272 #endif
1273 hidden_size,
1274 proj_size,
1275 num_layers,
1276 bidirectional,
1277 promote_rnn_math_type(flat_buf_datatype),
1278 flat_buf_datatype);
1279
1280 auto handle = getCudnnHandle();
1281 RNNDescriptor rnn_desc = rnn.descriptor(handle);
1282
1283 TensorGeometry x_geom({1, input_size});
1284 TensorDescriptor x_desc;
1285 // Why do we pad to 5 dims here (and elsewhere)?
1286 // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNForwardTraining
1287 // expects descriptors padded to 3 dimensions.
1288 x_desc.set(flat_buf_datatype, x_geom.sizes(), x_geom.strides(), 5);
1289
1290 auto num_weights =
1291 #ifndef USE_CUDNN_RNN_V8_API
1292 get_num_weights(handle, rnn_desc, x_desc, flat_buf_datatype);
1293 #else
1294 get_num_weights(handle, rnn_desc, flat_buf_datatype);
1295 #endif
1296 auto weight_buf = at::zeros(num_weights, flat_buf_options);
1297
1298 #ifndef USE_CUDNN_RNN_V8_API
1299 FilterDescriptor w_desc;
1300 w_desc.set(weight_buf, 3);
1301 #endif
1302
1303 // Slice off views into weight_buf
1304 auto [params_arr, params_stride0] = get_parameters(
1305 #ifndef USE_CUDNN_RNN_V8_API
1306 handle, rnn, rnn_desc, x_desc, w_desc, weight_buf, include_bias);
1307 #else
1308 handle, rnn, rnn_desc, weight_buf, include_bias);
1309 #endif
1310 MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)},
1311 params{params_arr, params_stride0};
1312
1313 // Copy weights
1314 _viewOrCopyParams(weight, params, /*copy=*/true, allow_type_change);
1315 if (set_orig_weights_to_flat_buf) {
1316 // Update the storage
1317 for (const auto i : c10::irange(weight.size(0))) {
1318 // There is a special case for LSTM with projections and no bias,
1319 // where weight copy is done in 0->0, 1->1, 2->4 layout
1320 if (weight[i].size() == 3 && params[i].size() == 5) {
1321 weight[i][0].set_(params[i][0].view_as(weight[i][0]));
1322 weight[i][1].set_(params[i][1].view_as(weight[i][1]));
1323 weight[i][2].set_(params[i][4].view_as(weight[i][2]));
1324 } else {
1325 for (auto orig_param_it = weight[i].begin(),
1326 new_param_it = params[i].begin();
1327 orig_param_it != weight[i].end() &&
1328 new_param_it != params[i].end();
1329 orig_param_it++, new_param_it++) {
1330 auto orig_param = *orig_param_it, new_param = *new_param_it;
1331 orig_param.set_(new_param.view_as(orig_param));
1332 }
1333 }
1334 }
1335 }
1336
1337 return std::make_tuple(weight_buf, params_arr);
1338 }
1339
1340 } // namespace cudnn_rnn
1341
1342 using namespace cudnn_rnn;
1343
1344 // NB: does inplace update into TensorList
1345 // It would be a relatively simple matter to refactor this into multiple
1346 // functions, only one of which does an inplace update, but we leave this
1347 // for future work
_cudnn_rnn_flatten_weight(TensorList weight_arr,int64_t weight_stride0,int64_t input_size,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,bool fn_bidirectional)1348 Tensor _cudnn_rnn_flatten_weight(
1349 TensorList weight_arr,
1350 int64_t weight_stride0,
1351 int64_t input_size,
1352 int64_t fn_mode,
1353 int64_t fn_hidden_size,
1354 int64_t fn_proj_size,
1355 int64_t fn_num_layers,
1356 bool batch_first,
1357 bool fn_bidirectional) {
1358 // returns flat weight_buf
1359 return std::get<0>(copy_weights_to_flat_buf_views(
1360 weight_arr,
1361 weight_stride0,
1362 input_size,
1363 fn_mode,
1364 fn_hidden_size,
1365 fn_proj_size,
1366 fn_num_layers,
1367 batch_first,
1368 fn_bidirectional,
1369 /*flat_buf_datatype=*/getCudnnDataType(weight_arr[0]),
1370 /*flat_buf_options=*/weight_arr[0].options(),
1371 /*set_orig_weights_to_flat_buf=*/true));
1372 }
1373
1374 const char* WEIGHT_FORMAT_WARN =
1375 "RNN module weights are not part of single contiguous "
1376 "chunk of memory. This means they need to be compacted "
1377 "at every call, possibly greatly increasing memory usage. "
1378 "To compact weights again call flatten_parameters().";
1379
1380 // NB: when fn_batch_sizes is empty, that means no batch sizes was specified
_cudnn_rnn(const Tensor & input_r,TensorList weight,int64_t weight_stride0,const std::optional<Tensor> & weight_buf_r_opt,const Tensor & hx,const std::optional<Tensor> & cx_opt,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const std::optional<Tensor> & fn_dropout_state_opt)1381 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
1382 const Tensor& input_r,
1383 TensorList weight,
1384 int64_t weight_stride0,
1385 const std::optional<Tensor>& weight_buf_r_opt,
1386 const Tensor& hx,
1387 const std::optional<Tensor>& cx_opt,
1388 int64_t fn_mode,
1389 int64_t fn_hidden_size,
1390 int64_t fn_proj_size,
1391 int64_t fn_num_layers,
1392 bool batch_first,
1393 double fn_dropout,
1394 bool fn_train,
1395 bool fn_bidirectional,
1396 IntArrayRef fn_batch_sizes,
1397 const std::optional<Tensor>& fn_dropout_state_opt) {
1398 // See [Note: hacky wrapper removal for optional tensor]
1399 c10::MaybeOwned<Tensor> weight_buf_r_maybe_owned =
1400 at::borrow_from_optional_tensor(weight_buf_r_opt);
1401 const Tensor& weight_buf_r = *weight_buf_r_maybe_owned;
1402 const Tensor& cx = c10::value_or_else(cx_opt, [] { return Tensor(); });
1403 const Tensor& fn_dropout_state =
1404 c10::value_or_else(fn_dropout_state_opt, [] { return Tensor(); });
1405
1406 check_attributes(input_r, weight, {hx, cx}, /*check_dtype=*/true);
1407 auto input = input_r;
1408 auto weight_buf = weight_buf_r;
1409 if (!weight_buf.defined()) {
1410 TORCH_WARN(WEIGHT_FORMAT_WARN);
1411 }
1412 if (fn_dropout_state.defined()) {
1413 auto input_arg = TensorArg(input, "input", 1);
1414 auto dropout_state_arg = TensorArg(fn_dropout_state, "dropout_states", 15);
1415 checkSameGPU("cudnn_rnn", input_arg, dropout_state_arg);
1416 }
1417 RNNParams fn;
1418 auto datatype = getCudnnDataType(input);
1419 #ifndef USE_CUDNN_RNN_V8_API
1420 fn.rnn.set(
1421 fn_mode,
1422 fn_hidden_size,
1423 fn_proj_size,
1424 fn_num_layers,
1425 fn_bidirectional,
1426 promote_rnn_math_type(datatype),
1427 datatype);
1428 #else
1429 auto input_size = input_r.size(-1);
1430 auto packed = fn_batch_sizes.size() != 0;
1431 fn.rnn.set(
1432 fn_mode,
1433 input_size,
1434 packed,
1435 fn_hidden_size,
1436 fn_proj_size,
1437 fn_num_layers,
1438 fn_bidirectional,
1439 promote_rnn_math_type(datatype),
1440 datatype);
1441 #endif
1442 fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
1443 fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
1444
1445 // TODO: Set device to input
1446
1447 if (fn.rnn.mode != CUDNN_LSTM) {
1448 TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
1449 }
1450
1451 // TODO: can batch_first be a wrapper around this function?
1452 auto is_input_packed = fn.tensors.batch_sizes.size() != 0;
1453 if (batch_first && !is_input_packed) {
1454 input = input.transpose(0, 1);
1455 }
1456
1457 auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
1458 auto cell_size = _cell_size(fn.rnn, fn.tensors);
1459 auto output_size = _output_size(fn.rnn, fn.tensors);
1460
1461 TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous");
1462 TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous");
1463
1464 auto x = input.contiguous();
1465 auto output = at::empty(output_size, input.options());
1466 auto hy = at::empty(hidden_size, hx.options());
1467 Tensor cy;
1468 if (cx.defined()) {
1469 cy = at::empty(cell_size, cx.options());
1470 } else {
1471 cy = at::empty(
1472 {0}, hx.options()); // NB: Not allowed to return undefined tensors
1473 }
1474 auto y = output;
1475
1476 auto handle = getCudnnHandle();
1477 cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, true);
1478 fn.rnn.set_algo(algo);
1479 RNNDescriptors descs(fn, handle, x, y, hx, cx);
1480
1481 #ifndef USE_CUDNN_RNN_V8_API
1482 FilterDescriptor w_desc;
1483 #endif
1484 if (!weight_buf.defined()) {
1485 #ifndef USE_CUDNN_RNN_V8_API
1486 auto num_weights =
1487 get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], datatype);
1488 #else
1489 auto num_weights = get_num_weights(handle, descs.rnn_desc, datatype);
1490 #endif
1491 weight_buf = at::empty(num_weights, x.options());
1492 #ifndef USE_CUDNN_RNN_V8_API
1493 w_desc.set(weight_buf, 3);
1494 #endif
1495 weight_buf.zero_();
1496 #ifndef USE_CUDNN_RNN_V8_API
1497 auto [params, params_stride0] = get_parameters(
1498 handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, weight_buf);
1499 #else
1500 auto [params, params_stride0] =
1501 get_parameters(handle, fn.rnn, descs.rnn_desc, weight_buf);
1502 #endif
1503 _copyParams(
1504 MatrixRef<Tensor>{weight, static_cast<size_t>(weight_stride0)},
1505 MatrixRef<Tensor>{params, params_stride0});
1506 } else {
1507 #ifndef USE_CUDNN_RNN_V8_API
1508 w_desc.set(weight_buf, 3);
1509 #endif
1510 }
1511
1512 TORCH_CHECK(
1513 !cx.defined() || cx.sizes().equals(cell_size),
1514 "Expected cell size ",
1515 IntArrayRef{cell_size},
1516 ", got ",
1517 cx.sizes());
1518 size_t workspace_size;
1519 #ifndef USE_CUDNN_RNN_V8_API
1520 auto x_descs_arr = descs.get_x_descs();
1521 auto y_descs_arr = descs.get_y_descs();
1522 #else
1523 auto& x_descs_arr = descs.x_descs;
1524 auto& y_descs_arr = descs.y_descs;
1525 #endif
1526 #ifndef USE_CUDNN_RNN_V8_API
1527 AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
1528 handle,
1529 descs.rnn_desc.desc(),
1530 fn.tensors.seq_length,
1531 x_descs_arr.data(),
1532 &workspace_size));
1533 #endif
1534 Tensor workspace;
1535 Tensor reserve;
1536 // NB: Previously, the test was for fn.requires_grad, but we don't have
1537 // this information. Use 'train' as a proxy.
1538 if (fn_train) {
1539 size_t reserve_size;
1540 #ifndef USE_CUDNN_RNN_V8_API
1541 AT_CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(
1542 handle,
1543 descs.rnn_desc.desc(),
1544 fn.tensors.seq_length,
1545 x_descs_arr.data(),
1546 &reserve_size));
1547 #else
1548 AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes(
1549 handle,
1550 descs.rnn_desc.desc(),
1551 CUDNN_FWD_MODE_TRAINING,
1552 x_descs_arr.desc(),
1553 &workspace_size,
1554 &reserve_size));
1555 #endif
1556 workspace = at::empty(workspace_size, input.options().dtype(kByte));
1557 reserve = at::empty(reserve_size, input.options().dtype(kByte));
1558 #ifndef USE_CUDNN_RNN_V8_API
1559 AT_CUDNN_CHECK(cudnnRNNForwardTraining(
1560 handle,
1561 descs.rnn_desc.desc(),
1562 fn.tensors.seq_length,
1563 x_descs_arr.data(),
1564 x.data_ptr(),
1565 descs.hx_desc.desc(),
1566 hx.data_ptr(),
1567 descs.cx_desc.desc(),
1568 cx.defined() ? cx.data_ptr() : nullptr,
1569 w_desc.desc(),
1570 weight_buf.data_ptr(),
1571 y_descs_arr.data(),
1572 y.data_ptr(),
1573 descs.hy_desc.desc(),
1574 hy.data_ptr(),
1575 descs.cy_desc.desc(),
1576 cy.defined() ? cy.data_ptr() : nullptr,
1577 workspace.data_ptr(),
1578 workspace.size(0),
1579 reserve.mutable_data_ptr(),
1580 reserve.size(0)));
1581 #else
1582 AT_CUDNN_CHECK(cudnnRNNForward(
1583 handle,
1584 descs.rnn_desc.desc(),
1585 CUDNN_FWD_MODE_TRAINING,
1586 nullptr,
1587 x_descs_arr.desc(),
1588 x.data_ptr(),
1589 y_descs_arr.desc(),
1590 y.data_ptr(),
1591 descs.hx_desc.desc(),
1592 hx.data_ptr(),
1593 hy.data_ptr(),
1594 descs.cx_desc.desc(),
1595 cx.defined() ? cx.data_ptr() : nullptr,
1596 cy.defined() ? cy.data_ptr() : nullptr,
1597 weight_buf.numel() * weight_buf.element_size(),
1598 weight_buf.data_ptr(),
1599 workspace.size(0),
1600 workspace.data_ptr(),
1601 reserve.size(0),
1602 reserve.mutable_data_ptr()));
1603 #endif
1604 } else { // inference
1605 #ifdef USE_CUDNN_RNN_V8_API
1606 AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes(
1607 handle,
1608 descs.rnn_desc.desc(),
1609 CUDNN_FWD_MODE_INFERENCE,
1610 x_descs_arr.desc(),
1611 &workspace_size,
1612 NULL));
1613 #endif
1614 workspace = at::empty(workspace_size, input.options().dtype(kByte));
1615 reserve = at::empty({0}, input.options().dtype(kByte));
1616 #ifndef USE_CUDNN_RNN_V8_API
1617 AT_CUDNN_CHECK(cudnnRNNForwardInference(
1618 handle,
1619 descs.rnn_desc.desc(),
1620 fn.tensors.seq_length,
1621 x_descs_arr.data(),
1622 x.data_ptr(),
1623 descs.hx_desc.desc(),
1624 hx.data_ptr(),
1625 descs.cx_desc.desc(),
1626 cx.defined() ? cx.data_ptr() : nullptr,
1627 w_desc.desc(),
1628 weight_buf.data_ptr(),
1629 y_descs_arr.data(),
1630 y.data_ptr(),
1631 descs.hy_desc.desc(),
1632 hy.data_ptr(),
1633 descs.cy_desc.desc(),
1634 cy.defined() ? cy.data_ptr() : nullptr,
1635 workspace.data_ptr(),
1636 workspace.size(0)));
1637 #else
1638 AT_CUDNN_CHECK(cudnnRNNForward(
1639 handle,
1640 descs.rnn_desc.desc(),
1641 CUDNN_FWD_MODE_INFERENCE,
1642 nullptr,
1643 x_descs_arr.desc(),
1644 x.data_ptr(),
1645 y_descs_arr.desc(),
1646 y.data_ptr(),
1647 descs.hx_desc.desc(),
1648 hx.data_ptr(),
1649 hy.data_ptr(),
1650 descs.cx_desc.desc(),
1651 cx.defined() ? cx.data_ptr() : nullptr,
1652 cy.defined() ? cy.data_ptr() : nullptr,
1653 weight_buf.numel() * weight_buf.element_size(),
1654 weight_buf.data_ptr(),
1655 workspace.size(0),
1656 workspace.data_ptr(),
1657 reserve.size(0),
1658 reserve.mutable_data_ptr()));
1659 #endif
1660 }
1661
1662 if (batch_first && !is_input_packed) {
1663 output.transpose_(0, 1);
1664 }
1665
1666 return std::make_tuple(output, hy, cy, reserve, weight_buf);
1667 }
1668
_cudnn_rnn_backward_input(const Tensor & input_r,const Tensor & weight_buf,const Tensor & hx,const Tensor & cx,const Tensor & output_r,const Tensor & grad_output_r,const Tensor & grad_hy,const Tensor & grad_cy,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const Tensor & fn_dropout_state,const Tensor & fn_reserve,std::array<bool,3> output_mask)1669 std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
1670 const Tensor& input_r,
1671 const Tensor& weight_buf,
1672 const Tensor& hx,
1673 const Tensor& cx,
1674 const Tensor& output_r,
1675 const Tensor& grad_output_r,
1676 const Tensor& grad_hy,
1677 const Tensor& grad_cy,
1678 int64_t fn_mode,
1679 int64_t fn_hidden_size,
1680 int64_t fn_proj_size,
1681 int64_t fn_num_layers,
1682 bool batch_first,
1683 double fn_dropout,
1684 bool fn_train,
1685 bool fn_bidirectional,
1686 IntArrayRef fn_batch_sizes,
1687 const Tensor& fn_dropout_state,
1688 const Tensor& fn_reserve,
1689 std::array<bool, 3> output_mask) {
1690 auto input = input_r;
1691 auto grad_output = grad_output_r;
1692 auto output = output_r;
1693
1694 RNNParams fn;
1695 auto datatype = getCudnnDataType(input);
1696 #ifndef USE_CUDNN_RNN_V8_API
1697 fn.rnn.set(
1698 fn_mode,
1699 fn_hidden_size,
1700 fn_proj_size,
1701 fn_num_layers,
1702 fn_bidirectional,
1703 promote_rnn_math_type(datatype),
1704 datatype);
1705 #else
1706 auto cudnn_input_size = input_r.size(-1);
1707 auto packed = fn_batch_sizes.size() != 0;
1708 fn.rnn.set(
1709 fn_mode,
1710 cudnn_input_size,
1711 packed,
1712 fn_hidden_size,
1713 fn_proj_size,
1714 fn_num_layers,
1715 fn_bidirectional,
1716 promote_rnn_math_type(datatype),
1717 datatype);
1718 #endif
1719 fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
1720 fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
1721
1722 // TODO: Set device to input
1723 auto handle = getCudnnHandle();
1724
1725 if (fn.rnn.mode != CUDNN_LSTM) {
1726 TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
1727 }
1728
1729 auto is_input_packed = fn_batch_sizes.size() != 0;
1730 if (batch_first && !is_input_packed) {
1731 input = input.transpose(0, 1);
1732 grad_output = grad_output.transpose(0, 1);
1733 output = output.transpose(0, 1);
1734 }
1735
1736 auto input_size = _input_size(fn.tensors);
1737 auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
1738 auto cell_size = _cell_size(fn.rnn, fn.tensors);
1739 auto output_size = _output_size(fn.rnn, fn.tensors);
1740
1741 TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous");
1742 TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous");
1743
1744 auto x = input.contiguous();
1745 auto dy = grad_output.contiguous();
1746 auto y = output;
1747 auto w = weight_buf;
1748 auto dx = at::empty(
1749 input.sizes(), input.options()); // TODO: more compact way of saying this
1750 auto dhy = grad_hy.contiguous().view(hidden_size);
1751 auto dcy =
1752 grad_cy.defined() ? grad_cy.contiguous().view(cell_size) : Tensor();
1753 auto dhx = at::empty(hidden_size, hx.options());
1754 TORCH_INTERNAL_ASSERT(
1755 cx.defined() || !output_mask[2],
1756 "illegally required grad of cx for non-LSTM RNN");
1757 auto dcx = cx.defined() ? at::empty(cell_size, cx.options()) : Tensor();
1758
1759 TORCH_CHECK(
1760 fn_train, "cudnn RNN backward can only be called in training mode");
1761
1762 TORCH_CHECK(
1763 input.sizes().equals(input_size),
1764 "Expected input size ",
1765 IntArrayRef{input_size},
1766 ", got ",
1767 input.sizes());
1768 TORCH_CHECK(
1769 output.sizes().equals(output_size),
1770 "Expected output size ",
1771 IntArrayRef{output_size},
1772 ", got ",
1773 output.sizes());
1774
1775 TORCH_CHECK(
1776 !hx.defined() || hx.sizes().equals(hidden_size),
1777 "Expected hidden size ",
1778 IntArrayRef{hidden_size},
1779 ", got ",
1780 hx.sizes());
1781 TORCH_CHECK(
1782 !cx.defined() || cx.sizes().equals(cell_size),
1783 "Expected cell size ",
1784 IntArrayRef{cell_size},
1785 ", got ",
1786 cx.sizes());
1787 TORCH_CHECK(
1788 !dhy.defined() || dhy.sizes().equals(hidden_size),
1789 "Expected d_hidden size ",
1790 IntArrayRef{hidden_size},
1791 ", got ",
1792 dhy.sizes());
1793 TORCH_CHECK(
1794 !dcy.defined() || dcy.sizes().equals(cell_size),
1795 "Expected d_cell size ",
1796 IntArrayRef{cell_size},
1797 ", got ",
1798 dcy.sizes());
1799
1800 TORCH_CHECK(
1801 dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()),
1802 "Gradients aren't CUDA tensors");
1803
1804 cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, false);
1805 fn.rnn.set_algo(algo);
1806 RNNDescriptors descs(fn, handle, x, y, hx, cx);
1807
1808 #ifndef USE_CUDNN_RNN_V8_API
1809 FilterDescriptor w_desc;
1810 w_desc.set(weight_buf, 3);
1811 #endif
1812
1813 size_t workspace_size;
1814 #ifndef USE_CUDNN_RNN_V8_API
1815 auto x_descs_arr = descs.get_x_descs();
1816 auto y_descs_arr = descs.get_y_descs();
1817 AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
1818 handle,
1819 descs.rnn_desc.desc(),
1820 fn.tensors.seq_length,
1821 x_descs_arr.data(),
1822 &workspace_size));
1823 #else
1824 auto& x_descs_arr = descs.x_descs;
1825 auto& y_descs_arr = descs.y_descs;
1826 AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes(
1827 handle,
1828 descs.rnn_desc.desc(),
1829 CUDNN_FWD_MODE_TRAINING,
1830 x_descs_arr.desc(),
1831 &workspace_size,
1832 NULL));
1833 #endif
1834 // TODO: put this in the correct device???
1835 Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
1836 #ifndef USE_CUDNN_RNN_V8_API
1837 AT_CUDNN_CHECK(cudnnRNNBackwardData(
1838 handle,
1839 descs.rnn_desc.desc(),
1840 fn.tensors.seq_length,
1841 y_descs_arr.data(),
1842 y.data_ptr(),
1843 y_descs_arr.data(),
1844 dy.data_ptr(),
1845 descs.hy_desc.desc(),
1846 dhy.data_ptr(),
1847 descs.cy_desc.desc(),
1848 cx.defined() ? dcy.data_ptr() : nullptr,
1849 w_desc.desc(),
1850 w.data_ptr(),
1851 descs.hx_desc.desc(),
1852 hx.data_ptr(),
1853 descs.cx_desc.desc(),
1854 cx.defined() ? cx.data_ptr() : nullptr,
1855 x_descs_arr.data(),
1856 dx.data_ptr(),
1857 descs.hx_desc.desc(),
1858 dhx.data_ptr(),
1859 descs.cx_desc.desc(),
1860 cx.defined() ? dcx.data_ptr() : nullptr,
1861 workspace.data_ptr(),
1862 workspace.size(0),
1863 fn_reserve.data_ptr(),
1864 fn_reserve.size(0)));
1865 #else
1866 AT_CUDNN_CHECK(cudnnRNNBackwardData_v8(
1867 handle,
1868 descs.rnn_desc.desc(),
1869 nullptr,
1870 y_descs_arr.desc(),
1871 y.data_ptr(),
1872 dy.data_ptr(),
1873 x_descs_arr.desc(),
1874 dx.data_ptr(),
1875 descs.hx_desc.desc(),
1876 hx.data_ptr(),
1877 dhy.data_ptr(),
1878 dhx.data_ptr(),
1879 descs.cx_desc.desc(),
1880 cx.defined() ? cx.data_ptr() : nullptr,
1881 cx.defined() ? dcy.data_ptr() : nullptr,
1882 cx.defined() ? dcx.data_ptr() : nullptr,
1883 weight_buf.numel() * weight_buf.element_size(),
1884 weight_buf.data_ptr(),
1885 workspace.size(0),
1886 workspace.data_ptr(),
1887 fn_reserve.size(0),
1888 fn_reserve.data_ptr()));
1889 #endif
1890 if (batch_first && !is_input_packed) {
1891 dx = dx.transpose_(0, 1);
1892 }
1893
1894 return std::make_tuple(dx, dhx, dcx);
1895 }
1896
1897 // NB: This MUST BE CALLED AFTER _cudnn_rnn_backward_input.
1898 // We'll give a user friendly combined function...
_cudnn_rnn_backward_weight(const Tensor & input_r,TensorList weight_arr,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const Tensor & cx,const Tensor & output_r,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const Tensor & fn_dropout_state,const Tensor & fn_reserve)1899 std::vector<Tensor> _cudnn_rnn_backward_weight(
1900 // TODO: I think tensor geometry sufficient for weight_buf/weight
1901 const Tensor& input_r,
1902 TensorList weight_arr,
1903 int64_t weight_stride0,
1904 const Tensor& weight_buf,
1905 const Tensor& hx,
1906 const Tensor& cx,
1907 const Tensor& output_r,
1908 int64_t fn_mode,
1909 int64_t fn_hidden_size,
1910 int64_t fn_proj_size,
1911 int64_t fn_num_layers,
1912 bool batch_first,
1913 double fn_dropout,
1914 bool fn_train,
1915 bool fn_bidirectional,
1916 IntArrayRef fn_batch_sizes,
1917 const Tensor& fn_dropout_state,
1918 const Tensor& fn_reserve) {
1919 MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)};
1920 auto input = input_r;
1921 auto output = output_r;
1922
1923 RNNParams fn;
1924 auto datatype = getCudnnDataType(input);
1925 #ifndef USE_CUDNN_RNN_V8_API
1926 fn.rnn.set(
1927 fn_mode,
1928 fn_hidden_size,
1929 fn_proj_size,
1930 fn_num_layers,
1931 fn_bidirectional,
1932 promote_rnn_math_type(datatype),
1933 datatype);
1934 #else
1935 auto cudnn_input_size = input_r.size(-1);
1936 auto packed = fn_batch_sizes.size() != 0;
1937 fn.rnn.set(
1938 fn_mode,
1939 cudnn_input_size,
1940 packed,
1941 fn_hidden_size,
1942 fn_proj_size,
1943 fn_num_layers,
1944 fn_bidirectional,
1945 promote_rnn_math_type(datatype),
1946 datatype);
1947 #endif
1948 fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
1949 fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
1950
1951 auto handle = getCudnnHandle();
1952
1953 if (fn.rnn.mode != CUDNN_LSTM) {
1954 TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
1955 }
1956
1957 auto is_input_packed = fn_batch_sizes.size() != 0;
1958 if (batch_first && !is_input_packed) {
1959 input = input.transpose(0, 1);
1960 output = output.transpose(0, 1);
1961 }
1962
1963 auto input_size = _input_size(fn.tensors);
1964 auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
1965
1966 TORCH_CHECK(
1967 fn_train, "cudnn RNN backward can only be called in training mode");
1968
1969 TORCH_CHECK(
1970 input.sizes().equals(input_size),
1971 "Expected input size ",
1972 IntArrayRef{input_size},
1973 ", got ",
1974 input.sizes());
1975 TORCH_CHECK(
1976 !hx.defined() || hx.sizes().equals(hidden_size),
1977 "Expected hidden size ",
1978 IntArrayRef{hidden_size},
1979 ", got ",
1980 hx.sizes());
1981
1982 // TODO: the above were the only checks in rnn.py, but it doesn't seem
1983 // like these checks are enough
1984
1985 TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous");
1986 TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous");
1987
1988 auto x = input.contiguous();
1989 const auto& y = output;
1990 auto dw = at::zeros(weight_buf.sizes(), weight_buf.options());
1991
1992 cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, false);
1993 fn.rnn.set_algo(algo);
1994 RNNDescriptors descs(fn, handle, x, y, hx, cx);
1995
1996 #ifndef USE_CUDNN_RNN_V8_API
1997 FilterDescriptor w_desc;
1998 w_desc.set(weight_buf, 3);
1999 #endif
2000
2001 size_t workspace_size;
2002 #ifndef USE_CUDNN_RNN_V8_API
2003 auto x_descs_arr = descs.get_x_descs();
2004 auto y_descs_arr = descs.get_y_descs();
2005 AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
2006 handle,
2007 descs.rnn_desc.desc(),
2008 fn.tensors.seq_length,
2009 x_descs_arr.data(),
2010 &workspace_size));
2011 #else
2012 auto& x_descs_arr = descs.x_descs;
2013 auto& y_descs_arr = descs.y_descs;
2014 AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes(
2015 handle,
2016 descs.rnn_desc.desc(),
2017 CUDNN_FWD_MODE_TRAINING,
2018 x_descs_arr.desc(),
2019 &workspace_size,
2020 NULL));
2021 #endif
2022 Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
2023 #ifndef USE_CUDNN_RNN_V8_API
2024 AT_CUDNN_CHECK(cudnnRNNBackwardWeights(
2025 handle,
2026 descs.rnn_desc.desc(),
2027 fn.tensors.seq_length,
2028 x_descs_arr.data(),
2029 x.data_ptr(),
2030 descs.hx_desc.desc(),
2031 hx.data_ptr(),
2032 y_descs_arr.data(),
2033 y.data_ptr(),
2034 workspace.data_ptr(),
2035 workspace.size(0),
2036 w_desc.desc(),
2037 dw.data_ptr(),
2038 fn_reserve.data_ptr(),
2039 fn_reserve.size(0)));
2040 #else
2041 AT_CUDNN_CHECK(cudnnRNNBackwardWeights_v8(
2042 handle,
2043 descs.rnn_desc.desc(),
2044 CUDNN_WGRAD_MODE_ADD,
2045 nullptr,
2046 x_descs_arr.desc(),
2047 x.data_ptr(),
2048 descs.hx_desc.desc(),
2049 hx.data_ptr(),
2050 y_descs_arr.desc(),
2051 y.data_ptr(),
2052 weight_buf.numel() * weight_buf.element_size(),
2053 dw.data_ptr(),
2054 workspace.size(0),
2055 workspace.data_ptr(),
2056 fn_reserve.size(0),
2057 fn_reserve.data_ptr()));
2058 #endif
2059
2060 #ifndef USE_CUDNN_RNN_V8_API
2061 auto [grad_params_arr, grad_params_stride0] = get_parameters(
2062 handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, dw);
2063 #else
2064 auto [grad_params_arr, grad_params_stride0] =
2065 get_parameters(handle, fn.rnn, descs.rnn_desc, dw);
2066 #endif
2067 if (grad_params_stride0 == static_cast<size_t>(weight_stride0)) {
2068 _viewParams(
2069 MatrixRef<Tensor>{grad_params_arr, grad_params_stride0},
2070 MatrixRef<Tensor>{weight_arr, static_cast<size_t>(weight_stride0)});
2071 return grad_params_arr;
2072 } else {
2073 std::vector<Tensor> grad_weight_arr;
2074 grad_weight_arr.reserve(weight.numel());
2075 for (const auto& w : weight_arr) {
2076 grad_weight_arr.emplace_back(at::empty(w.sizes(), w.options()));
2077 }
2078 _copyParams(
2079 MatrixRef<Tensor>{grad_params_arr, grad_params_stride0},
2080 MatrixRef<Tensor>{
2081 grad_weight_arr, static_cast<size_t>(weight_stride0)});
2082 return grad_weight_arr;
2083 }
2084 }
2085
2086 // We need this dispatcher because _cudnn_rnn_backward_weight has a stringent
2087 // ordering requirement with _cudnn_rnn_backward_input
_cudnn_rnn_backward(const Tensor & input,TensorList weight,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const std::optional<Tensor> & cx_opt,const Tensor & output,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,int64_t mode,int64_t hidden_size,int64_t proj_size,int64_t num_layers,bool batch_first,double dropout,bool train,bool bidirectional,IntArrayRef batch_sizes,const std::optional<Tensor> & dropout_state_opt,const Tensor & reserve,std::array<bool,4> output_mask)2088 std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> _cudnn_rnn_backward(
2089 const Tensor& input,
2090 TensorList weight,
2091 int64_t weight_stride0,
2092 const Tensor& weight_buf,
2093 const Tensor& hx,
2094 const std::optional<Tensor>& cx_opt,
2095 const Tensor& output,
2096 const std::optional<Tensor>& grad_output_r_opt,
2097 const std::optional<Tensor>& grad_hy_r_opt,
2098 const std::optional<Tensor>& grad_cy_r_opt,
2099 int64_t mode,
2100 int64_t hidden_size,
2101 int64_t proj_size,
2102 int64_t num_layers,
2103 bool batch_first,
2104 double dropout,
2105 bool train,
2106 bool bidirectional,
2107 IntArrayRef batch_sizes,
2108 const std::optional<Tensor>& dropout_state_opt,
2109 const Tensor& reserve,
2110 std::array<bool, 4> output_mask) {
2111 // See [Note: hacky wrapper removal for optional tensor]
2112 c10::MaybeOwned<Tensor> cx_maybe_owned =
2113 at::borrow_from_optional_tensor(cx_opt);
2114 const Tensor& cx = *cx_maybe_owned;
2115 const Tensor& grad_output_r =
2116 c10::value_or_else(grad_output_r_opt, [] { return Tensor(); });
2117 const Tensor& grad_hy_r =
2118 c10::value_or_else(grad_hy_r_opt, [] { return Tensor(); });
2119 const Tensor& grad_cy_r =
2120 c10::value_or_else(grad_cy_r_opt, [] { return Tensor(); });
2121 const Tensor& dropout_state =
2122 c10::value_or_else(dropout_state_opt, [] { return Tensor(); });
2123
2124 if (!grad_output_r.defined() && !grad_hy_r.defined() &&
2125 !grad_cy_r.defined()) {
2126 return std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>>(
2127 Tensor(), Tensor(), Tensor(), std::vector<Tensor>(weight.size()));
2128 }
2129
2130 auto grad_output = grad_output_r.defined()
2131 ? grad_output_r
2132 : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
2133 auto grad_hy = grad_hy_r.defined()
2134 ? grad_hy_r
2135 : at::zeros_like(hx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
2136 auto grad_cy = cx.defined()
2137 ? (grad_cy_r.defined()
2138 ? grad_cy_r
2139 : at::zeros_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT))
2140 : grad_cy_r;
2141
2142 // NB: unconditionally compute this gradient, because it mutates reserve
2143 auto [dx, dhx, dcx] = at::native::_cudnn_rnn_backward_input(
2144 input,
2145 weight_buf,
2146 hx,
2147 cx,
2148 output,
2149 grad_output,
2150 grad_hy,
2151 grad_cy,
2152 mode,
2153 hidden_size,
2154 proj_size,
2155 num_layers,
2156 batch_first,
2157 dropout,
2158 train,
2159 bidirectional,
2160 batch_sizes,
2161 dropout_state,
2162 reserve,
2163 {output_mask[0], output_mask[1], output_mask[2]});
2164 std::vector<Tensor> dw;
2165 if (output_mask[3]) {
2166 dw = at::native::_cudnn_rnn_backward_weight(
2167 input,
2168 weight,
2169 weight_stride0,
2170 weight_buf,
2171 hx,
2172 cx,
2173 output,
2174 mode,
2175 hidden_size,
2176 proj_size,
2177 num_layers,
2178 batch_first,
2179 dropout,
2180 train,
2181 bidirectional,
2182 batch_sizes,
2183 dropout_state,
2184 reserve);
2185 }
2186 return std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>>{
2187 dx, dhx, dcx, dw};
2188 }
2189
2190 // TODO: I am not sure if we actually need the 'dropout' and 'train' parameters
2191 // to initialize just the state tensor
2192 //
2193 // NB: You can have any color you like, as long as it's a CUDA byte
2194 // tensor. Why does this function take a TensorOptions at all in that case?
2195 // This is a factory function: it produces tensors but takes no tensors
2196 // as input. The codegen currently assumes that ALL factory functions
2197 // take TensorOptions, so it's just a lot easier for this function to
2198 // be bound if it also does it.
_cudnn_init_dropout_state(double dropout,bool train,int64_t dropout_seed,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)2199 Tensor _cudnn_init_dropout_state(
2200 double dropout,
2201 bool train,
2202 int64_t dropout_seed,
2203 std::optional<ScalarType> dtype,
2204 std::optional<Layout> layout,
2205 std::optional<Device> device,
2206 std::optional<bool> pin_memory) {
2207 // See [Note: hacky wrapper removal for TensorOptions]
2208 TensorOptions options =
2209 TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
2210 pin_memory);
2211
2212 auto handle = getCudnnHandle();
2213 DropoutDescriptor dropout_desc;
2214 auto dropout_p = train ? dropout : 0;
2215 dropout_desc.initialize_rng(handle, dropout_p, dropout_seed, options);
2216 return dropout_desc.state;
2217 }
2218
2219 ////////////////////////////////////////////////////////////////////////////////
2220 // CUDA dispatch for the generic RNN ops (at::lstm, at::gru, ...)
2221 ////////////////////////////////////////////////////////////////////////////////
2222
2223 namespace {
2224
2225 // Helpers for working with different hidden types.
unpack_hidden(const Tensor & hidden)2226 std::tuple<Tensor, Tensor> unpack_hidden(const Tensor& hidden) {
2227 return std::make_tuple(hidden, at::Tensor{});
2228 }
2229
unpack_hidden(const std::tuple<Tensor,Tensor> & hidden)2230 std::tuple<Tensor, Tensor> unpack_hidden(
2231 const std::tuple<Tensor, Tensor>& hidden) {
2232 return hidden;
2233 }
2234
2235 template <typename hidden_type>
pack_hidden(const Tensor & hx,const Tensor & cx)2236 hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) {
2237 static_assert(
2238 false && sizeof(hidden_type),
2239 "pack_hidden not implemented for this type");
2240 }
2241
2242 template <>
pack_hidden(const Tensor & hx,const Tensor & cx)2243 Tensor pack_hidden<Tensor>(const Tensor& hx, const Tensor& cx) {
2244 AT_ASSERT(cx.numel() == 0);
2245 return hx;
2246 }
2247
2248 template <>
pack_hidden(const Tensor & hx,const Tensor & cx)2249 std::tuple<Tensor, Tensor> pack_hidden<std::tuple<Tensor, Tensor>>(
2250 const Tensor& hx,
2251 const Tensor& cx) {
2252 return std::make_tuple(hx, cx);
2253 }
2254
2255 /**
2256 * Note [DropoutState and CUDA graph capture]
2257 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2258 * (1) Telling a capturing stream to wait on an event recorded in a
2259 non-capturing stream is an error.
2260 * (2) Telling a non-capturing stream to wait on an event recorded during
2261 capture is also an error.
2262 *
2263 * So DropoutState's usage syncs could error if an RNN with dropout is called in
2264 an uncaptured region
2265 * then called in a captured region (triggering 1), or called in a captured
2266 region then called # in an uncaptured region (triggering 2).
2267 *
2268 * To prevent 1 and 2, lock() only syncs on the last usage event if it was
2269 recorded in the same
2270 * capture state as the current state (which also means the same graph, if
2271 capture is in progress).
2272 *
2273 * The solution should be safe as long as capture obeys the following
2274 restrictions:
2275 * - Only one capture may be underway at a time in a given process.
2276 * - While a capture is underway, no calls to eager ops on noncapturing streams
2277 (on any thread)
2278 * may interleave with the captured ops.
2279 *
2280 * TODO: As people experiment with capture, keep an eye out for use cases that
2281 might need to
2282 * relax those restrictions.
2283 *
2284 * See https://github.com/pytorch/pytorch/pull/56433 for more discussion.
2285 */
2286
2287 struct DropoutState {
2288 // Both buffer and event are lazily instantiated when a dropout state is
2289 // needed for the first time. Note that in this case needed != used, as we
2290 // don't need a buffer to e.g. run RNNs in test mode.
2291 at::Tensor buffer;
2292 std::optional<cuda::CUDAEvent> event;
2293 std::mutex mutex;
2294 #if !defined(USE_ROCM)
2295 // cudaStreamGetCaptureInfo will never give back a capture id of 0, so 0 can
2296 // serve as a sentinel value that capture was not underway.
2297 cuda::CaptureId_t capture_id_last_lock = 0;
2298 cuda::CaptureId_t capture_id_last_unlock = 0;
2299 #endif
2300
2301 // Every time we use a dropout state, we need to synchronize with its event,
2302 // to make sure all previous uses finish running before this one starts. Once
2303 // we're done, we record the event to allow others to synchronize with this
2304 // kernel. Those events are really needed only for inter-stream sync on a
2305 // single GPU. I doubt anyone will want to run cuDNN RNNs in parallel on a
2306 // single GPU, so they should end up being complete no-ops.
lockat::native::__anon58f427e00811::DropoutState2307 void lock() {
2308 // NB: We can't ignore the lock even when event is undefined, because
2309 // someone could then define it before we get to unlock().
2310 mutex.lock();
2311 if (event) {
2312 #if !defined(USE_ROCM)
2313 // See Note [DropoutState and CUDA graph capture]
2314 cudaStreamCaptureStatus status;
2315 AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
2316 cuda::getCurrentCUDAStream(), &status, &capture_id_last_lock));
2317 if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
2318 capture_id_last_lock = 0;
2319 }
2320 if (capture_id_last_lock == capture_id_last_unlock) {
2321 event->block(cuda::getCurrentCUDAStream());
2322 }
2323 #else
2324 event->block(cuda::getCurrentCUDAStream());
2325 #endif
2326 }
2327 }
2328
unlockat::native::__anon58f427e00811::DropoutState2329 void unlock() {
2330 if (event) {
2331 event->record();
2332 #if !defined(USE_ROCM)
2333 // See Note [DropoutState and CUDA graph capture]
2334 cudaStreamCaptureStatus status;
2335 AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
2336 cuda::getCurrentCUDAStream(), &status, &capture_id_last_unlock));
2337 if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
2338 capture_id_last_unlock = 0;
2339 }
2340 TORCH_INTERNAL_ASSERT(capture_id_last_unlock == capture_id_last_lock);
2341 #endif
2342 }
2343 mutex.unlock();
2344 }
2345 };
2346
get_dropout_state(double dropout_p,bool train,TensorOptions options)2347 DropoutState& get_dropout_state(
2348 double dropout_p,
2349 bool train,
2350 TensorOptions options) {
2351 // Each state is slightly over 2MB and initialized lazily, so it's fine to
2352 // cache them.
2353 static std::vector<DropoutState> dropout_state_cache{
2354 static_cast<size_t>(cuda::getNumGPUs())};
2355 static std::mutex state_cache_mut;
2356
2357 AT_ASSERT(options.device().is_cuda());
2358 auto device = options.device().index();
2359
2360 std::unique_lock<std::mutex> lock{state_cache_mut};
2361 auto& state = dropout_state_cache.at(device);
2362 if (train && dropout_p > 0) {
2363 const auto& gen =
2364 at::detail::getCUDAHooks().getDefaultCUDAGenerator(device);
2365 auto gen_impl = gen.get<at::CUDAGeneratorImpl>();
2366 bool reset_rnn_state = gen_impl->reset_rnn_state();
2367 if (!state.buffer.defined() || reset_rnn_state) {
2368 std::unique_lock<std::mutex> lock{state.mutex};
2369 int64_t seed =
2370 at::empty({}, options.dtype(at::kLong)).random_(gen).item<int64_t>();
2371 state.buffer = at::_cudnn_init_dropout_state(
2372 dropout_p, train, seed, options.dtype(at::kByte));
2373 // NB: CUDA binds the event to a device at creation time, so we can
2374 // initialize it only now, when we know we're on the correct device.
2375 if (!state.event.has_value()) {
2376 state.event.emplace();
2377 }
2378 }
2379 }
2380 return state;
2381 }
2382
try_get_weight_buf(const Tensor & input,TensorList parameters,bool has_biases,cudnnRNNMode_t mode,c10::SymInt hidden_size,c10::SymInt proj_size,int64_t num_layers,bool bidirectional)2383 Tensor try_get_weight_buf(
2384 const Tensor& input,
2385 TensorList parameters,
2386 bool has_biases,
2387 cudnnRNNMode_t mode,
2388 c10::SymInt hidden_size,
2389 c10::SymInt proj_size,
2390 int64_t num_layers,
2391 bool bidirectional) {
2392 // Prepare all relevant descriptors
2393 auto handle = getCudnnHandle();
2394 auto& any_param = parameters.at(0);
2395 auto datatype = getCudnnDataType(any_param);
2396
2397 // Something very naughty is happening here. try_get_weight_buf
2398 // is called from _cudnn_impl, which is a *composite*. In other words,
2399 // inside the composite function we need to query cudnn to figure out how big
2400 // the weight buf actually is going to be. This clearly cannot be done
2401 // symbolically. For now, we insert guards here; but once we have the black
2402 // box handling for dynamic shapes, we could also hypothetically infer out
2403 // the relationships
2404 RNNDescriptorParams rnn;
2405 #ifndef USE_CUDNN_RNN_V8_API
2406 rnn.set(
2407 mode,
2408 hidden_size.guard_int(__FILE__, __LINE__),
2409 proj_size.guard_int(__FILE__, __LINE__),
2410 num_layers,
2411 bidirectional,
2412 promote_rnn_math_type(datatype),
2413 datatype);
2414 #else
2415 auto cudnn_input_size = input.size(-1);
2416 auto packed = false; // eqy: bogus as we do not know if the input is packed
2417 // here again, it should also not affect the weights
2418 rnn.set(
2419 mode,
2420 cudnn_input_size,
2421 packed,
2422 hidden_size.guard_int(__FILE__, __LINE__),
2423 proj_size.guard_int(__FILE__, __LINE__),
2424 num_layers,
2425 bidirectional,
2426 promote_rnn_math_type(datatype),
2427 datatype);
2428 #endif
2429 RNNDescriptor rnn_desc = rnn.descriptor(handle);
2430
2431 TensorGeometry x_geom({1, input.sym_size(-1).guard_int(__FILE__, __LINE__)});
2432 TensorDescriptor x_desc;
2433 // datatype for x_desc comes from any_param, not input.
2434 // try_get_weight_buf's job is to check "is the weight buffer correctly laid
2435 // out for us to run it with input of the same datatype?"
2436 x_desc.set(datatype, x_geom.sizes(), x_geom.strides(), 5);
2437
2438 #ifndef USE_CUDNN_RNN_V8_API
2439 auto num_params = get_num_weights(handle, rnn_desc, x_desc, datatype);
2440 #else
2441 auto num_params = get_num_weights(handle, rnn_desc, datatype);
2442 #endif
2443
2444 // Try to get parameter storage
2445 auto param_storage = any_param.storage();
2446 auto weight_buf = at::empty({0}, any_param.options()).set_(param_storage);
2447 if (weight_buf.size(0) < num_params) {
2448 return {};
2449 } else if (weight_buf.size(0) > num_params) {
2450 weight_buf = weight_buf.narrow(0, 0, num_params);
2451 }
2452
2453 // Get and check data pointers
2454 auto expected_data_ptrs = get_expected_data_ptrs(
2455 weight_buf, handle, rnn, rnn_desc, x_desc, datatype);
2456
2457 int64_t num_parameters = parameters.size();
2458 int64_t num_ptrs = expected_data_ptrs.size();
2459 if (proj_size != 0) {
2460 AT_ASSERT(num_parameters % (has_biases ? 5 : 3) == 0);
2461 AT_ASSERT(num_ptrs % 5 == 0);
2462 if (has_biases) {
2463 AT_ASSERT(num_ptrs == num_parameters);
2464 for (const auto i : c10::irange(num_parameters)) {
2465 if (expected_data_ptrs[i] != parameters[i].data_ptr())
2466 return {};
2467 }
2468 } else {
2469 AT_ASSERT(num_parameters % 3 == 0);
2470 AT_ASSERT(num_ptrs == num_parameters * 5 / 3);
2471 for (int64_t param_i = 0, ptr_i = 0; ptr_i < num_ptrs;
2472 ptr_i += 5, param_i += 3) {
2473 if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr())
2474 return {};
2475 if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr())
2476 return {};
2477 if (expected_data_ptrs[ptr_i + 4] != parameters[param_i + 2].data_ptr())
2478 return {};
2479 }
2480 }
2481 } else {
2482 AT_ASSERT(num_ptrs == (num_parameters * (has_biases ? 1 : 2)));
2483 AT_ASSERT(num_parameters % (has_biases ? 4 : 2) == 0);
2484 for (int64_t param_i = 0, ptr_i = 0; ptr_i < num_ptrs;
2485 ptr_i += (has_biases ? 2 : 4), param_i += 2) {
2486 if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr())
2487 return {};
2488 if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr())
2489 return {};
2490 }
2491 }
2492 if (!parameters[num_parameters - 1].is_contiguous())
2493 return {};
2494 return weight_buf;
2495 }
2496
2497 template <typename hidden_type>
_cudnn_impl(const Tensor & input,const Tensor & _batch_sizes,const hidden_type & hidden,TensorList params,bool has_biases,cudnnRNNMode_t mode,int64_t num_layers,double dropout_p,bool train,bool bidirectional)2498 std::pair<Tensor, hidden_type> _cudnn_impl(
2499 const Tensor& input,
2500 const Tensor& _batch_sizes,
2501 const hidden_type& hidden,
2502 TensorList params,
2503 bool has_biases,
2504 cudnnRNNMode_t mode,
2505 int64_t num_layers,
2506 double dropout_p,
2507 bool train,
2508 bool bidirectional) {
2509 auto [hx, cx] = unpack_hidden(hidden);
2510 auto hidden_size = hx.sym_size(2);
2511 SymInt proj_size = 0;
2512 // For LSTM models with projections hidden size could be different
2513 if (cx.defined() && cx.sym_size(2) != hx.sym_size(2)) {
2514 hidden_size = cx.sym_size(2);
2515 proj_size = hx.sym_size(2);
2516 }
2517
2518 // TODO: try_get_weight_buf returns a Tensor, but _cudnn_rnn below takes a
2519 // std::optional<Tensor> in weight_buf's slot. Do we want try_get_weight_buf
2520 // to return a std::optional<Tensor> instead of a defined or undefined Tensor?
2521 at::cuda::OptionalCUDAGuard guard(input.get_device());
2522 auto weight_buf = try_get_weight_buf(
2523 input,
2524 params,
2525 has_biases,
2526 mode,
2527 hidden_size,
2528 proj_size,
2529 num_layers,
2530 bidirectional);
2531
2532 TORCH_CHECK(_batch_sizes.dim() == 1, "batch_sizes tensor should be 1D");
2533 IntArrayRef batch_sizes{
2534 _batch_sizes.data_ptr<int64_t>(),
2535 static_cast<size_t>(_batch_sizes.size(0))};
2536
2537 auto& dropout_state = get_dropout_state(dropout_p, train, input.options());
2538 std::unique_lock<DropoutState> lock{dropout_state};
2539 int64_t num_params = has_biases ? 4 : 2;
2540 if (proj_size != 0) {
2541 ++num_params;
2542 }
2543 auto sym_batch_sizes = c10::SymIntArrayRef(
2544 reinterpret_cast<const c10::SymInt*>(batch_sizes.data()),
2545 batch_sizes.size());
2546 // cudnn_output = std::tuple<output, hy, cy, reserve, new_weight_buf>
2547 auto cudnn_output = at::_cudnn_rnn_symint(
2548 input,
2549 params,
2550 num_params,
2551 weight_buf,
2552 hx,
2553 cx,
2554 static_cast<int>(mode),
2555 hidden_size,
2556 proj_size,
2557 num_layers,
2558 /*batch_first=*/false,
2559 dropout_p,
2560 train,
2561 bidirectional,
2562 sym_batch_sizes,
2563 dropout_state.buffer);
2564
2565 return {
2566 std::get<0>(cudnn_output),
2567 pack_hidden<hidden_type>(
2568 std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
2569 }
2570
2571 template <typename hidden_type>
_cudnn_impl(const Tensor & input,const hidden_type & hidden,TensorList params,bool has_biases,cudnnRNNMode_t mode,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first)2572 std::pair<Tensor, hidden_type> _cudnn_impl(
2573 const Tensor& input,
2574 const hidden_type& hidden,
2575 TensorList params,
2576 bool has_biases,
2577 cudnnRNNMode_t mode,
2578 int64_t num_layers,
2579 double dropout_p,
2580 bool train,
2581 bool bidirectional,
2582 bool batch_first) {
2583 auto [hx, cx] = unpack_hidden(hidden);
2584 auto hidden_size = hx.sym_size(2);
2585 c10::SymInt proj_size = 0;
2586 // For LSTM models with projections hidden size could be different
2587 if (cx.defined() && cx.sym_size(2) != hx.sym_size(2)) {
2588 hidden_size = cx.sym_size(2);
2589 proj_size = hx.sym_size(2);
2590 }
2591 at::cuda::OptionalCUDAGuard guard(input.get_device());
2592 auto weight_buf = try_get_weight_buf(
2593 input,
2594 params,
2595 has_biases,
2596 mode,
2597 hidden_size,
2598 proj_size,
2599 num_layers,
2600 bidirectional);
2601 auto& dropout_state = get_dropout_state(dropout_p, train, input.options());
2602 std::unique_lock<DropoutState> lock{dropout_state};
2603 int64_t num_params = has_biases ? 4 : 2;
2604 if (proj_size != 0) {
2605 ++num_params;
2606 }
2607 // cudnn_output = std::tuple<output, hy, cy, reserve, new_weight_buf>
2608 auto cudnn_output = at::_cudnn_rnn_symint(
2609 input,
2610 params,
2611 num_params,
2612 weight_buf,
2613 hx,
2614 cx,
2615 static_cast<int>(mode),
2616 hidden_size,
2617 proj_size,
2618 num_layers,
2619 batch_first,
2620 dropout_p,
2621 train,
2622 bidirectional,
2623 /*batch_sizes=*/{},
2624 dropout_state.buffer);
2625
2626 return {
2627 std::get<0>(cudnn_output),
2628 pack_hidden<hidden_type>(
2629 std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
2630 }
2631
2632 #define ONE_HIDDEN_RNN(NAME, MODE) \
2633 void NAME##_cudnn( \
2634 Tensor& output, \
2635 Tensor& hy, \
2636 const Tensor& input, \
2637 const Tensor& hx, \
2638 TensorList params, \
2639 bool has_biases, \
2640 int64_t num_layers, \
2641 double dropout_p, \
2642 bool train, \
2643 bool bidirectional, \
2644 bool batch_first) { \
2645 std::tie(output, hy) = _cudnn_impl( \
2646 input, \
2647 hx, \
2648 params, \
2649 has_biases, \
2650 MODE, \
2651 num_layers, \
2652 dropout_p, \
2653 train, \
2654 bidirectional, \
2655 batch_first); \
2656 } \
2657 \
2658 void NAME##_packed_cudnn( \
2659 Tensor& output, \
2660 Tensor& hy, \
2661 const Tensor& data, \
2662 const Tensor& batch_sizes, \
2663 const Tensor& hx, \
2664 TensorList params, \
2665 bool has_biases, \
2666 int64_t num_layers, \
2667 double dropout_p, \
2668 bool train, \
2669 bool bidirectional) { \
2670 std::tie(output, hy) = _cudnn_impl( \
2671 data, \
2672 batch_sizes, \
2673 hx, \
2674 params, \
2675 has_biases, \
2676 MODE, \
2677 num_layers, \
2678 dropout_p, \
2679 train, \
2680 bidirectional); \
2681 } \
2682 \
2683 REGISTER_CUDA_DISPATCH(NAME##_cudnn_stub, &NAME##_cudnn); \
2684 REGISTER_CUDA_DISPATCH(NAME##_packed_cudnn_stub, &NAME##_packed_cudnn);
2685
ONE_HIDDEN_RNN(gru,CUDNN_GRU)2686 ONE_HIDDEN_RNN(gru, CUDNN_GRU)
2687 ONE_HIDDEN_RNN(rnn_tanh, CUDNN_RNN_TANH)
2688 ONE_HIDDEN_RNN(rnn_relu, CUDNN_RNN_RELU)
2689
2690 void lstm_cudnn(
2691 Tensor& output,
2692 Tensor& hy,
2693 Tensor& cy,
2694 const Tensor& input,
2695 TensorList hx,
2696 TensorList params,
2697 bool has_biases,
2698 int64_t num_layers,
2699 double dropout_p,
2700 bool train,
2701 bool bidirectional,
2702 bool batch_first) {
2703 auto result = _cudnn_impl(
2704 input,
2705 std::make_tuple(hx[0], hx[1]),
2706 params,
2707 has_biases,
2708 CUDNN_LSTM,
2709 num_layers,
2710 dropout_p,
2711 train,
2712 bidirectional,
2713 batch_first);
2714 output = result.first;
2715 hy = std::get<0>(result.second);
2716 cy = std::get<1>(result.second);
2717 }
2718
lstm_packed_cudnn(Tensor & output,Tensor & hy,Tensor & cy,const Tensor & data,const Tensor & batch_sizes,TensorList hx,TensorList params,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional)2719 void lstm_packed_cudnn(
2720 Tensor& output,
2721 Tensor& hy,
2722 Tensor& cy,
2723 const Tensor& data,
2724 const Tensor& batch_sizes,
2725 TensorList hx,
2726 TensorList params,
2727 bool has_biases,
2728 int64_t num_layers,
2729 double dropout_p,
2730 bool train,
2731 bool bidirectional) {
2732 auto result = _cudnn_impl(
2733 data,
2734 batch_sizes,
2735 std::make_tuple(hx[0], hx[1]),
2736 params,
2737 has_biases,
2738 CUDNN_LSTM,
2739 num_layers,
2740 dropout_p,
2741 train,
2742 bidirectional);
2743 output = result.first;
2744 hy = std::get<0>(result.second);
2745 cy = std::get<1>(result.second);
2746 }
2747
2748 REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn);
2749 REGISTER_CUDA_DISPATCH(lstm_packed_cudnn_stub, &lstm_packed_cudnn);
2750
2751 } // namespace
2752
2753 } // namespace at
2754 } // namespace at
2755
2756 #endif // AT_CUDNN_ENABLED()
2757