1 #pragma once 2 3 #include <ATen/cudnn/Descriptors.h> 4 #include <ATen/cudnn/Types.h> 5 #include <ATen/cudnn/Utils.h> 6 #include <ATen/cudnn/cudnn-wrapper.h> 7 8 // Declares utilities used by RNN.cpp and also needed by external consumers 9 namespace at { 10 namespace native { 11 namespace cudnn_rnn { 12 13 TORCH_CUDA_CPP_API std::tuple<Tensor, std::vector<Tensor>> 14 copy_weights_to_flat_buf_views( 15 TensorList weight_arr, 16 int64_t weight_stride0, 17 int64_t input_size, 18 int64_t mode, 19 int64_t hidden_size, 20 int64_t proj_size, 21 int64_t num_layers, 22 bool batch_first, 23 bool bidirectional, 24 const cudnnDataType_t flat_buf_datatype, 25 const TensorOptions& flat_buf_options, 26 bool set_orig_weights_to_flat_buf, 27 bool allow_type_change = false, 28 bool include_bias = true); 29 30 } // namespace cudnn_rnn 31 } // namespace native 32 } // namespace at 33