xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/RNNUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cudnn/Descriptors.h>
4 #include <ATen/cudnn/Types.h>
5 #include <ATen/cudnn/Utils.h>
6 #include <ATen/cudnn/cudnn-wrapper.h>
7 
8 // Declares utilities used by RNN.cpp and also needed by external consumers
9 namespace at {
10 namespace native {
11 namespace cudnn_rnn {
12 
13 TORCH_CUDA_CPP_API std::tuple<Tensor, std::vector<Tensor>>
14 copy_weights_to_flat_buf_views(
15     TensorList weight_arr,
16     int64_t weight_stride0,
17     int64_t input_size,
18     int64_t mode,
19     int64_t hidden_size,
20     int64_t proj_size,
21     int64_t num_layers,
22     bool batch_first,
23     bool bidirectional,
24     const cudnnDataType_t flat_buf_datatype,
25     const TensorOptions& flat_buf_options,
26     bool set_orig_weights_to_flat_buf,
27     bool allow_type_change = false,
28     bool include_bias = true);
29 
30 } // namespace cudnn_rnn
31 } // namespace native
32 } // namespace at
33