xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/RNN.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/MatrixRef.h>
4 #include <ATen/TensorUtils.h>
5 #include <ATen/core/Tensor.h>
6 #include <ATen/cuda/CUDAConfig.h>
7 #include <ATen/cuda/CUDAEvent.h>
8 #include <ATen/cuda/Exceptions.h>
9 #include <ATen/native/RNN.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/accumulate.h>
12 #include <c10/util/irange.h>
13 #include <ATen/cuda/CUDAGraphsUtils.cuh>
14 
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/_cudnn_init_dropout_state.h>
20 #include <ATen/ops/_cudnn_init_dropout_state_native.h>
21 #include <ATen/ops/_cudnn_rnn.h>
22 #include <ATen/ops/_cudnn_rnn_backward_native.h>
23 #include <ATen/ops/_cudnn_rnn_flatten_weight_native.h>
24 #include <ATen/ops/_cudnn_rnn_native.h>
25 #include <ATen/ops/empty.h>
26 #include <ATen/ops/zeros.h>
27 #include <ATen/ops/zeros_like.h>
28 #endif
29 
30 #if !AT_CUDNN_ENABLED()
31 
32 namespace at {
33 namespace native {
34 
35 // See Note [ATen preprocessor philosophy]
36 
_cudnn_rnn_flatten_weight(TensorList weight_arr,int64_t weight_stride0,int64_t input_size,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,bool fn_bidirectional)37 Tensor _cudnn_rnn_flatten_weight(
38     TensorList weight_arr,
39     int64_t weight_stride0,
40     int64_t input_size,
41     int64_t fn_mode,
42     int64_t fn_hidden_size,
43     int64_t fn_proj_size,
44     int64_t fn_num_layers,
45     bool batch_first,
46     bool fn_bidirectional) {
47   AT_ERROR("_cudnn_rnn_flatten_weight: ATen not compiled with cuDNN support");
48 }
49 
_cudnn_rnn(const Tensor & input_r,TensorList weight,int64_t weight_stride0,const std::optional<Tensor> & weight_buf_r_opt,const Tensor & hx,const std::optional<Tensor> & cx_opt,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const std::optional<Tensor> & fn_dropout_state_opt)50 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
51     const Tensor& input_r,
52     TensorList weight,
53     int64_t weight_stride0,
54     const std::optional<Tensor>& weight_buf_r_opt,
55     const Tensor& hx,
56     const std::optional<Tensor>& cx_opt,
57     int64_t fn_mode,
58     int64_t fn_hidden_size,
59     int64_t fn_proj_size,
60     int64_t fn_num_layers,
61     bool batch_first,
62     double fn_dropout,
63     bool fn_train,
64     bool fn_bidirectional,
65     IntArrayRef fn_batch_sizes,
66     const std::optional<Tensor>& fn_dropout_state_opt) {
67   AT_ERROR("_cudnn_rnn: ATen not compiled with cuDNN support");
68 }
69 
_cudnn_rnn_backward(const Tensor & input,TensorList weight,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const std::optional<Tensor> & cx_opt,const Tensor & output,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,int64_t mode,int64_t hidden_size,int64_t proj_size,int64_t num_layers,bool batch_first,double dropout,bool train,bool bidirectional,IntArrayRef batch_sizes,const std::optional<Tensor> & dropout_state_opt,const Tensor & reserve,std::array<bool,4> output_mask)70 std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> _cudnn_rnn_backward(
71     const Tensor& input,
72     TensorList weight,
73     int64_t weight_stride0,
74     const Tensor& weight_buf,
75     const Tensor& hx,
76     const std::optional<Tensor>& cx_opt,
77     const Tensor& output,
78     const std::optional<Tensor>& grad_output_r_opt,
79     const std::optional<Tensor>& grad_hy_r_opt,
80     const std::optional<Tensor>& grad_cy_r_opt,
81     int64_t mode,
82     int64_t hidden_size,
83     int64_t proj_size,
84     int64_t num_layers,
85     bool batch_first,
86     double dropout,
87     bool train,
88     bool bidirectional,
89     IntArrayRef batch_sizes,
90     const std::optional<Tensor>& dropout_state_opt,
91     const Tensor& reserve,
92     std::array<bool, 4> output_mask) {
93   AT_ERROR("_cudnn_rnn_backward: ATen not compiled with cuDNN support");
94 }
95 
_cudnn_init_dropout_state(double dropout,bool train,int64_t dropout_seed,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)96 Tensor _cudnn_init_dropout_state(
97     double dropout,
98     bool train,
99     int64_t dropout_seed,
100     std::optional<ScalarType> dtype,
101     std::optional<Layout> layout,
102     std::optional<Device> device,
103     std::optional<bool> pin_memory) {
104   // See [Note: hacky wrapper removal for TensorOptions]
105   TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
106       pin_memory);
107 
108   AT_ERROR("_cudnn_init_dropout_state: ATen not compiled with cuDNN support");
109 }
110 
111 } // namespace native
112 } // namespace at
113 
114 #else // AT_CUDNN_ENABLED()
115 
116 #include <ATen/native/cudnn/RNNUtils.h>
117 
118 namespace at {
119 namespace native {
120 
121 namespace {
122 // DropoutDescriptor
123 
124 struct DropoutDescriptorParams {
125   bool train;
126   double dropout;
127   Tensor dropout_state;
128   DropoutDescriptorParams() = default;
setat::native::__anon58f427e00111::DropoutDescriptorParams129   void set(bool train_, double dropout_, Tensor dropout_state_) {
130     train = train_;
131     dropout = dropout_;
132     dropout_state = dropout_state_;
133   }
descriptorat::native::__anon58f427e00111::DropoutDescriptorParams134   DropoutDescriptor descriptor(cudnnHandle_t handle) const {
135     auto dropout_p = train ? dropout : 0;
136     DropoutDescriptor dropout_desc;
137     if (dropout_p == 0) {
138       dropout_desc.set_no_dropout(handle);
139     } else {
140       dropout_desc.set(handle, dropout_p, dropout_state);
141     }
142     return dropout_desc;
143   }
144 };
145 
146 // RNNDescriptor
147 
148 struct RNNDescriptorParams {
149 #ifdef USE_CUDNN_RNN_V8_API
150   int64_t input_size;
151   bool packed;
152 #endif
153   int64_t hidden_size;
154   int64_t proj_size;
155   int64_t num_layers;
156   cudnnDirectionMode_t bidirectional;
157   cudnnRNNMode_t mode;
158   cudnnDataType_t datatype;
159   cudnnDataType_t input_datatype;
160   cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
161   cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
162 
num_directionsat::native::__anon58f427e00111::RNNDescriptorParams163   int64_t num_directions() const {
164     return bidirectional ? 2 : 1;
165   }
166 
set_modeat::native::__anon58f427e00111::RNNDescriptorParams167   void set_mode(int64_t fn_mode) {
168     switch (fn_mode) {
169       case CUDNN_RNN_RELU:
170         mode = CUDNN_RNN_RELU;
171         break;
172       case CUDNN_RNN_TANH:
173         mode = CUDNN_RNN_TANH;
174         break;
175       case CUDNN_LSTM:
176         mode = CUDNN_LSTM;
177         break;
178       case CUDNN_GRU:
179         mode = CUDNN_GRU;
180         break;
181       default: {
182         std::ostringstream oss;
183         oss << "unrecognized cuDNN RNN mode " << fn_mode;
184         AT_ERROR(oss.str());
185       }
186     }
187   }
188 
set_bidirectionalat::native::__anon58f427e00111::RNNDescriptorParams189   void set_bidirectional(bool fn_bidirectional) {
190     bidirectional =
191         fn_bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
192   }
193 
set_algoat::native::__anon58f427e00111::RNNDescriptorParams194   void set_algo(cudnnRNNAlgo_t algo) {
195     this->algo = algo;
196   }
197 
198 #ifndef USE_CUDNN_RNN_V8_API
setat::native::__anon58f427e00111::RNNDescriptorParams199   void set(
200       int64_t mode,
201       int64_t hidden_size,
202       int64_t proj_size,
203       int64_t num_layers,
204       bool bidirectional,
205       cudnnDataType_t datatype,
206       cudnnDataType_t input_datatype){
207 #else
208   void set(
209       int64_t mode,
210       int64_t input_size,
211       bool packed,
212       int64_t hidden_size,
213       int64_t proj_size,
214       int64_t num_layers,
215       bool bidirectional,
216       cudnnDataType_t datatype,
217       cudnnDataType_t input_datatype) {
218 #endif
219       this->set_mode(mode);
220 #ifdef USE_CUDNN_RNN_V8_API
221   this->input_size = input_size;
222   this->packed = packed;
223 #endif
224   this->hidden_size = hidden_size;
225   this->proj_size = proj_size;
226   this->num_layers = num_layers;
227   this->set_bidirectional(bidirectional);
228   this->datatype = datatype;
229   this->input_datatype = input_datatype;
230 }
231 
232 RNNDescriptor
233 descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
234   RNNDescriptor rnn_desc;
235 #ifndef USE_CUDNN_RNN_V8_API
236   rnn_desc.set(
237       handle,
238       hidden_size,
239       proj_size,
240       num_layers,
241       std::move(dropout_desc),
242       input_mode,
243       bidirectional,
244       mode,
245       datatype,
246       input_datatype,
247       algo,
248       at::globalContext().allowTF32CuDNN());
249 #else
250     rnn_desc.set(
251         handle,
252         input_size,
253         packed,
254         hidden_size,
255         proj_size,
256         num_layers,
257         std::move(dropout_desc),
258         input_mode,
259         bidirectional,
260         mode,
261         datatype,
262         input_datatype,
263         algo,
264         at::globalContext().allowTF32CuDNN());
265 #endif
266   return rnn_desc;
267 }
268 
269 // In some cases, a use of RNNDescriptor does not rely on the
270 // DropoutDescriptor.  In this case, we fake up a no-dropout
271 // descriptor to make the RNN descriptor initialization go through.
272 // This is used by _cudnn_rnn_flatten_weight, which needs an
273 // RNNDescriptor for get_parameters(), but does not actually need
274 // a fully initialized dropout descriptor.  This lets us avoid
275 // having to pass the dropout state to flatten, which has no business
276 // knowing what the dropout state is.
277 RNNDescriptor descriptor(cudnnHandle_t handle) const {
278   DropoutDescriptor dropout_desc;
279   dropout_desc.set_no_dropout(handle);
280   return descriptor(handle, std::move(dropout_desc));
281 }
282 }; // namespace
283 
284 // TensorDescriptor list
285 #ifndef USE_CUDNN_RNN_V8_API
rnn_descriptor_sequence(const Tensor & tensor,IntArrayRef batch_sizes)286 std::vector<TensorDescriptor> rnn_descriptor_sequence(
287     const Tensor& tensor,
288     IntArrayRef batch_sizes) {
289   std::vector<TensorDescriptor> descriptors(batch_sizes.size());
290   size_t i = 0;
291   // To be mutated in the loop
292   auto batch_tensor_size = tensor.sizes().vec();
293   for (auto batch_size : batch_sizes) {
294     batch_tensor_size[0] = batch_size;
295     // NB: cuDNN RNN API does not support 2d descriptors, so we
296     // must pad it out to 3d.
297     descriptors[i].set(
298         getCudnnDataType(tensor), batch_tensor_size, tensor.strides(), 3);
299     i++;
300   }
301   return descriptors;
302 }
303 
rnn_descriptor(const Tensor & tensor,int64_t N)304 std::vector<TensorDescriptor> rnn_descriptor(const Tensor& tensor, int64_t N) {
305   std::vector<TensorDescriptor> descriptors(N);
306   for (const auto i : c10::irange(N)) {
307     descriptors[i].set(tensor, 5);
308   }
309   return descriptors;
310 }
311 #else
rnn_descriptor_sequence(const Tensor & tensor,uint32_t batch_size,const IntArrayRef batch_sizes,uint32_t seq_len,uint32_t vector_size)312 auto rnn_descriptor_sequence(
313     const Tensor& tensor,
314     uint32_t batch_size,
315     const IntArrayRef batch_sizes,
316     uint32_t seq_len,
317     uint32_t vector_size) { // packed case
318   RNNDataDescriptor r;
319   std::vector<int> seqLengthArray(batch_size, 1);
320   // cuDNN wants the sequence lengths for a packed batch as if they
321   // were unpacked, e.g., for the
322   // Sequence 1: ABCD
323   // Sequence 2: EF
324   // Sequence 3: G
325   // case below, this would be [4, 2, 1] (has length == mini_batch)
326   // TODO(eqy): There's probably a smarter way to do this than O(SN)
327   for (auto it = batch_sizes.begin(); it != batch_sizes.end(); it++) {
328     // everyone starts at sequence length 1 so we skip an iteration
329     if (it == batch_sizes.begin()) {
330       continue;
331     }
332     for (const auto idx : c10::irange(*it)) {
333       seqLengthArray[idx]++;
334     }
335   }
336   r.set(
337       tensor,
338       CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
339       seq_len,
340       batch_size,
341       vector_size,
342       seqLengthArray.data());
343   return r;
344 }
345 
rnn_descriptor(const Tensor & tensor,uint32_t batch_size,uint32_t seq_len,uint32_t vector_size)346 auto rnn_descriptor(
347     const Tensor& tensor,
348     uint32_t batch_size,
349     uint32_t seq_len,
350     uint32_t vector_size) {
351   RNNDataDescriptor r;
352   // NB: Looks like even if batch_first is true here we always want
353   // SEQ_MAJOR_UNPACKED, because the input appears to be transposed if it is
354   // barch-major
355   const auto layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
356   std::vector<int32_t> seqLengthArray(batch_size, seq_len);
357   r.set(
358       tensor, layout, seq_len, batch_size, vector_size, seqLengthArray.data());
359   return r;
360 }
361 #endif
362 
363 // The best way to understand the meaning of the values stored in
364 // this struct is to consider each of the possible ways our
365 // input can be structured.
366 //
367 // Suppose you want to run RNN on the following variable
368 // length inputs:
369 //
370 //    Sequence 1: ABCD
371 //    Sequence 2: EF
372 //    Sequence 3: G
373 //
374 // (Let _ be padding when we have non-packed representations.)
375 //
376 // # Packed input (batch_sizes is non-empty)
377 //
378 //  input_size
379 // +------+                    +
380 // | A    |                    |
381 // | E    | mini_batch =       |
382 // | G    | batch_sizes[0] = 3 |
383 // +------+                    |
384 // | B    |                    | batch_sizes_sum = 7
385 // | F    | batch_sizes[1] = 2 |
386 // +------+                    |
387 // | C    | batch_sizes[2] = 1 |
388 // +------+                    |
389 // | D    | batch_sizes[3] = 1 |
390 // +------+                    +
391 //
392 //              (seq_length = 4)
393 //
394 //    input.size() = batch_sizes_sum x input_size
395 //
396 // # Unpacked input (batch_first = false)
397 //
398 //  mini_batch = 3
399 // +-------+
400 // | A E G |
401 // | B F _ | seq_length = 4
402 // | C _ _ |
403 // | D _ _ |
404 // +-------+
405 //    ...    input_size
406 // +-------+
407 //
408 //    input.size() = seq_length x mini_batch x input_size
409 //
410 // # Unpacked input (batch_first = true)
411 //
412 //  seq_length = 4
413 // +---------+
414 // | A B C D |
415 // | E F _ _ | mini_batch = 3
416 // | G _ _ _ |
417 // +---------+
418 //     ...     input_size
419 // +---------+
420 //
421 //    input.size() = mini_batch x seq_length x input_size
422 //
423 struct TensorDescriptorListParams {
424   IntArrayRef batch_sizes;
425   int64_t seq_length;
426   int64_t mini_batch;
427   // NB: this is not input.size(), which is an IntArrayRef; instead, this
428   // size of the inner-most dimension.  In NL applications, this is usually
429   // the size of the embedding.  You can also think of this as the size
430   // of the "channel" dimension (at risk of confusing vision researchers :)
431   int64_t input_size;
432   // Only valid when !is_input_packed
433   int64_t batch_sizes_sum; // == sum(batch_sizes)
434 
is_input_packedat::native::__anon58f427e00111::TensorDescriptorListParams435   bool is_input_packed() const {
436     return batch_sizes.size() != 0;
437   }
438 
setat::native::__anon58f427e00111::TensorDescriptorListParams439   void set(
440       IntArrayRef input_sizes,
441       IntArrayRef batch_sizes_,
442       bool batch_first) {
443     batch_sizes = batch_sizes_;
444     if (is_input_packed()) {
445       seq_length = batch_sizes.size();
446       mini_batch = batch_sizes[0];
447       // NB: When input is packed, the mini_batch size is NOT the size
448       // of the outer dimension
449       batch_sizes_sum = input_sizes[0];
450       input_size = input_sizes[1];
451     } else {
452       if (batch_first) {
453         seq_length = input_sizes[1];
454         mini_batch = input_sizes[0];
455       } else {
456         seq_length = input_sizes[0];
457         mini_batch = input_sizes[1];
458       }
459       input_size = input_sizes[2];
460       // TODO: Actually, would this make ASAN's job harder catching
461       // an uninitialized access?
462       batch_sizes_sum = -1; // something bogus in case we access it
463     }
464   }
465 #ifndef USE_CUDNN_RNN_V8_API
466   // TODO: check x for consistency with input_size?
descriptorsat::native::__anon58f427e00111::TensorDescriptorListParams467   std::vector<TensorDescriptor> descriptors(Tensor x) const {
468     auto is_input_packed = batch_sizes.size() != 0;
469     if (is_input_packed) {
470       return rnn_descriptor_sequence(x, batch_sizes);
471     } else {
472       return rnn_descriptor(x[0], seq_length);
473     }
474   }
475 #else
descriptorsat::native::__anon58f427e00111::TensorDescriptorListParams476   auto descriptors(Tensor x) const {
477     auto is_input_packed = batch_sizes.size() != 0;
478     if (is_input_packed) {
479       return rnn_descriptor_sequence(
480           x, mini_batch, batch_sizes, seq_length, x.size(-1));
481     } else {
482       return rnn_descriptor(x, mini_batch, seq_length, x.size(-1));
483     }
484   }
485 #endif
486 };
487 
488 // Everything together
489 
490 struct RNNParams {
491   DropoutDescriptorParams dropout;
492   RNNDescriptorParams rnn;
493   TensorDescriptorListParams tensors;
494 };
495 
496 // NB: Doesn't include the weight descriptor
497 struct RNNDescriptors {
498   RNNDescriptor rnn_desc;
499   // NB: this won't actually lay out the tensor descriptor pointers
500   // in the right way, so you'll have to preprocess them
501 #ifndef USE_CUDNN_RNN_V8_API
502   std::vector<TensorDescriptor> x_descs;
503   std::vector<TensorDescriptor> y_descs;
504 #else
505   RNNDataDescriptor x_descs;
506   RNNDataDescriptor y_descs;
507 #endif
508   TensorDescriptor hx_desc;
509   TensorDescriptor hy_desc;
510   TensorDescriptor cx_desc;
511   TensorDescriptor cy_desc;
512 
RNNDescriptorsat::native::__anon58f427e00111::RNNDescriptors513   RNNDescriptors(
514       const RNNParams& fn,
515       cudnnHandle_t handle,
516       Tensor x,
517       Tensor y,
518       Tensor hx,
519       Tensor cx) {
520     rnn_desc = fn.rnn.descriptor(handle, fn.dropout.descriptor(handle));
521     x_descs = fn.tensors.descriptors(x);
522     y_descs = fn.tensors.descriptors(y);
523     hx_desc.set(hx, 5);
524     hy_desc.set(hx, 5);
525     if (cx.defined()) {
526       cx_desc.set(cx, 5);
527       cy_desc.set(cx, 5);
528     }
529   }
530 
531   // TODO: This is annoying, having to put the cudnnTensorDescriptor_t
532   // in a contiguous array...
get_descsat::native::__anon58f427e00111::RNNDescriptors533   std::vector<cudnnTensorDescriptor_t> get_descs(
534       const std::vector<TensorDescriptor>& descs) {
535     std::vector<cudnnTensorDescriptor_t> r;
536     r.reserve(descs.size());
537     for (auto& desc : descs) {
538       r.emplace_back(desc.desc());
539     }
540     return r;
541   }
542 #ifndef USE_CUDNN_RNN_V8_API
get_x_descsat::native::__anon58f427e00111::RNNDescriptors543   std::vector<cudnnTensorDescriptor_t> get_x_descs() {
544     return get_descs(x_descs);
545   }
546 
get_y_descsat::native::__anon58f427e00111::RNNDescriptors547   std::vector<cudnnTensorDescriptor_t> get_y_descs() {
548     return get_descs(y_descs);
549   }
550 #endif
551 };
552 
get_num_weights(cudnnHandle_t handle,const RNNDescriptor & rnn_desc,const TensorDescriptor & x_desc,cudnnDataType_t datatype)553 int64_t get_num_weights(
554     cudnnHandle_t handle,
555     const RNNDescriptor& rnn_desc,
556 #ifndef USE_CUDNN_RNN_V8_API
557     const TensorDescriptor& x_desc,
558 #endif
559     cudnnDataType_t datatype) {
560   size_t weight_size;
561 #ifndef USE_CUDNN_RNN_V8_API
562   AT_CUDNN_CHECK(cudnnGetRNNParamsSize(
563       handle, rnn_desc.desc(), x_desc.desc(), &weight_size, datatype));
564 #else
565   AT_CUDNN_CHECK(
566       cudnnGetRNNWeightSpaceSize(handle, rnn_desc.desc(), &weight_size));
567 #endif
568   auto elem_size = dataSize(datatype);
569   TORCH_INTERNAL_ASSERT(
570       weight_size % elem_size == 0,
571       "cudnnGetRNNParamsSize returned nonsensical weight_size");
572   return weight_size / elem_size;
573 }
574 
_num_linear_layers(cudnnRNNMode_t mode)575 int64_t _num_linear_layers(cudnnRNNMode_t mode) {
576   switch (mode) {
577     case CUDNN_LSTM:
578       return 8;
579     case CUDNN_GRU:
580       return 6;
581     case CUDNN_RNN_RELU:
582       return 2;
583     case CUDNN_RNN_TANH:
584       return 2;
585     default:
586       AT_ERROR("unknown cuDNN RNN mode ", mode);
587   }
588 }
589 
add_projection_weights(cudnnHandle_t handle,const RNNDescriptor & rnn_desc,const TensorDescriptor & x_desc,const FilterDescriptor & w_desc,const Tensor & weight_buf,int64_t layer,std::vector<Tensor> & params)590 void add_projection_weights(
591     cudnnHandle_t handle,
592     const RNNDescriptor& rnn_desc,
593 #ifndef USE_CUDNN_RNN_V8_API
594     const TensorDescriptor& x_desc,
595     const FilterDescriptor& w_desc,
596 #endif
597     const Tensor& weight_buf,
598     int64_t layer,
599     std::vector<Tensor>& params) {
600   void* matrix_pointer = nullptr;
601   // assuming it's LSTM which has 8 "linear layers" (i.e. 4 weights and 4
602   // biases)
603   int64_t linear_id = 8;
604 #ifndef USE_CUDNN_RNN_V8_API
605   FilterDescriptor lin_layer_mat_desc;
606   AT_CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams(
607       /*handle=*/handle,
608       /*rnnDesc=*/rnn_desc.desc(),
609       /*layer=*/layer,
610       /*xDesc=*/x_desc.desc(),
611       /*wDesc=*/w_desc.desc(),
612       /*w=*/weight_buf.data_ptr(),
613       /*linLayerID=*/linear_id,
614       /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(),
615       /*linLayerMat=*/&matrix_pointer));
616 #else
617   TensorDescriptor lin_layer_mat_desc;
618   AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
619       /*handle=*/handle,
620       /*rnnDesc=*/rnn_desc.desc(),
621       /*layer=*/layer,
622       /*wDesc=*/weight_buf.numel() * weight_buf.element_size(),
623       /*w=*/weight_buf.data_ptr(),
624       /*linLayerID=*/linear_id,
625       /*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(),
626       /*linLayerMat=*/&matrix_pointer,
627       nullptr,
628       nullptr));
629 #endif
630 
631   cudnnDataType_t data_type;
632 #ifndef USE_CUDNN_RNN_V8_API
633   cudnnTensorFormat_t format;
634 #else
635   int stride_dim_a[5];
636 #endif
637   int nb_dims;
638   constexpr int min_dim = 3;
639   int filter_dim_a[min_dim];
640 #ifndef USE_CUDNN_RNN_V8_API
641   AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor(
642       lin_layer_mat_desc.desc(),
643       min_dim,
644       &data_type,
645       &format,
646       &nb_dims,
647       filter_dim_a));
648 #else
649   AT_CUDNN_CHECK(cudnnGetTensorNdDescriptor(
650       lin_layer_mat_desc.desc(),
651       min_dim,
652       &data_type,
653       &nb_dims,
654       filter_dim_a,
655       stride_dim_a));
656 #endif
657 
658   TORCH_INTERNAL_ASSERT(
659       nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim  = ", min_dim);
660   auto elem_size = dataSize(getCudnnDataType(weight_buf));
661   auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr();
662   TORCH_INTERNAL_ASSERT(
663       offset_bytes % elem_size == 0,
664       "offset_bytes = ",
665       offset_bytes,
666       "; elem_size = ",
667       elem_size);
668   size_t offset = offset_bytes / elem_size;
669 
670   int mat_numel = c10::multiply_integers(filter_dim_a, filter_dim_a + nb_dims);
671   // Generate a new parameter tensor which is a view into the weight_buf.
672   std::initializer_list<int64_t> size = {mat_numel, 1};
673   Tensor param = at::empty({0}, weight_buf.options())
674                      .set_(weight_buf.storage(), offset, size);
675   params.emplace_back(std::move(param));
676 }
677 
678 /*
679   Returns weight and bias tensors for each layer of the RNN. These tensors
680   are views on the underlying weight buffer allocated by CuDNN.
681 
682   Note: for LSTM and GRU, which have multiple parameters of each type (4 and 3,
683   respectively), these parameters are concatenated along the first dimension.
684         These parameters are returned in a consistent order by CuDNN:
685             (reset, forget, cell, output) for LSTM
686             (reset, input, new) for GRU
687   Args:
688       fn: The RNN function object holding the RNN state
689       handle: a CuDNN handle
690       weight_buf: a 1D tensor containing the CuDNN-allocated weight (or
691   grad_weight) buffer Returns: parameters: [(weight_ih, weight_hh, bias_ih,
692   bias_hh)*], with length equal to the num_layers. This is represented as a pair
693   of vector, and outer-dimension stride (NB: Can't return MatrixRef because we
694   need to allocate the underlying tensor)
695 */
696 std::pair<std::vector<Tensor>, size_t> // stride0
get_parameters(cudnnHandle_t handle,const RNNDescriptorParams & rnn,const RNNDescriptor & rnn_desc,const TensorDescriptor & x_desc,const FilterDescriptor & w_desc,const Tensor & weight_buf,bool include_bias=true)697 get_parameters(
698     cudnnHandle_t handle,
699     const RNNDescriptorParams& rnn,
700     const RNNDescriptor& rnn_desc,
701 #ifndef USE_CUDNN_RNN_V8_API
702     const TensorDescriptor& x_desc,
703     const FilterDescriptor& w_desc,
704 #endif
705     const Tensor& weight_buf,
706     bool include_bias = true) {
707 #ifndef USE_CUDNN_RNN_V8_API
708   auto cudnn_methods = {
709       cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams};
710 #else
711   auto cudnn_methods = {true, false};
712 #endif
713   std::vector<Tensor> params;
714   int64_t num_linear_layers = _num_linear_layers(rnn.mode);
715   int64_t num_layers = rnn.num_directions() * rnn.num_layers;
716   size_t cur_offset = 0;
717   size_t global_layer_params_count = 0;
718   for (const auto layer : c10::irange(num_layers)) {
719     size_t layer_params_count = 0;
720     for (auto cudnn_method : cudnn_methods) {
721       for (const auto linear_id : c10::irange(num_linear_layers)) {
722         void* matrix_pointer;
723 #ifndef USE_CUDNN_RNN_V8_API
724         FilterDescriptor lin_layer_mat_desc;
725         AT_CUDNN_CHECK(cudnn_method(
726             handle,
727             rnn_desc.desc(),
728             layer,
729             x_desc.desc(),
730             w_desc.desc(),
731             weight_buf.data_ptr(),
732             linear_id,
733             lin_layer_mat_desc.mut_desc(),
734             &matrix_pointer));
735 #else
736         TensorDescriptor lin_layer_mat_desc;
737         for (int stateless = 0; stateless < 100; stateless++) {
738           if (cudnn_method) { // matrix
739             AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
740                 handle,
741                 rnn_desc.desc(),
742                 layer,
743                 weight_buf.numel() * weight_buf.element_size(),
744                 weight_buf.data_ptr(),
745                 linear_id,
746                 lin_layer_mat_desc.mut_desc(),
747                 &matrix_pointer,
748                 nullptr,
749                 nullptr));
750           } else { // bias
751             AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
752                 handle,
753                 rnn_desc.desc(),
754                 layer,
755                 weight_buf.numel() * weight_buf.element_size(),
756                 weight_buf.data_ptr(),
757                 linear_id,
758                 nullptr,
759                 nullptr,
760                 lin_layer_mat_desc.mut_desc(),
761                 &matrix_pointer));
762           }
763         }
764 #endif
765         cudnnDataType_t data_type;
766 #ifndef USE_CUDNN_RNN_V8_API
767         cudnnTensorFormat_t format;
768 #else
769         int stride_dim_a[5];
770 #endif
771         int nb_dims;
772         constexpr int min_dim = 3;
773         int filter_dim_a[min_dim];
774 #ifndef USE_CUDNN_RNN_V8_API
775         AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor(
776             lin_layer_mat_desc.desc(),
777             min_dim,
778             &data_type,
779             &format,
780             &nb_dims,
781             filter_dim_a));
782 #else
783         AT_CUDNN_CHECK(cudnnGetTensorNdDescriptor(
784             lin_layer_mat_desc.desc(),
785             min_dim,
786             &data_type,
787             &nb_dims,
788             filter_dim_a,
789             stride_dim_a));
790 #endif
791 
792         TORCH_INTERNAL_ASSERT(
793             nb_dims <= min_dim,
794             "nb_dims = ",
795             nb_dims,
796             "; min_dim  = ",
797             min_dim);
798         auto elem_size = dataSize(getCudnnDataType(weight_buf));
799         auto offset_bytes =
800             (char*)matrix_pointer - (char*)weight_buf.data_ptr();
801         TORCH_INTERNAL_ASSERT(
802             offset_bytes % elem_size == 0,
803             "offset_bytes = ",
804             offset_bytes,
805             "; elem_size = ",
806             elem_size);
807         size_t offset = offset_bytes / elem_size;
808         // for all the RNN types provided by CUDNN, all the ih weights
809         // are the same size and are allocated in a contiguous chunk
810         // (same for the hh weights, and the ih and hh biases).
811         // Since we're storing all the weights in a single tensor anyway,
812         // might as well merge the CUDNN ones into a single tensor as well
813         int mat_numel =
814             c10::multiply_integers(filter_dim_a, filter_dim_a + nb_dims);
815         if (linear_id == 0 || linear_id == num_linear_layers / 2) {
816           // We could also exclude bias params by restricting cudnn_methods to
817           // just { cudnnGetRNNLinLayerMatrixParams } at the very top.  However,
818           // to do so would throw off the cur_offset account, which is currently
819           // a strict and informative check that all params are laid out the way
820           // we think they are.  If include_bias is false, I'd rather keep full
821           // cur_offset checks rather than save some CPU overhead by skipping
822           // the cudnn_method = cudnnGetRNNLinLayerBiasParams iteration.
823 #ifndef USE_CUDNN_RNN_V8_API
824           if (include_bias || cudnn_method != cudnnGetRNNLinLayerBiasParams) {
825 #else
826           if (include_bias || cudnn_method) {
827 #endif
828             // Generate a new parameter tensor which is a view into the
829             // weight_buf.
830             std::initializer_list<int64_t> size = {
831                 mat_numel * num_linear_layers / 2, 1};
832             Tensor param = at::empty({0}, weight_buf.options())
833                                .set_(weight_buf.storage(), offset, size);
834             params.emplace_back(std::move(param));
835             layer_params_count++;
836           }
837         } else {
838           TORCH_INTERNAL_ASSERT(
839               cur_offset == offset,
840               "cur_offset = ",
841               cur_offset,
842               "; offset = ",
843               offset);
844         }
845         cur_offset = offset + mat_numel;
846       }
847     } // for cudnn_method
848     if (rnn.proj_size != 0) {
849 #ifndef USE_CUDNN_RNN_V8_API
850       add_projection_weights(
851           handle, rnn_desc, x_desc, w_desc, weight_buf, layer, params);
852 #else
853       add_projection_weights(handle, rnn_desc, weight_buf, layer, params);
854 #endif
855       layer_params_count++;
856     }
857 
858     if (layer == 0) {
859       global_layer_params_count = layer_params_count;
860     } else {
861       TORCH_INTERNAL_ASSERT(
862           global_layer_params_count == layer_params_count,
863           "global_layer_params_count = ",
864           global_layer_params_count,
865           "; layer_params_count = ",
866           layer_params_count);
867     }
868   } // for layer
869   return std::make_pair(params, global_layer_params_count);
870 }
871 
872 // This is a lightweight version of the method above used to quickly get the
873 // expected parameter offsets.
874 std::vector<void*> get_expected_data_ptrs(
875     const Tensor& weight_buf,
876     cudnnHandle_t handle,
877     const RNNDescriptorParams& rnn,
878     const RNNDescriptor& rnn_desc,
879     const TensorDescriptor& x_desc,
880     cudnnDataType_t datatype) {
881 #ifndef USE_CUDNN_RNN_V8_API
882   FilterDescriptor w_desc;
883   w_desc.set(weight_buf, 3);
884 #endif
885 
886   int64_t num_linear_layers = _num_linear_layers(rnn.mode);
887   int64_t num_dir_layers = rnn.num_directions() * rnn.num_layers;
888 #ifndef USE_CUDNN_RNN_V8_API
889   const auto cudnn_methods = {
890       cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams};
891 #else
892   const auto cudnn_methods = {true, false};
893 #endif
894   std::vector<void*> data_ptrs;
895   if (rnn.proj_size != 0) {
896     data_ptrs.reserve(num_dir_layers * (2 * 2 + 1));
897   } else {
898     data_ptrs.reserve(num_dir_layers * 2 * 2);
899   }
900   for (const auto layer : c10::irange(num_dir_layers)) {
901     for (auto cudnn_method : cudnn_methods) {
902       // This API returns a separate pointer for weight of every gate,
903       // but we represent them as a single tensor, so we're only interested
904       // in a very limited subset of possible values.
905       const std::array<int64_t, 2> linear_offsets = {0, num_linear_layers / 2};
906       for (int64_t linear_id : linear_offsets) {
907         void* matrix_pointer;
908 #ifndef USE_CUDNN_RNN_V8_API
909         FilterDescriptor lin_layer_mat_desc;
910         AT_CUDNN_CHECK(cudnn_method(
911             handle,
912             rnn_desc.desc(),
913             layer,
914             x_desc.desc(),
915             w_desc.desc(),
916             weight_buf.data_ptr(),
917             linear_id,
918             lin_layer_mat_desc.mut_desc(),
919             &matrix_pointer));
920 #else
921         TensorDescriptor lin_layer_mat_desc;
922         if (cudnn_method) { // matrix
923           AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
924               handle,
925               rnn_desc.desc(),
926               layer,
927               weight_buf.numel() * weight_buf.element_size(),
928               weight_buf.data_ptr(),
929               linear_id,
930               lin_layer_mat_desc.mut_desc(),
931               &matrix_pointer,
932               nullptr,
933               nullptr));
934         } else { // bias
935           AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
936               handle,
937               rnn_desc.desc(),
938               layer,
939               weight_buf.numel() * weight_buf.element_size(),
940               weight_buf.data_ptr(),
941               linear_id,
942               nullptr,
943               nullptr,
944               lin_layer_mat_desc.mut_desc(),
945               &matrix_pointer));
946         }
947 #endif
948         data_ptrs.push_back(matrix_pointer);
949       }
950     }
951     if (rnn.proj_size != 0) {
952       // assuming it's LSTM which has 8 "linear layers" (i.e. 4 weights and 4
953       // biases)
954       int64_t linear_id = 8;
955       void* matrix_pointer;
956 #ifndef USE_CUDNN_RNN_V8_API
957       FilterDescriptor lin_layer_mat_desc;
958       AT_CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams(
959           handle,
960           rnn_desc.desc(),
961           layer,
962           x_desc.desc(),
963           w_desc.desc(),
964           weight_buf.data_ptr(),
965           linear_id,
966           lin_layer_mat_desc.mut_desc(),
967           &matrix_pointer));
968 #else
969       TensorDescriptor lin_layer_mat_desc;
970 
971       AT_CUDNN_CHECK(cudnnGetRNNWeightParams(
972           handle,
973           rnn_desc.desc(),
974           layer,
975           weight_buf.numel() * weight_buf.element_size(),
976           weight_buf.data_ptr(),
977           linear_id,
978           lin_layer_mat_desc.mut_desc(),
979           &matrix_pointer,
980           nullptr,
981           nullptr));
982 #endif
983       data_ptrs.push_back(matrix_pointer);
984     }
985   }
986   return data_ptrs;
987 }
988 
989 void _viewOrCopyOneParam(
990     const Tensor& param_from,
991     const Tensor& param_to,
992     bool copy,
993     bool allow_type_change = false) {
994   // if copying, allow_type_change may be true or false.
995   // if viewing, allow_type_change must be false.
996   TORCH_INTERNAL_ASSERT(
997       copy || !allow_type_change, "if viewing, type change is not allowed.");
998   TORCH_INTERNAL_ASSERT(
999       allow_type_change || (param_from.scalar_type() == param_to.scalar_type()),
1000       "parameter types mismatch");
1001   if (copy) {
1002     param_to.copy_(param_from.view_as(param_to));
1003   } else {
1004     param_from.resize_as_(param_to);
1005   }
1006 }
1007 
1008 void _viewOrCopyParams(
1009     MatrixRef<Tensor> params_from,
1010     MatrixRef<Tensor> params_to,
1011     bool copy,
1012     bool allow_type_change = false) {
1013   TORCH_INTERNAL_ASSERT(
1014       params_from.size(0) == params_to.size(0), "number of layers mismatch");
1015   for (const auto i : c10::irange(params_from.size(0))) {
1016     auto layer_params_from = params_from[i];
1017     auto layer_params_to = params_to[i];
1018     // NOTE: these lists have all weights before all biases, so if the layer
1019     // doesn't use biases, iteration will terminate once layer_params_from ends
1020     // and ignore them.
1021 
1022     // NOTE: there is an exception from the above statement. If LSTMs with
1023     // projections are used, weights layout will be w_ih, w_hh, b_ih, b_hh,
1024     // w_hr. So need to handle no-bias case specially, because will need to copy
1025     // 0->0, 1->1, 2->4. This case can be uniquely identified by checking if
1026     // number of defined parameters for each layer is 3.
1027     if (layer_params_from.size() == 3 && layer_params_to.size() != 3) {
1028       _viewOrCopyOneParam(
1029           layer_params_from[0], layer_params_to[0], copy, allow_type_change);
1030       _viewOrCopyOneParam(
1031           layer_params_from[1], layer_params_to[1], copy, allow_type_change);
1032       _viewOrCopyOneParam(
1033           layer_params_from[2], layer_params_to[4], copy, allow_type_change);
1034       continue;
1035     }
1036     if (layer_params_to.size() == 3 && layer_params_from.size() != 3) {
1037       _viewOrCopyOneParam(
1038           layer_params_from[0], layer_params_to[0], copy, allow_type_change);
1039       _viewOrCopyOneParam(
1040           layer_params_from[1], layer_params_to[1], copy, allow_type_change);
1041       _viewOrCopyOneParam(
1042           layer_params_from[4], layer_params_to[2], copy, allow_type_change);
1043       continue;
1044     }
1045     for (auto a = layer_params_from.begin(), b = layer_params_to.begin();
1046          a != layer_params_from.end() && b != layer_params_to.end();
1047          ++a, ++b) {
1048       _viewOrCopyOneParam(*a, *b, copy, allow_type_change);
1049     }
1050   }
1051 }
1052 
1053 void _copyParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
1054   _viewOrCopyParams(params_from, params_to, true);
1055 }
1056 
1057 void _viewParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
1058   _viewOrCopyParams(params_from, params_to, false);
1059 }
1060 
1061 std::vector<int64_t> _input_size(const TensorDescriptorListParams& tensors) {
1062   if (tensors.is_input_packed()) {
1063     return {tensors.batch_sizes_sum, tensors.input_size};
1064   } else {
1065     return {tensors.seq_length, tensors.mini_batch, tensors.input_size};
1066   }
1067 }
1068 
1069 std::vector<int64_t> _hidden_size(
1070     const RNNDescriptorParams& rnn,
1071     const TensorDescriptorListParams& tensors) {
1072   if (rnn.proj_size != 0) {
1073     return {
1074         rnn.num_layers * rnn.num_directions(),
1075         tensors.mini_batch,
1076         rnn.proj_size};
1077   } else {
1078     return {
1079         rnn.num_layers * rnn.num_directions(),
1080         tensors.mini_batch,
1081         rnn.hidden_size};
1082   }
1083 }
1084 
1085 std::vector<int64_t> _cell_size(
1086     const RNNDescriptorParams& rnn,
1087     const TensorDescriptorListParams& tensors) {
1088   return {
1089       rnn.num_layers * rnn.num_directions(),
1090       tensors.mini_batch,
1091       rnn.hidden_size};
1092 }
1093 
1094 std::vector<int64_t> _output_size(
1095     const RNNDescriptorParams& rnn,
1096     const TensorDescriptorListParams& tensors) {
1097   auto out_size = rnn.hidden_size;
1098   if (rnn.proj_size != 0) {
1099     out_size = rnn.proj_size;
1100   }
1101   if (tensors.is_input_packed()) {
1102     return {tensors.batch_sizes_sum, out_size * rnn.num_directions()};
1103   } else {
1104     return {
1105         tensors.seq_length,
1106         tensors.mini_batch,
1107         out_size * rnn.num_directions()};
1108   }
1109 }
1110 
1111 inline bool use_persist_common_heuristics(
1112     const RNNDescriptorParams& rnn,
1113     const TensorDescriptorListParams& tensors) {
1114   return rnn.num_layers == 1 && rnn.hidden_size <= 1024 &&
1115       rnn.num_directions() == 1 && rnn.hidden_size % 128 == 0 &&
1116       tensors.input_size % 128 == 0;
1117 }
1118 
1119 inline bool use_persist_device_heuristics(
1120     const RNNDescriptorParams& rnn,
1121     const TensorDescriptorListParams& tensors) {
1122   auto bsize = tensors.mini_batch;
1123   cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
1124   if (prop->major == 7) {
1125     if (prop->minor == 5) {
1126       // Excludes Turing from using persistent rnn.
1127       return false;
1128     } else {
1129       // technically, batch size should be multiple of 8, but there are quite a
1130       // few multiple-of-8 batchsizes that give bad perf, weed them out
1131       return ((bsize % 16 == 0 && bsize != 80 && bsize != 112) || bsize == 8) &&
1132           ((tensors.seq_length >= 40 && bsize <= 128) ||
1133            (tensors.seq_length >= 20 && bsize <= 96) ||
1134            (tensors.seq_length >= 10 && bsize <= 32));
1135     }
1136   } else if (prop->major >= 8 && prop->multiProcessorCount >= 98) {
1137     // SM count check excludes A30 (similar issue to A40)
1138     if (prop->minor == 6) {
1139       // Excludes sm_86 GPU devices from using persistent rnn.
1140       // This is because there are some edge cases that will throw exceptions
1141       // with cudnn 8.0.5 on Nvidia A40 GPU.
1142       return false;
1143     }
1144     // Based on tests by Vasily Volkov and xwang233.  Vasily only tried bsize <=
1145     // 128, so conservatively enable persistence for bsize <= 128 only.
1146     // TODO:  Run more tests for bsize > 128.
1147     if (rnn.mode == CUDNN_GRU) {
1148       // Persistent GRU performance is flakier than other RNN types.  Exclude
1149       // them for now.
1150       // TODO:  Write a more refined GRU heuristic.
1151       return false;
1152     } else if (rnn.mode == CUDNN_LSTM) {
1153       // Persistent LSTMs are comparable to or better than non-persistent for
1154       // bsize <= 128.
1155       return (bsize % 8 == 0) && (bsize <= 128);
1156     } else {
1157       // Persistent RNN_RELU and TANH show poor performance when bsize >= 96 AND
1158       // hidden size >= 896.
1159       return (bsize % 8 == 0) && (bsize <= 128) &&
1160           (bsize < 96 || rnn.hidden_size < 896);
1161     }
1162   } else {
1163     return false;
1164   }
1165 }
1166 
1167 inline bool use_rnn_persist_small_h(
1168     const RNNDescriptorParams& rnn,
1169     const TensorDescriptorListParams& tensors,
1170     bool forward) {
1171   cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
1172   if (prop->major < 6)
1173     return false;
1174 
1175   if (forward) {
1176     if (rnn.mode == CUDNN_RNN_RELU || rnn.mode == CUDNN_RNN_TANH) {
1177       return rnn.hidden_size <= 384;
1178     }
1179     if (rnn.mode == CUDNN_LSTM || rnn.mode == CUDNN_GRU) {
1180       return rnn.hidden_size <= 192;
1181     }
1182   } else /* backward */ {
1183     if (rnn.mode == CUDNN_RNN_RELU || rnn.mode == CUDNN_RNN_TANH) {
1184       return rnn.hidden_size <= 256;
1185     }
1186     if (rnn.mode == CUDNN_LSTM || rnn.mode == CUDNN_GRU) {
1187       return rnn.hidden_size <= 128;
1188     }
1189   }
1190 
1191   return false;
1192 }
1193 
1194 cudnnRNNAlgo_t get_algo(
1195     const RNNDescriptorParams& rnn,
1196     const TensorDescriptorListParams& tensors,
1197     const Tensor input,
1198     bool forward) {
1199   // LSTM with projections only works with standard algorithm
1200   if (rnn.proj_size != 0) {
1201     return CUDNN_RNN_ALGO_STANDARD;
1202   }
1203 
1204   // Persistent algos typically don't work for packed inputs with sequence
1205   // lengths that vary across batch elements, and will return
1206   // CUDNN_STATUS_NOT_SUPPORTED if attempted. See
1207   // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#features-of-rnn-functions
1208   if (!tensors.is_input_packed()) {
1209     auto cudnnDataType = getCudnnDataType(input);
1210     if (cudnnDataType != CUDNN_DATA_DOUBLE) {
1211       if (use_rnn_persist_small_h(rnn, tensors, forward)) {
1212         return CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H;
1213       }
1214     }
1215     if (cudnnDataType == CUDNN_DATA_HALF) {
1216       if (use_persist_common_heuristics(rnn, tensors) &&
1217           use_persist_device_heuristics(rnn, tensors)) {
1218         return CUDNN_RNN_ALGO_PERSIST_STATIC;
1219       }
1220     }
1221   }
1222 
1223   return CUDNN_RNN_ALGO_STANDARD;
1224 }
1225 
1226 cudnnDataType_t promote_rnn_math_type(cudnnDataType_t dtype) {
1227   if (dtype == CUDNN_DATA_HALF) {
1228     return CUDNN_DATA_FLOAT;
1229   }
1230   return dtype;
1231 }
1232 
1233 } // namespace native
1234 
1235 // Utilities exposed in RNNUtils.h
1236 namespace cudnn_rnn {
1237 
1238 TORCH_CUDA_CPP_API std::tuple<Tensor, std::vector<Tensor>>
copy_weights_to_flat_buf_views(TensorList weight_arr,int64_t weight_stride0,int64_t input_size,int64_t mode,int64_t hidden_size,int64_t proj_size,int64_t num_layers,bool batch_first,bool bidirectional,const cudnnDataType_t flat_buf_datatype,const TensorOptions & flat_buf_options,bool set_orig_weights_to_flat_buf,bool allow_type_change,bool include_bias)1239 copy_weights_to_flat_buf_views(
1240     TensorList weight_arr,
1241     int64_t weight_stride0,
1242     int64_t input_size,
1243     int64_t mode,
1244     int64_t hidden_size,
1245     int64_t proj_size,
1246     int64_t num_layers,
1247     bool batch_first,
1248     bool bidirectional,
1249     const cudnnDataType_t flat_buf_datatype,
1250     const TensorOptions& flat_buf_options,
1251     bool set_orig_weights_to_flat_buf,
1252     bool allow_type_change /*=false*/,
1253     bool include_bias /*=true*/) {
1254   // flat_buf_datatype is accepted as a separate argument (rather than extracted
1255   // from flat_buf_options) because to extract flat_buf_datatype from
1256   // flat_buf_options, we'd need to say auto flat_buf_datatype =
1257   // getCudnnDataTypeFromScalarType(typeMetaToScalarType(options.dtype()));
1258   // typeMetaToScalarType is a surprisingly nontrivial function.  We should
1259   // avoid it if we can.
1260   TORCH_CHECK(
1261       weight_arr.size() > 0,
1262       "copy_weights_to_flat_buf_views: cannot flatten empty weight list");
1263 
1264   RNNDescriptorParams rnn;
1265   rnn.set(
1266       mode,
1267 #ifdef USE_CUDNN_RNN_V8_API
1268       input_size,
1269       false, // eqy: bogus as we do not know if the input is packed here
1270              // but it should not affect the weights (what are are interested
1271              // in)
1272 #endif
1273       hidden_size,
1274       proj_size,
1275       num_layers,
1276       bidirectional,
1277       promote_rnn_math_type(flat_buf_datatype),
1278       flat_buf_datatype);
1279 
1280   auto handle = getCudnnHandle();
1281   RNNDescriptor rnn_desc = rnn.descriptor(handle);
1282 
1283   TensorGeometry x_geom({1, input_size});
1284   TensorDescriptor x_desc;
1285   // Why do we pad to 5 dims here (and elsewhere)?
1286   // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNForwardTraining
1287   // expects descriptors padded to 3 dimensions.
1288   x_desc.set(flat_buf_datatype, x_geom.sizes(), x_geom.strides(), 5);
1289 
1290   auto num_weights =
1291 #ifndef USE_CUDNN_RNN_V8_API
1292       get_num_weights(handle, rnn_desc, x_desc, flat_buf_datatype);
1293 #else
1294       get_num_weights(handle, rnn_desc, flat_buf_datatype);
1295 #endif
1296   auto weight_buf = at::zeros(num_weights, flat_buf_options);
1297 
1298 #ifndef USE_CUDNN_RNN_V8_API
1299   FilterDescriptor w_desc;
1300   w_desc.set(weight_buf, 3);
1301 #endif
1302 
1303   // Slice off views into weight_buf
1304   auto [params_arr, params_stride0] = get_parameters(
1305 #ifndef USE_CUDNN_RNN_V8_API
1306       handle, rnn, rnn_desc, x_desc, w_desc, weight_buf, include_bias);
1307 #else
1308       handle, rnn, rnn_desc, weight_buf, include_bias);
1309 #endif
1310   MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)},
1311       params{params_arr, params_stride0};
1312 
1313   // Copy weights
1314   _viewOrCopyParams(weight, params, /*copy=*/true, allow_type_change);
1315   if (set_orig_weights_to_flat_buf) {
1316     // Update the storage
1317     for (const auto i : c10::irange(weight.size(0))) {
1318       // There is a special case for LSTM with projections and no bias,
1319       // where weight copy is done in 0->0, 1->1, 2->4 layout
1320       if (weight[i].size() == 3 && params[i].size() == 5) {
1321         weight[i][0].set_(params[i][0].view_as(weight[i][0]));
1322         weight[i][1].set_(params[i][1].view_as(weight[i][1]));
1323         weight[i][2].set_(params[i][4].view_as(weight[i][2]));
1324       } else {
1325         for (auto orig_param_it = weight[i].begin(),
1326                   new_param_it = params[i].begin();
1327              orig_param_it != weight[i].end() &&
1328              new_param_it != params[i].end();
1329              orig_param_it++, new_param_it++) {
1330           auto orig_param = *orig_param_it, new_param = *new_param_it;
1331           orig_param.set_(new_param.view_as(orig_param));
1332         }
1333       }
1334     }
1335   }
1336 
1337   return std::make_tuple(weight_buf, params_arr);
1338 }
1339 
1340 } // namespace cudnn_rnn
1341 
1342 using namespace cudnn_rnn;
1343 
1344 // NB: does inplace update into TensorList
1345 // It would be a relatively simple matter to refactor this into multiple
1346 // functions, only one of which does an inplace update, but we leave this
1347 // for future work
_cudnn_rnn_flatten_weight(TensorList weight_arr,int64_t weight_stride0,int64_t input_size,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,bool fn_bidirectional)1348 Tensor _cudnn_rnn_flatten_weight(
1349     TensorList weight_arr,
1350     int64_t weight_stride0,
1351     int64_t input_size,
1352     int64_t fn_mode,
1353     int64_t fn_hidden_size,
1354     int64_t fn_proj_size,
1355     int64_t fn_num_layers,
1356     bool batch_first,
1357     bool fn_bidirectional) {
1358   // returns flat weight_buf
1359   return std::get<0>(copy_weights_to_flat_buf_views(
1360       weight_arr,
1361       weight_stride0,
1362       input_size,
1363       fn_mode,
1364       fn_hidden_size,
1365       fn_proj_size,
1366       fn_num_layers,
1367       batch_first,
1368       fn_bidirectional,
1369       /*flat_buf_datatype=*/getCudnnDataType(weight_arr[0]),
1370       /*flat_buf_options=*/weight_arr[0].options(),
1371       /*set_orig_weights_to_flat_buf=*/true));
1372 }
1373 
1374 const char* WEIGHT_FORMAT_WARN =
1375     "RNN module weights are not part of single contiguous "
1376     "chunk of memory. This means they need to be compacted "
1377     "at every call, possibly greatly increasing memory usage. "
1378     "To compact weights again call flatten_parameters().";
1379 
1380 // NB: when fn_batch_sizes is empty, that means no batch sizes was specified
_cudnn_rnn(const Tensor & input_r,TensorList weight,int64_t weight_stride0,const std::optional<Tensor> & weight_buf_r_opt,const Tensor & hx,const std::optional<Tensor> & cx_opt,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const std::optional<Tensor> & fn_dropout_state_opt)1381 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
1382     const Tensor& input_r,
1383     TensorList weight,
1384     int64_t weight_stride0,
1385     const std::optional<Tensor>& weight_buf_r_opt,
1386     const Tensor& hx,
1387     const std::optional<Tensor>& cx_opt,
1388     int64_t fn_mode,
1389     int64_t fn_hidden_size,
1390     int64_t fn_proj_size,
1391     int64_t fn_num_layers,
1392     bool batch_first,
1393     double fn_dropout,
1394     bool fn_train,
1395     bool fn_bidirectional,
1396     IntArrayRef fn_batch_sizes,
1397     const std::optional<Tensor>& fn_dropout_state_opt) {
1398   // See [Note: hacky wrapper removal for optional tensor]
1399   c10::MaybeOwned<Tensor> weight_buf_r_maybe_owned =
1400       at::borrow_from_optional_tensor(weight_buf_r_opt);
1401   const Tensor& weight_buf_r = *weight_buf_r_maybe_owned;
1402   const Tensor& cx = c10::value_or_else(cx_opt, [] { return Tensor(); });
1403   const Tensor& fn_dropout_state =
1404       c10::value_or_else(fn_dropout_state_opt, [] { return Tensor(); });
1405 
1406   check_attributes(input_r, weight, {hx, cx}, /*check_dtype=*/true);
1407   auto input = input_r;
1408   auto weight_buf = weight_buf_r;
1409   if (!weight_buf.defined()) {
1410     TORCH_WARN(WEIGHT_FORMAT_WARN);
1411   }
1412   if (fn_dropout_state.defined()) {
1413     auto input_arg = TensorArg(input, "input", 1);
1414     auto dropout_state_arg = TensorArg(fn_dropout_state, "dropout_states", 15);
1415     checkSameGPU("cudnn_rnn", input_arg, dropout_state_arg);
1416   }
1417   RNNParams fn;
1418   auto datatype = getCudnnDataType(input);
1419 #ifndef USE_CUDNN_RNN_V8_API
1420   fn.rnn.set(
1421       fn_mode,
1422       fn_hidden_size,
1423       fn_proj_size,
1424       fn_num_layers,
1425       fn_bidirectional,
1426       promote_rnn_math_type(datatype),
1427       datatype);
1428 #else
1429   auto input_size = input_r.size(-1);
1430   auto packed = fn_batch_sizes.size() != 0;
1431   fn.rnn.set(
1432       fn_mode,
1433       input_size,
1434       packed,
1435       fn_hidden_size,
1436       fn_proj_size,
1437       fn_num_layers,
1438       fn_bidirectional,
1439       promote_rnn_math_type(datatype),
1440       datatype);
1441 #endif
1442   fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
1443   fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
1444 
1445   // TODO: Set device to input
1446 
1447   if (fn.rnn.mode != CUDNN_LSTM) {
1448     TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
1449   }
1450 
1451   // TODO: can batch_first be a wrapper around this function?
1452   auto is_input_packed = fn.tensors.batch_sizes.size() != 0;
1453   if (batch_first && !is_input_packed) {
1454     input = input.transpose(0, 1);
1455   }
1456 
1457   auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
1458   auto cell_size = _cell_size(fn.rnn, fn.tensors);
1459   auto output_size = _output_size(fn.rnn, fn.tensors);
1460 
1461   TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous");
1462   TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous");
1463 
1464   auto x = input.contiguous();
1465   auto output = at::empty(output_size, input.options());
1466   auto hy = at::empty(hidden_size, hx.options());
1467   Tensor cy;
1468   if (cx.defined()) {
1469     cy = at::empty(cell_size, cx.options());
1470   } else {
1471     cy = at::empty(
1472         {0}, hx.options()); // NB: Not allowed to return undefined tensors
1473   }
1474   auto y = output;
1475 
1476   auto handle = getCudnnHandle();
1477   cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, true);
1478   fn.rnn.set_algo(algo);
1479   RNNDescriptors descs(fn, handle, x, y, hx, cx);
1480 
1481 #ifndef USE_CUDNN_RNN_V8_API
1482   FilterDescriptor w_desc;
1483 #endif
1484   if (!weight_buf.defined()) {
1485 #ifndef USE_CUDNN_RNN_V8_API
1486     auto num_weights =
1487         get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], datatype);
1488 #else
1489     auto num_weights = get_num_weights(handle, descs.rnn_desc, datatype);
1490 #endif
1491     weight_buf = at::empty(num_weights, x.options());
1492 #ifndef USE_CUDNN_RNN_V8_API
1493     w_desc.set(weight_buf, 3);
1494 #endif
1495     weight_buf.zero_();
1496 #ifndef USE_CUDNN_RNN_V8_API
1497     auto [params, params_stride0] = get_parameters(
1498         handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, weight_buf);
1499 #else
1500     auto [params, params_stride0] =
1501         get_parameters(handle, fn.rnn, descs.rnn_desc, weight_buf);
1502 #endif
1503     _copyParams(
1504         MatrixRef<Tensor>{weight, static_cast<size_t>(weight_stride0)},
1505         MatrixRef<Tensor>{params, params_stride0});
1506   } else {
1507 #ifndef USE_CUDNN_RNN_V8_API
1508     w_desc.set(weight_buf, 3);
1509 #endif
1510   }
1511 
1512   TORCH_CHECK(
1513       !cx.defined() || cx.sizes().equals(cell_size),
1514       "Expected cell size ",
1515       IntArrayRef{cell_size},
1516       ", got ",
1517       cx.sizes());
1518   size_t workspace_size;
1519 #ifndef USE_CUDNN_RNN_V8_API
1520   auto x_descs_arr = descs.get_x_descs();
1521   auto y_descs_arr = descs.get_y_descs();
1522 #else
1523   auto& x_descs_arr = descs.x_descs;
1524   auto& y_descs_arr = descs.y_descs;
1525 #endif
1526 #ifndef USE_CUDNN_RNN_V8_API
1527   AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
1528       handle,
1529       descs.rnn_desc.desc(),
1530       fn.tensors.seq_length,
1531       x_descs_arr.data(),
1532       &workspace_size));
1533 #endif
1534   Tensor workspace;
1535   Tensor reserve;
1536   // NB: Previously, the test was for fn.requires_grad, but we don't have
1537   // this information.  Use 'train' as a proxy.
1538   if (fn_train) {
1539     size_t reserve_size;
1540 #ifndef USE_CUDNN_RNN_V8_API
1541     AT_CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(
1542         handle,
1543         descs.rnn_desc.desc(),
1544         fn.tensors.seq_length,
1545         x_descs_arr.data(),
1546         &reserve_size));
1547 #else
1548     AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes(
1549         handle,
1550         descs.rnn_desc.desc(),
1551         CUDNN_FWD_MODE_TRAINING,
1552         x_descs_arr.desc(),
1553         &workspace_size,
1554         &reserve_size));
1555 #endif
1556     workspace = at::empty(workspace_size, input.options().dtype(kByte));
1557     reserve = at::empty(reserve_size, input.options().dtype(kByte));
1558 #ifndef USE_CUDNN_RNN_V8_API
1559     AT_CUDNN_CHECK(cudnnRNNForwardTraining(
1560         handle,
1561         descs.rnn_desc.desc(),
1562         fn.tensors.seq_length,
1563         x_descs_arr.data(),
1564         x.data_ptr(),
1565         descs.hx_desc.desc(),
1566         hx.data_ptr(),
1567         descs.cx_desc.desc(),
1568         cx.defined() ? cx.data_ptr() : nullptr,
1569         w_desc.desc(),
1570         weight_buf.data_ptr(),
1571         y_descs_arr.data(),
1572         y.data_ptr(),
1573         descs.hy_desc.desc(),
1574         hy.data_ptr(),
1575         descs.cy_desc.desc(),
1576         cy.defined() ? cy.data_ptr() : nullptr,
1577         workspace.data_ptr(),
1578         workspace.size(0),
1579         reserve.mutable_data_ptr(),
1580         reserve.size(0)));
1581 #else
1582     AT_CUDNN_CHECK(cudnnRNNForward(
1583         handle,
1584         descs.rnn_desc.desc(),
1585         CUDNN_FWD_MODE_TRAINING,
1586         nullptr,
1587         x_descs_arr.desc(),
1588         x.data_ptr(),
1589         y_descs_arr.desc(),
1590         y.data_ptr(),
1591         descs.hx_desc.desc(),
1592         hx.data_ptr(),
1593         hy.data_ptr(),
1594         descs.cx_desc.desc(),
1595         cx.defined() ? cx.data_ptr() : nullptr,
1596         cy.defined() ? cy.data_ptr() : nullptr,
1597         weight_buf.numel() * weight_buf.element_size(),
1598         weight_buf.data_ptr(),
1599         workspace.size(0),
1600         workspace.data_ptr(),
1601         reserve.size(0),
1602         reserve.mutable_data_ptr()));
1603 #endif
1604   } else { // inference
1605 #ifdef USE_CUDNN_RNN_V8_API
1606     AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes(
1607         handle,
1608         descs.rnn_desc.desc(),
1609         CUDNN_FWD_MODE_INFERENCE,
1610         x_descs_arr.desc(),
1611         &workspace_size,
1612         NULL));
1613 #endif
1614     workspace = at::empty(workspace_size, input.options().dtype(kByte));
1615     reserve = at::empty({0}, input.options().dtype(kByte));
1616 #ifndef USE_CUDNN_RNN_V8_API
1617     AT_CUDNN_CHECK(cudnnRNNForwardInference(
1618         handle,
1619         descs.rnn_desc.desc(),
1620         fn.tensors.seq_length,
1621         x_descs_arr.data(),
1622         x.data_ptr(),
1623         descs.hx_desc.desc(),
1624         hx.data_ptr(),
1625         descs.cx_desc.desc(),
1626         cx.defined() ? cx.data_ptr() : nullptr,
1627         w_desc.desc(),
1628         weight_buf.data_ptr(),
1629         y_descs_arr.data(),
1630         y.data_ptr(),
1631         descs.hy_desc.desc(),
1632         hy.data_ptr(),
1633         descs.cy_desc.desc(),
1634         cy.defined() ? cy.data_ptr() : nullptr,
1635         workspace.data_ptr(),
1636         workspace.size(0)));
1637 #else
1638     AT_CUDNN_CHECK(cudnnRNNForward(
1639         handle,
1640         descs.rnn_desc.desc(),
1641         CUDNN_FWD_MODE_INFERENCE,
1642         nullptr,
1643         x_descs_arr.desc(),
1644         x.data_ptr(),
1645         y_descs_arr.desc(),
1646         y.data_ptr(),
1647         descs.hx_desc.desc(),
1648         hx.data_ptr(),
1649         hy.data_ptr(),
1650         descs.cx_desc.desc(),
1651         cx.defined() ? cx.data_ptr() : nullptr,
1652         cy.defined() ? cy.data_ptr() : nullptr,
1653         weight_buf.numel() * weight_buf.element_size(),
1654         weight_buf.data_ptr(),
1655         workspace.size(0),
1656         workspace.data_ptr(),
1657         reserve.size(0),
1658         reserve.mutable_data_ptr()));
1659 #endif
1660   }
1661 
1662   if (batch_first && !is_input_packed) {
1663     output.transpose_(0, 1);
1664   }
1665 
1666   return std::make_tuple(output, hy, cy, reserve, weight_buf);
1667 }
1668 
_cudnn_rnn_backward_input(const Tensor & input_r,const Tensor & weight_buf,const Tensor & hx,const Tensor & cx,const Tensor & output_r,const Tensor & grad_output_r,const Tensor & grad_hy,const Tensor & grad_cy,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const Tensor & fn_dropout_state,const Tensor & fn_reserve,std::array<bool,3> output_mask)1669 std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
1670     const Tensor& input_r,
1671     const Tensor& weight_buf,
1672     const Tensor& hx,
1673     const Tensor& cx,
1674     const Tensor& output_r,
1675     const Tensor& grad_output_r,
1676     const Tensor& grad_hy,
1677     const Tensor& grad_cy,
1678     int64_t fn_mode,
1679     int64_t fn_hidden_size,
1680     int64_t fn_proj_size,
1681     int64_t fn_num_layers,
1682     bool batch_first,
1683     double fn_dropout,
1684     bool fn_train,
1685     bool fn_bidirectional,
1686     IntArrayRef fn_batch_sizes,
1687     const Tensor& fn_dropout_state,
1688     const Tensor& fn_reserve,
1689     std::array<bool, 3> output_mask) {
1690   auto input = input_r;
1691   auto grad_output = grad_output_r;
1692   auto output = output_r;
1693 
1694   RNNParams fn;
1695   auto datatype = getCudnnDataType(input);
1696 #ifndef USE_CUDNN_RNN_V8_API
1697   fn.rnn.set(
1698       fn_mode,
1699       fn_hidden_size,
1700       fn_proj_size,
1701       fn_num_layers,
1702       fn_bidirectional,
1703       promote_rnn_math_type(datatype),
1704       datatype);
1705 #else
1706   auto cudnn_input_size = input_r.size(-1);
1707   auto packed = fn_batch_sizes.size() != 0;
1708   fn.rnn.set(
1709       fn_mode,
1710       cudnn_input_size,
1711       packed,
1712       fn_hidden_size,
1713       fn_proj_size,
1714       fn_num_layers,
1715       fn_bidirectional,
1716       promote_rnn_math_type(datatype),
1717       datatype);
1718 #endif
1719   fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
1720   fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
1721 
1722   // TODO: Set device to input
1723   auto handle = getCudnnHandle();
1724 
1725   if (fn.rnn.mode != CUDNN_LSTM) {
1726     TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
1727   }
1728 
1729   auto is_input_packed = fn_batch_sizes.size() != 0;
1730   if (batch_first && !is_input_packed) {
1731     input = input.transpose(0, 1);
1732     grad_output = grad_output.transpose(0, 1);
1733     output = output.transpose(0, 1);
1734   }
1735 
1736   auto input_size = _input_size(fn.tensors);
1737   auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
1738   auto cell_size = _cell_size(fn.rnn, fn.tensors);
1739   auto output_size = _output_size(fn.rnn, fn.tensors);
1740 
1741   TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous");
1742   TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous");
1743 
1744   auto x = input.contiguous();
1745   auto dy = grad_output.contiguous();
1746   auto y = output;
1747   auto w = weight_buf;
1748   auto dx = at::empty(
1749       input.sizes(), input.options()); // TODO: more compact way of saying this
1750   auto dhy = grad_hy.contiguous().view(hidden_size);
1751   auto dcy =
1752       grad_cy.defined() ? grad_cy.contiguous().view(cell_size) : Tensor();
1753   auto dhx = at::empty(hidden_size, hx.options());
1754   TORCH_INTERNAL_ASSERT(
1755       cx.defined() || !output_mask[2],
1756       "illegally required grad of cx for non-LSTM RNN");
1757   auto dcx = cx.defined() ? at::empty(cell_size, cx.options()) : Tensor();
1758 
1759   TORCH_CHECK(
1760       fn_train, "cudnn RNN backward can only be called in training mode");
1761 
1762   TORCH_CHECK(
1763       input.sizes().equals(input_size),
1764       "Expected input size ",
1765       IntArrayRef{input_size},
1766       ", got ",
1767       input.sizes());
1768   TORCH_CHECK(
1769       output.sizes().equals(output_size),
1770       "Expected output size ",
1771       IntArrayRef{output_size},
1772       ", got ",
1773       output.sizes());
1774 
1775   TORCH_CHECK(
1776       !hx.defined() || hx.sizes().equals(hidden_size),
1777       "Expected hidden size ",
1778       IntArrayRef{hidden_size},
1779       ", got ",
1780       hx.sizes());
1781   TORCH_CHECK(
1782       !cx.defined() || cx.sizes().equals(cell_size),
1783       "Expected cell size ",
1784       IntArrayRef{cell_size},
1785       ", got ",
1786       cx.sizes());
1787   TORCH_CHECK(
1788       !dhy.defined() || dhy.sizes().equals(hidden_size),
1789       "Expected d_hidden size ",
1790       IntArrayRef{hidden_size},
1791       ", got ",
1792       dhy.sizes());
1793   TORCH_CHECK(
1794       !dcy.defined() || dcy.sizes().equals(cell_size),
1795       "Expected d_cell size ",
1796       IntArrayRef{cell_size},
1797       ", got ",
1798       dcy.sizes());
1799 
1800   TORCH_CHECK(
1801       dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()),
1802       "Gradients aren't CUDA tensors");
1803 
1804   cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, false);
1805   fn.rnn.set_algo(algo);
1806   RNNDescriptors descs(fn, handle, x, y, hx, cx);
1807 
1808 #ifndef USE_CUDNN_RNN_V8_API
1809   FilterDescriptor w_desc;
1810   w_desc.set(weight_buf, 3);
1811 #endif
1812 
1813   size_t workspace_size;
1814 #ifndef USE_CUDNN_RNN_V8_API
1815   auto x_descs_arr = descs.get_x_descs();
1816   auto y_descs_arr = descs.get_y_descs();
1817   AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
1818       handle,
1819       descs.rnn_desc.desc(),
1820       fn.tensors.seq_length,
1821       x_descs_arr.data(),
1822       &workspace_size));
1823 #else
1824   auto& x_descs_arr = descs.x_descs;
1825   auto& y_descs_arr = descs.y_descs;
1826   AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes(
1827       handle,
1828       descs.rnn_desc.desc(),
1829       CUDNN_FWD_MODE_TRAINING,
1830       x_descs_arr.desc(),
1831       &workspace_size,
1832       NULL));
1833 #endif
1834   // TODO: put this in the correct device???
1835   Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
1836 #ifndef USE_CUDNN_RNN_V8_API
1837   AT_CUDNN_CHECK(cudnnRNNBackwardData(
1838       handle,
1839       descs.rnn_desc.desc(),
1840       fn.tensors.seq_length,
1841       y_descs_arr.data(),
1842       y.data_ptr(),
1843       y_descs_arr.data(),
1844       dy.data_ptr(),
1845       descs.hy_desc.desc(),
1846       dhy.data_ptr(),
1847       descs.cy_desc.desc(),
1848       cx.defined() ? dcy.data_ptr() : nullptr,
1849       w_desc.desc(),
1850       w.data_ptr(),
1851       descs.hx_desc.desc(),
1852       hx.data_ptr(),
1853       descs.cx_desc.desc(),
1854       cx.defined() ? cx.data_ptr() : nullptr,
1855       x_descs_arr.data(),
1856       dx.data_ptr(),
1857       descs.hx_desc.desc(),
1858       dhx.data_ptr(),
1859       descs.cx_desc.desc(),
1860       cx.defined() ? dcx.data_ptr() : nullptr,
1861       workspace.data_ptr(),
1862       workspace.size(0),
1863       fn_reserve.data_ptr(),
1864       fn_reserve.size(0)));
1865 #else
1866   AT_CUDNN_CHECK(cudnnRNNBackwardData_v8(
1867       handle,
1868       descs.rnn_desc.desc(),
1869       nullptr,
1870       y_descs_arr.desc(),
1871       y.data_ptr(),
1872       dy.data_ptr(),
1873       x_descs_arr.desc(),
1874       dx.data_ptr(),
1875       descs.hx_desc.desc(),
1876       hx.data_ptr(),
1877       dhy.data_ptr(),
1878       dhx.data_ptr(),
1879       descs.cx_desc.desc(),
1880       cx.defined() ? cx.data_ptr() : nullptr,
1881       cx.defined() ? dcy.data_ptr() : nullptr,
1882       cx.defined() ? dcx.data_ptr() : nullptr,
1883       weight_buf.numel() * weight_buf.element_size(),
1884       weight_buf.data_ptr(),
1885       workspace.size(0),
1886       workspace.data_ptr(),
1887       fn_reserve.size(0),
1888       fn_reserve.data_ptr()));
1889 #endif
1890   if (batch_first && !is_input_packed) {
1891     dx = dx.transpose_(0, 1);
1892   }
1893 
1894   return std::make_tuple(dx, dhx, dcx);
1895 }
1896 
1897 // NB: This MUST BE CALLED AFTER _cudnn_rnn_backward_input.
1898 // We'll give a user friendly combined function...
_cudnn_rnn_backward_weight(const Tensor & input_r,TensorList weight_arr,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const Tensor & cx,const Tensor & output_r,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_proj_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const Tensor & fn_dropout_state,const Tensor & fn_reserve)1899 std::vector<Tensor> _cudnn_rnn_backward_weight(
1900     // TODO: I think tensor geometry sufficient for weight_buf/weight
1901     const Tensor& input_r,
1902     TensorList weight_arr,
1903     int64_t weight_stride0,
1904     const Tensor& weight_buf,
1905     const Tensor& hx,
1906     const Tensor& cx,
1907     const Tensor& output_r,
1908     int64_t fn_mode,
1909     int64_t fn_hidden_size,
1910     int64_t fn_proj_size,
1911     int64_t fn_num_layers,
1912     bool batch_first,
1913     double fn_dropout,
1914     bool fn_train,
1915     bool fn_bidirectional,
1916     IntArrayRef fn_batch_sizes,
1917     const Tensor& fn_dropout_state,
1918     const Tensor& fn_reserve) {
1919   MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)};
1920   auto input = input_r;
1921   auto output = output_r;
1922 
1923   RNNParams fn;
1924   auto datatype = getCudnnDataType(input);
1925 #ifndef USE_CUDNN_RNN_V8_API
1926   fn.rnn.set(
1927       fn_mode,
1928       fn_hidden_size,
1929       fn_proj_size,
1930       fn_num_layers,
1931       fn_bidirectional,
1932       promote_rnn_math_type(datatype),
1933       datatype);
1934 #else
1935   auto cudnn_input_size = input_r.size(-1);
1936   auto packed = fn_batch_sizes.size() != 0;
1937   fn.rnn.set(
1938       fn_mode,
1939       cudnn_input_size,
1940       packed,
1941       fn_hidden_size,
1942       fn_proj_size,
1943       fn_num_layers,
1944       fn_bidirectional,
1945       promote_rnn_math_type(datatype),
1946       datatype);
1947 #endif
1948   fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
1949   fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
1950 
1951   auto handle = getCudnnHandle();
1952 
1953   if (fn.rnn.mode != CUDNN_LSTM) {
1954     TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
1955   }
1956 
1957   auto is_input_packed = fn_batch_sizes.size() != 0;
1958   if (batch_first && !is_input_packed) {
1959     input = input.transpose(0, 1);
1960     output = output.transpose(0, 1);
1961   }
1962 
1963   auto input_size = _input_size(fn.tensors);
1964   auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
1965 
1966   TORCH_CHECK(
1967       fn_train, "cudnn RNN backward can only be called in training mode");
1968 
1969   TORCH_CHECK(
1970       input.sizes().equals(input_size),
1971       "Expected input size ",
1972       IntArrayRef{input_size},
1973       ", got ",
1974       input.sizes());
1975   TORCH_CHECK(
1976       !hx.defined() || hx.sizes().equals(hidden_size),
1977       "Expected hidden size ",
1978       IntArrayRef{hidden_size},
1979       ", got ",
1980       hx.sizes());
1981 
1982   // TODO: the above were the only checks in rnn.py, but it doesn't seem
1983   // like these checks are enough
1984 
1985   TORCH_CHECK(hx.is_contiguous(), "rnn: hx is not contiguous");
1986   TORCH_CHECK(!cx.defined() || cx.is_contiguous(), "rnn: cx is not contiguous");
1987 
1988   auto x = input.contiguous();
1989   const auto& y = output;
1990   auto dw = at::zeros(weight_buf.sizes(), weight_buf.options());
1991 
1992   cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, false);
1993   fn.rnn.set_algo(algo);
1994   RNNDescriptors descs(fn, handle, x, y, hx, cx);
1995 
1996 #ifndef USE_CUDNN_RNN_V8_API
1997   FilterDescriptor w_desc;
1998   w_desc.set(weight_buf, 3);
1999 #endif
2000 
2001   size_t workspace_size;
2002 #ifndef USE_CUDNN_RNN_V8_API
2003   auto x_descs_arr = descs.get_x_descs();
2004   auto y_descs_arr = descs.get_y_descs();
2005   AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
2006       handle,
2007       descs.rnn_desc.desc(),
2008       fn.tensors.seq_length,
2009       x_descs_arr.data(),
2010       &workspace_size));
2011 #else
2012   auto& x_descs_arr = descs.x_descs;
2013   auto& y_descs_arr = descs.y_descs;
2014   AT_CUDNN_CHECK(cudnnGetRNNTempSpaceSizes(
2015       handle,
2016       descs.rnn_desc.desc(),
2017       CUDNN_FWD_MODE_TRAINING,
2018       x_descs_arr.desc(),
2019       &workspace_size,
2020       NULL));
2021 #endif
2022   Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
2023 #ifndef USE_CUDNN_RNN_V8_API
2024   AT_CUDNN_CHECK(cudnnRNNBackwardWeights(
2025       handle,
2026       descs.rnn_desc.desc(),
2027       fn.tensors.seq_length,
2028       x_descs_arr.data(),
2029       x.data_ptr(),
2030       descs.hx_desc.desc(),
2031       hx.data_ptr(),
2032       y_descs_arr.data(),
2033       y.data_ptr(),
2034       workspace.data_ptr(),
2035       workspace.size(0),
2036       w_desc.desc(),
2037       dw.data_ptr(),
2038       fn_reserve.data_ptr(),
2039       fn_reserve.size(0)));
2040 #else
2041   AT_CUDNN_CHECK(cudnnRNNBackwardWeights_v8(
2042       handle,
2043       descs.rnn_desc.desc(),
2044       CUDNN_WGRAD_MODE_ADD,
2045       nullptr,
2046       x_descs_arr.desc(),
2047       x.data_ptr(),
2048       descs.hx_desc.desc(),
2049       hx.data_ptr(),
2050       y_descs_arr.desc(),
2051       y.data_ptr(),
2052       weight_buf.numel() * weight_buf.element_size(),
2053       dw.data_ptr(),
2054       workspace.size(0),
2055       workspace.data_ptr(),
2056       fn_reserve.size(0),
2057       fn_reserve.data_ptr()));
2058 #endif
2059 
2060 #ifndef USE_CUDNN_RNN_V8_API
2061   auto [grad_params_arr, grad_params_stride0] = get_parameters(
2062       handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, dw);
2063 #else
2064   auto [grad_params_arr, grad_params_stride0] =
2065       get_parameters(handle, fn.rnn, descs.rnn_desc, dw);
2066 #endif
2067   if (grad_params_stride0 == static_cast<size_t>(weight_stride0)) {
2068     _viewParams(
2069         MatrixRef<Tensor>{grad_params_arr, grad_params_stride0},
2070         MatrixRef<Tensor>{weight_arr, static_cast<size_t>(weight_stride0)});
2071     return grad_params_arr;
2072   } else {
2073     std::vector<Tensor> grad_weight_arr;
2074     grad_weight_arr.reserve(weight.numel());
2075     for (const auto& w : weight_arr) {
2076       grad_weight_arr.emplace_back(at::empty(w.sizes(), w.options()));
2077     }
2078     _copyParams(
2079         MatrixRef<Tensor>{grad_params_arr, grad_params_stride0},
2080         MatrixRef<Tensor>{
2081             grad_weight_arr, static_cast<size_t>(weight_stride0)});
2082     return grad_weight_arr;
2083   }
2084 }
2085 
2086 // We need this dispatcher because _cudnn_rnn_backward_weight has a stringent
2087 // ordering requirement with _cudnn_rnn_backward_input
_cudnn_rnn_backward(const Tensor & input,TensorList weight,int64_t weight_stride0,const Tensor & weight_buf,const Tensor & hx,const std::optional<Tensor> & cx_opt,const Tensor & output,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,int64_t mode,int64_t hidden_size,int64_t proj_size,int64_t num_layers,bool batch_first,double dropout,bool train,bool bidirectional,IntArrayRef batch_sizes,const std::optional<Tensor> & dropout_state_opt,const Tensor & reserve,std::array<bool,4> output_mask)2088 std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> _cudnn_rnn_backward(
2089     const Tensor& input,
2090     TensorList weight,
2091     int64_t weight_stride0,
2092     const Tensor& weight_buf,
2093     const Tensor& hx,
2094     const std::optional<Tensor>& cx_opt,
2095     const Tensor& output,
2096     const std::optional<Tensor>& grad_output_r_opt,
2097     const std::optional<Tensor>& grad_hy_r_opt,
2098     const std::optional<Tensor>& grad_cy_r_opt,
2099     int64_t mode,
2100     int64_t hidden_size,
2101     int64_t proj_size,
2102     int64_t num_layers,
2103     bool batch_first,
2104     double dropout,
2105     bool train,
2106     bool bidirectional,
2107     IntArrayRef batch_sizes,
2108     const std::optional<Tensor>& dropout_state_opt,
2109     const Tensor& reserve,
2110     std::array<bool, 4> output_mask) {
2111   // See [Note: hacky wrapper removal for optional tensor]
2112   c10::MaybeOwned<Tensor> cx_maybe_owned =
2113       at::borrow_from_optional_tensor(cx_opt);
2114   const Tensor& cx = *cx_maybe_owned;
2115   const Tensor& grad_output_r =
2116       c10::value_or_else(grad_output_r_opt, [] { return Tensor(); });
2117   const Tensor& grad_hy_r =
2118       c10::value_or_else(grad_hy_r_opt, [] { return Tensor(); });
2119   const Tensor& grad_cy_r =
2120       c10::value_or_else(grad_cy_r_opt, [] { return Tensor(); });
2121   const Tensor& dropout_state =
2122       c10::value_or_else(dropout_state_opt, [] { return Tensor(); });
2123 
2124   if (!grad_output_r.defined() && !grad_hy_r.defined() &&
2125       !grad_cy_r.defined()) {
2126     return std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>>(
2127         Tensor(), Tensor(), Tensor(), std::vector<Tensor>(weight.size()));
2128   }
2129 
2130   auto grad_output = grad_output_r.defined()
2131       ? grad_output_r
2132       : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
2133   auto grad_hy = grad_hy_r.defined()
2134       ? grad_hy_r
2135       : at::zeros_like(hx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
2136   auto grad_cy = cx.defined()
2137       ? (grad_cy_r.defined()
2138              ? grad_cy_r
2139              : at::zeros_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT))
2140       : grad_cy_r;
2141 
2142   // NB: unconditionally compute this gradient, because it mutates reserve
2143   auto [dx, dhx, dcx] = at::native::_cudnn_rnn_backward_input(
2144       input,
2145       weight_buf,
2146       hx,
2147       cx,
2148       output,
2149       grad_output,
2150       grad_hy,
2151       grad_cy,
2152       mode,
2153       hidden_size,
2154       proj_size,
2155       num_layers,
2156       batch_first,
2157       dropout,
2158       train,
2159       bidirectional,
2160       batch_sizes,
2161       dropout_state,
2162       reserve,
2163       {output_mask[0], output_mask[1], output_mask[2]});
2164   std::vector<Tensor> dw;
2165   if (output_mask[3]) {
2166     dw = at::native::_cudnn_rnn_backward_weight(
2167         input,
2168         weight,
2169         weight_stride0,
2170         weight_buf,
2171         hx,
2172         cx,
2173         output,
2174         mode,
2175         hidden_size,
2176         proj_size,
2177         num_layers,
2178         batch_first,
2179         dropout,
2180         train,
2181         bidirectional,
2182         batch_sizes,
2183         dropout_state,
2184         reserve);
2185   }
2186   return std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>>{
2187       dx, dhx, dcx, dw};
2188 }
2189 
2190 // TODO: I am not sure if we actually need the 'dropout' and 'train' parameters
2191 // to initialize just the state tensor
2192 //
2193 // NB: You can have any color you like, as long as it's a CUDA byte
2194 // tensor.  Why does this function take a TensorOptions at all in that case?
2195 // This is a factory function: it produces tensors but takes no tensors
2196 // as input.  The codegen currently assumes that ALL factory functions
2197 // take TensorOptions, so it's just a lot easier for this function to
2198 // be bound if it also does it.
_cudnn_init_dropout_state(double dropout,bool train,int64_t dropout_seed,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)2199 Tensor _cudnn_init_dropout_state(
2200     double dropout,
2201     bool train,
2202     int64_t dropout_seed,
2203     std::optional<ScalarType> dtype,
2204     std::optional<Layout> layout,
2205     std::optional<Device> device,
2206     std::optional<bool> pin_memory) {
2207   // See [Note: hacky wrapper removal for TensorOptions]
2208   TensorOptions options =
2209       TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
2210           pin_memory);
2211 
2212   auto handle = getCudnnHandle();
2213   DropoutDescriptor dropout_desc;
2214   auto dropout_p = train ? dropout : 0;
2215   dropout_desc.initialize_rng(handle, dropout_p, dropout_seed, options);
2216   return dropout_desc.state;
2217 }
2218 
2219 ////////////////////////////////////////////////////////////////////////////////
2220 // CUDA dispatch for the generic RNN ops (at::lstm, at::gru, ...)
2221 ////////////////////////////////////////////////////////////////////////////////
2222 
2223 namespace {
2224 
2225 // Helpers for working with different hidden types.
unpack_hidden(const Tensor & hidden)2226 std::tuple<Tensor, Tensor> unpack_hidden(const Tensor& hidden) {
2227   return std::make_tuple(hidden, at::Tensor{});
2228 }
2229 
unpack_hidden(const std::tuple<Tensor,Tensor> & hidden)2230 std::tuple<Tensor, Tensor> unpack_hidden(
2231     const std::tuple<Tensor, Tensor>& hidden) {
2232   return hidden;
2233 }
2234 
2235 template <typename hidden_type>
pack_hidden(const Tensor & hx,const Tensor & cx)2236 hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) {
2237   static_assert(
2238       false && sizeof(hidden_type),
2239       "pack_hidden not implemented for this type");
2240 }
2241 
2242 template <>
pack_hidden(const Tensor & hx,const Tensor & cx)2243 Tensor pack_hidden<Tensor>(const Tensor& hx, const Tensor& cx) {
2244   AT_ASSERT(cx.numel() == 0);
2245   return hx;
2246 }
2247 
2248 template <>
pack_hidden(const Tensor & hx,const Tensor & cx)2249 std::tuple<Tensor, Tensor> pack_hidden<std::tuple<Tensor, Tensor>>(
2250     const Tensor& hx,
2251     const Tensor& cx) {
2252   return std::make_tuple(hx, cx);
2253 }
2254 
2255 /**
2256  * Note [DropoutState and CUDA graph capture]
2257  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2258  * (1) Telling a capturing stream to wait on an event recorded in a
2259  non-capturing stream is an error.
2260  * (2) Telling a non-capturing stream to wait on an event recorded during
2261  capture is also an error.
2262  *
2263  * So DropoutState's usage syncs could error if an RNN with dropout is called in
2264  an uncaptured region
2265  * then called in a captured region (triggering 1), or called in a captured
2266  region then called # in an uncaptured region (triggering 2).
2267  *
2268  * To prevent 1 and 2, lock() only syncs on the last usage event if it was
2269  recorded in the same
2270  * capture state as the current state (which also means the same graph, if
2271  capture is in progress).
2272  *
2273  * The solution should be safe as long as capture obeys the following
2274  restrictions:
2275  *  - Only one capture may be underway at a time in a given process.
2276  *  - While a capture is underway, no calls to eager ops on noncapturing streams
2277  (on any thread)
2278  *    may interleave with the captured ops.
2279  *
2280  * TODO: As people experiment with capture, keep an eye out for use cases that
2281  might need to
2282  * relax those restrictions.
2283  *
2284  * See https://github.com/pytorch/pytorch/pull/56433 for more discussion.
2285  */
2286 
2287 struct DropoutState {
2288   // Both buffer and event are lazily instantiated when a dropout state is
2289   // needed for the first time. Note that in this case needed != used, as we
2290   // don't need a buffer to e.g. run RNNs in test mode.
2291   at::Tensor buffer;
2292   std::optional<cuda::CUDAEvent> event;
2293   std::mutex mutex;
2294 #if !defined(USE_ROCM)
2295   // cudaStreamGetCaptureInfo will never give back a capture id of 0, so 0 can
2296   // serve as a sentinel value that capture was not underway.
2297   cuda::CaptureId_t capture_id_last_lock = 0;
2298   cuda::CaptureId_t capture_id_last_unlock = 0;
2299 #endif
2300 
2301   // Every time we use a dropout state, we need to synchronize with its event,
2302   // to make sure all previous uses finish running before this one starts. Once
2303   // we're done, we record the event to allow others to synchronize with this
2304   // kernel. Those events are really needed only for inter-stream sync on a
2305   // single GPU. I doubt anyone will want to run cuDNN RNNs in parallel on a
2306   // single GPU, so they should end up being complete no-ops.
lockat::native::__anon58f427e00811::DropoutState2307   void lock() {
2308     // NB: We can't ignore the lock even when event is undefined, because
2309     // someone could then define it before we get to unlock().
2310     mutex.lock();
2311     if (event) {
2312 #if !defined(USE_ROCM)
2313       // See Note [DropoutState and CUDA graph capture]
2314       cudaStreamCaptureStatus status;
2315       AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
2316           cuda::getCurrentCUDAStream(), &status, &capture_id_last_lock));
2317       if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
2318         capture_id_last_lock = 0;
2319       }
2320       if (capture_id_last_lock == capture_id_last_unlock) {
2321         event->block(cuda::getCurrentCUDAStream());
2322       }
2323 #else
2324       event->block(cuda::getCurrentCUDAStream());
2325 #endif
2326     }
2327   }
2328 
unlockat::native::__anon58f427e00811::DropoutState2329   void unlock() {
2330     if (event) {
2331       event->record();
2332 #if !defined(USE_ROCM)
2333       // See Note [DropoutState and CUDA graph capture]
2334       cudaStreamCaptureStatus status;
2335       AT_CUDA_CHECK(cudaStreamGetCaptureInfo(
2336           cuda::getCurrentCUDAStream(), &status, &capture_id_last_unlock));
2337       if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
2338         capture_id_last_unlock = 0;
2339       }
2340       TORCH_INTERNAL_ASSERT(capture_id_last_unlock == capture_id_last_lock);
2341 #endif
2342     }
2343     mutex.unlock();
2344   }
2345 };
2346 
get_dropout_state(double dropout_p,bool train,TensorOptions options)2347 DropoutState& get_dropout_state(
2348     double dropout_p,
2349     bool train,
2350     TensorOptions options) {
2351   // Each state is slightly over 2MB and initialized lazily, so it's fine to
2352   // cache them.
2353   static std::vector<DropoutState> dropout_state_cache{
2354       static_cast<size_t>(cuda::getNumGPUs())};
2355   static std::mutex state_cache_mut;
2356 
2357   AT_ASSERT(options.device().is_cuda());
2358   auto device = options.device().index();
2359 
2360   std::unique_lock<std::mutex> lock{state_cache_mut};
2361   auto& state = dropout_state_cache.at(device);
2362   if (train && dropout_p > 0) {
2363     const auto& gen =
2364         at::detail::getCUDAHooks().getDefaultCUDAGenerator(device);
2365     auto gen_impl = gen.get<at::CUDAGeneratorImpl>();
2366     bool reset_rnn_state = gen_impl->reset_rnn_state();
2367     if (!state.buffer.defined() || reset_rnn_state) {
2368       std::unique_lock<std::mutex> lock{state.mutex};
2369       int64_t seed =
2370           at::empty({}, options.dtype(at::kLong)).random_(gen).item<int64_t>();
2371       state.buffer = at::_cudnn_init_dropout_state(
2372           dropout_p, train, seed, options.dtype(at::kByte));
2373       // NB: CUDA binds the event to a device at creation time, so we can
2374       // initialize it only now, when we know we're on the correct device.
2375       if (!state.event.has_value()) {
2376         state.event.emplace();
2377       }
2378     }
2379   }
2380   return state;
2381 }
2382 
try_get_weight_buf(const Tensor & input,TensorList parameters,bool has_biases,cudnnRNNMode_t mode,c10::SymInt hidden_size,c10::SymInt proj_size,int64_t num_layers,bool bidirectional)2383 Tensor try_get_weight_buf(
2384     const Tensor& input,
2385     TensorList parameters,
2386     bool has_biases,
2387     cudnnRNNMode_t mode,
2388     c10::SymInt hidden_size,
2389     c10::SymInt proj_size,
2390     int64_t num_layers,
2391     bool bidirectional) {
2392   // Prepare all relevant descriptors
2393   auto handle = getCudnnHandle();
2394   auto& any_param = parameters.at(0);
2395   auto datatype = getCudnnDataType(any_param);
2396 
2397   // Something very naughty is happening here.  try_get_weight_buf
2398   // is called from _cudnn_impl, which is a *composite*.  In other words,
2399   // inside the composite function we need to query cudnn to figure out how big
2400   // the weight buf actually is going to be.  This clearly cannot be done
2401   // symbolically.  For now, we insert guards here; but once we have the black
2402   // box handling for dynamic shapes, we could also hypothetically infer out
2403   // the relationships
2404   RNNDescriptorParams rnn;
2405 #ifndef USE_CUDNN_RNN_V8_API
2406   rnn.set(
2407       mode,
2408       hidden_size.guard_int(__FILE__, __LINE__),
2409       proj_size.guard_int(__FILE__, __LINE__),
2410       num_layers,
2411       bidirectional,
2412       promote_rnn_math_type(datatype),
2413       datatype);
2414 #else
2415   auto cudnn_input_size = input.size(-1);
2416   auto packed = false; // eqy: bogus as we do not know if the input is packed
2417                        // here again, it should also not affect the weights
2418   rnn.set(
2419       mode,
2420       cudnn_input_size,
2421       packed,
2422       hidden_size.guard_int(__FILE__, __LINE__),
2423       proj_size.guard_int(__FILE__, __LINE__),
2424       num_layers,
2425       bidirectional,
2426       promote_rnn_math_type(datatype),
2427       datatype);
2428 #endif
2429   RNNDescriptor rnn_desc = rnn.descriptor(handle);
2430 
2431   TensorGeometry x_geom({1, input.sym_size(-1).guard_int(__FILE__, __LINE__)});
2432   TensorDescriptor x_desc;
2433   // datatype for x_desc comes from any_param, not input.
2434   // try_get_weight_buf's job is to check "is the weight buffer correctly laid
2435   // out for us to run it with input of the same datatype?"
2436   x_desc.set(datatype, x_geom.sizes(), x_geom.strides(), 5);
2437 
2438 #ifndef USE_CUDNN_RNN_V8_API
2439   auto num_params = get_num_weights(handle, rnn_desc, x_desc, datatype);
2440 #else
2441   auto num_params = get_num_weights(handle, rnn_desc, datatype);
2442 #endif
2443 
2444   // Try to get parameter storage
2445   auto param_storage = any_param.storage();
2446   auto weight_buf = at::empty({0}, any_param.options()).set_(param_storage);
2447   if (weight_buf.size(0) < num_params) {
2448     return {};
2449   } else if (weight_buf.size(0) > num_params) {
2450     weight_buf = weight_buf.narrow(0, 0, num_params);
2451   }
2452 
2453   // Get and check data pointers
2454   auto expected_data_ptrs = get_expected_data_ptrs(
2455       weight_buf, handle, rnn, rnn_desc, x_desc, datatype);
2456 
2457   int64_t num_parameters = parameters.size();
2458   int64_t num_ptrs = expected_data_ptrs.size();
2459   if (proj_size != 0) {
2460     AT_ASSERT(num_parameters % (has_biases ? 5 : 3) == 0);
2461     AT_ASSERT(num_ptrs % 5 == 0);
2462     if (has_biases) {
2463       AT_ASSERT(num_ptrs == num_parameters);
2464       for (const auto i : c10::irange(num_parameters)) {
2465         if (expected_data_ptrs[i] != parameters[i].data_ptr())
2466           return {};
2467       }
2468     } else {
2469       AT_ASSERT(num_parameters % 3 == 0);
2470       AT_ASSERT(num_ptrs == num_parameters * 5 / 3);
2471       for (int64_t param_i = 0, ptr_i = 0; ptr_i < num_ptrs;
2472            ptr_i += 5, param_i += 3) {
2473         if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr())
2474           return {};
2475         if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr())
2476           return {};
2477         if (expected_data_ptrs[ptr_i + 4] != parameters[param_i + 2].data_ptr())
2478           return {};
2479       }
2480     }
2481   } else {
2482     AT_ASSERT(num_ptrs == (num_parameters * (has_biases ? 1 : 2)));
2483     AT_ASSERT(num_parameters % (has_biases ? 4 : 2) == 0);
2484     for (int64_t param_i = 0, ptr_i = 0; ptr_i < num_ptrs;
2485          ptr_i += (has_biases ? 2 : 4), param_i += 2) {
2486       if (expected_data_ptrs[ptr_i] != parameters[param_i].data_ptr())
2487         return {};
2488       if (expected_data_ptrs[ptr_i + 1] != parameters[param_i + 1].data_ptr())
2489         return {};
2490     }
2491   }
2492   if (!parameters[num_parameters - 1].is_contiguous())
2493     return {};
2494   return weight_buf;
2495 }
2496 
2497 template <typename hidden_type>
_cudnn_impl(const Tensor & input,const Tensor & _batch_sizes,const hidden_type & hidden,TensorList params,bool has_biases,cudnnRNNMode_t mode,int64_t num_layers,double dropout_p,bool train,bool bidirectional)2498 std::pair<Tensor, hidden_type> _cudnn_impl(
2499     const Tensor& input,
2500     const Tensor& _batch_sizes,
2501     const hidden_type& hidden,
2502     TensorList params,
2503     bool has_biases,
2504     cudnnRNNMode_t mode,
2505     int64_t num_layers,
2506     double dropout_p,
2507     bool train,
2508     bool bidirectional) {
2509   auto [hx, cx] = unpack_hidden(hidden);
2510   auto hidden_size = hx.sym_size(2);
2511   SymInt proj_size = 0;
2512   // For LSTM models with projections hidden size could be different
2513   if (cx.defined() && cx.sym_size(2) != hx.sym_size(2)) {
2514     hidden_size = cx.sym_size(2);
2515     proj_size = hx.sym_size(2);
2516   }
2517 
2518   // TODO:  try_get_weight_buf returns a Tensor, but _cudnn_rnn below takes a
2519   // std::optional<Tensor> in weight_buf's slot.  Do we want try_get_weight_buf
2520   // to return a std::optional<Tensor> instead of a defined or undefined Tensor?
2521   at::cuda::OptionalCUDAGuard guard(input.get_device());
2522   auto weight_buf = try_get_weight_buf(
2523       input,
2524       params,
2525       has_biases,
2526       mode,
2527       hidden_size,
2528       proj_size,
2529       num_layers,
2530       bidirectional);
2531 
2532   TORCH_CHECK(_batch_sizes.dim() == 1, "batch_sizes tensor should be 1D");
2533   IntArrayRef batch_sizes{
2534       _batch_sizes.data_ptr<int64_t>(),
2535       static_cast<size_t>(_batch_sizes.size(0))};
2536 
2537   auto& dropout_state = get_dropout_state(dropout_p, train, input.options());
2538   std::unique_lock<DropoutState> lock{dropout_state};
2539   int64_t num_params = has_biases ? 4 : 2;
2540   if (proj_size != 0) {
2541     ++num_params;
2542   }
2543   auto sym_batch_sizes = c10::SymIntArrayRef(
2544       reinterpret_cast<const c10::SymInt*>(batch_sizes.data()),
2545       batch_sizes.size());
2546   // cudnn_output = std::tuple<output, hy, cy, reserve, new_weight_buf>
2547   auto cudnn_output = at::_cudnn_rnn_symint(
2548       input,
2549       params,
2550       num_params,
2551       weight_buf,
2552       hx,
2553       cx,
2554       static_cast<int>(mode),
2555       hidden_size,
2556       proj_size,
2557       num_layers,
2558       /*batch_first=*/false,
2559       dropout_p,
2560       train,
2561       bidirectional,
2562       sym_batch_sizes,
2563       dropout_state.buffer);
2564 
2565   return {
2566       std::get<0>(cudnn_output),
2567       pack_hidden<hidden_type>(
2568           std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
2569 }
2570 
2571 template <typename hidden_type>
_cudnn_impl(const Tensor & input,const hidden_type & hidden,TensorList params,bool has_biases,cudnnRNNMode_t mode,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first)2572 std::pair<Tensor, hidden_type> _cudnn_impl(
2573     const Tensor& input,
2574     const hidden_type& hidden,
2575     TensorList params,
2576     bool has_biases,
2577     cudnnRNNMode_t mode,
2578     int64_t num_layers,
2579     double dropout_p,
2580     bool train,
2581     bool bidirectional,
2582     bool batch_first) {
2583   auto [hx, cx] = unpack_hidden(hidden);
2584   auto hidden_size = hx.sym_size(2);
2585   c10::SymInt proj_size = 0;
2586   // For LSTM models with projections hidden size could be different
2587   if (cx.defined() && cx.sym_size(2) != hx.sym_size(2)) {
2588     hidden_size = cx.sym_size(2);
2589     proj_size = hx.sym_size(2);
2590   }
2591   at::cuda::OptionalCUDAGuard guard(input.get_device());
2592   auto weight_buf = try_get_weight_buf(
2593       input,
2594       params,
2595       has_biases,
2596       mode,
2597       hidden_size,
2598       proj_size,
2599       num_layers,
2600       bidirectional);
2601   auto& dropout_state = get_dropout_state(dropout_p, train, input.options());
2602   std::unique_lock<DropoutState> lock{dropout_state};
2603   int64_t num_params = has_biases ? 4 : 2;
2604   if (proj_size != 0) {
2605     ++num_params;
2606   }
2607   // cudnn_output = std::tuple<output, hy, cy, reserve, new_weight_buf>
2608   auto cudnn_output = at::_cudnn_rnn_symint(
2609       input,
2610       params,
2611       num_params,
2612       weight_buf,
2613       hx,
2614       cx,
2615       static_cast<int>(mode),
2616       hidden_size,
2617       proj_size,
2618       num_layers,
2619       batch_first,
2620       dropout_p,
2621       train,
2622       bidirectional,
2623       /*batch_sizes=*/{},
2624       dropout_state.buffer);
2625 
2626   return {
2627       std::get<0>(cudnn_output),
2628       pack_hidden<hidden_type>(
2629           std::get<1>(cudnn_output), std::get<2>(cudnn_output))};
2630 }
2631 
2632 #define ONE_HIDDEN_RNN(NAME, MODE)                          \
2633   void NAME##_cudnn(                                        \
2634       Tensor& output,                                       \
2635       Tensor& hy,                                           \
2636       const Tensor& input,                                  \
2637       const Tensor& hx,                                     \
2638       TensorList params,                                    \
2639       bool has_biases,                                      \
2640       int64_t num_layers,                                   \
2641       double dropout_p,                                     \
2642       bool train,                                           \
2643       bool bidirectional,                                   \
2644       bool batch_first) {                                   \
2645     std::tie(output, hy) = _cudnn_impl(                     \
2646         input,                                              \
2647         hx,                                                 \
2648         params,                                             \
2649         has_biases,                                         \
2650         MODE,                                               \
2651         num_layers,                                         \
2652         dropout_p,                                          \
2653         train,                                              \
2654         bidirectional,                                      \
2655         batch_first);                                       \
2656   }                                                         \
2657                                                             \
2658   void NAME##_packed_cudnn(                                 \
2659       Tensor& output,                                       \
2660       Tensor& hy,                                           \
2661       const Tensor& data,                                   \
2662       const Tensor& batch_sizes,                            \
2663       const Tensor& hx,                                     \
2664       TensorList params,                                    \
2665       bool has_biases,                                      \
2666       int64_t num_layers,                                   \
2667       double dropout_p,                                     \
2668       bool train,                                           \
2669       bool bidirectional) {                                 \
2670     std::tie(output, hy) = _cudnn_impl(                     \
2671         data,                                               \
2672         batch_sizes,                                        \
2673         hx,                                                 \
2674         params,                                             \
2675         has_biases,                                         \
2676         MODE,                                               \
2677         num_layers,                                         \
2678         dropout_p,                                          \
2679         train,                                              \
2680         bidirectional);                                     \
2681   }                                                         \
2682                                                             \
2683   REGISTER_CUDA_DISPATCH(NAME##_cudnn_stub, &NAME##_cudnn); \
2684   REGISTER_CUDA_DISPATCH(NAME##_packed_cudnn_stub, &NAME##_packed_cudnn);
2685 
ONE_HIDDEN_RNN(gru,CUDNN_GRU)2686 ONE_HIDDEN_RNN(gru, CUDNN_GRU)
2687 ONE_HIDDEN_RNN(rnn_tanh, CUDNN_RNN_TANH)
2688 ONE_HIDDEN_RNN(rnn_relu, CUDNN_RNN_RELU)
2689 
2690 void lstm_cudnn(
2691     Tensor& output,
2692     Tensor& hy,
2693     Tensor& cy,
2694     const Tensor& input,
2695     TensorList hx,
2696     TensorList params,
2697     bool has_biases,
2698     int64_t num_layers,
2699     double dropout_p,
2700     bool train,
2701     bool bidirectional,
2702     bool batch_first) {
2703   auto result = _cudnn_impl(
2704       input,
2705       std::make_tuple(hx[0], hx[1]),
2706       params,
2707       has_biases,
2708       CUDNN_LSTM,
2709       num_layers,
2710       dropout_p,
2711       train,
2712       bidirectional,
2713       batch_first);
2714   output = result.first;
2715   hy = std::get<0>(result.second);
2716   cy = std::get<1>(result.second);
2717 }
2718 
lstm_packed_cudnn(Tensor & output,Tensor & hy,Tensor & cy,const Tensor & data,const Tensor & batch_sizes,TensorList hx,TensorList params,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional)2719 void lstm_packed_cudnn(
2720     Tensor& output,
2721     Tensor& hy,
2722     Tensor& cy,
2723     const Tensor& data,
2724     const Tensor& batch_sizes,
2725     TensorList hx,
2726     TensorList params,
2727     bool has_biases,
2728     int64_t num_layers,
2729     double dropout_p,
2730     bool train,
2731     bool bidirectional) {
2732   auto result = _cudnn_impl(
2733       data,
2734       batch_sizes,
2735       std::make_tuple(hx[0], hx[1]),
2736       params,
2737       has_biases,
2738       CUDNN_LSTM,
2739       num_layers,
2740       dropout_p,
2741       train,
2742       bidirectional);
2743   output = result.first;
2744   hy = std::get<0>(result.second);
2745   cy = std::get<1>(result.second);
2746 }
2747 
2748 REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn);
2749 REGISTER_CUDA_DISPATCH(lstm_packed_cudnn_stub, &lstm_packed_cudnn);
2750 
2751 } // namespace
2752 
2753 } // namespace at
2754 } // namespace at
2755 
2756 #endif // AT_CUDNN_ENABLED()
2757