xref: /aosp_15_r20/external/pytorch/aten/src/ATen/miopen/AutocastRNN.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/autocast_mode.h>
3 #include <ATen/cuda/CUDAConfig.h>
4 #include <torch/library.h>
5 
6 namespace at {
7 namespace autocast {
8 
9 /**********************************************************************
10 Autocast wrapper for MIOpen RNNs
11 **********************************************************************/
12 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
miopen_rnn(const Tensor & input_r,TensorList weight,int64_t weight_stride0,const Tensor & hx,const std::optional<Tensor> & cx_opt,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const std::optional<Tensor> & fn_dropout_state_opt)13 miopen_rnn(const Tensor & input_r,
14            TensorList weight,
15            int64_t weight_stride0,
16            const Tensor & hx,
17            const std::optional<Tensor>& cx_opt,
18            int64_t fn_mode,
19            int64_t fn_hidden_size,
20            int64_t fn_num_layers,
21            bool batch_first,
22            double fn_dropout,
23            bool fn_train,
24            bool fn_bidirectional,
25            IntArrayRef fn_batch_sizes,
26            const std::optional<Tensor>& fn_dropout_state_opt) {
27 
28 #if AT_ROCM_ENABLED()
29 
30     c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
31 
32     return at::miopen_rnn(
33                 cached_cast(at::kHalf, input_r),
34                 cached_cast(at::kHalf, weight),
35                 weight_stride0,
36                 cached_cast(at::kHalf, hx),
37                 cached_cast(at::kHalf, cx_opt),
38                 fn_mode,
39                 fn_hidden_size,
40                 fn_num_layers,
41                 batch_first,
42                 fn_dropout,
43                 fn_train,
44                 fn_bidirectional,
45                 fn_batch_sizes,
46                 fn_dropout_state_opt);
47 
48 #else
49     AT_ERROR("autocast::miopen_rnn: ATen not compiled with ROCm enabled");
50     return {Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{}}; // placate the compiler
51 #endif
52 
53 }
54 
55 // Register Autocast dispatch
56 namespace {
TORCH_LIBRARY_IMPL(aten,Autocast,m)57 TORCH_LIBRARY_IMPL(aten, Autocast, m) {
58   m.impl("miopen_rnn",
59          TORCH_FN((&at::autocast::miopen_rnn)));
60 }
61 } // anonymous namespace
62 
63 } // namespace autocast
64 } // namespace at
65