1 #include <ATen/ATen.h>
2 #include <ATen/autocast_mode.h>
3 #include <ATen/cuda/CUDAConfig.h>
4 #include <torch/library.h>
5
6 namespace at {
7 namespace autocast {
8
9 /**********************************************************************
10 Autocast wrapper for MIOpen RNNs
11 **********************************************************************/
12 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
miopen_rnn(const Tensor & input_r,TensorList weight,int64_t weight_stride0,const Tensor & hx,const std::optional<Tensor> & cx_opt,int64_t fn_mode,int64_t fn_hidden_size,int64_t fn_num_layers,bool batch_first,double fn_dropout,bool fn_train,bool fn_bidirectional,IntArrayRef fn_batch_sizes,const std::optional<Tensor> & fn_dropout_state_opt)13 miopen_rnn(const Tensor & input_r,
14 TensorList weight,
15 int64_t weight_stride0,
16 const Tensor & hx,
17 const std::optional<Tensor>& cx_opt,
18 int64_t fn_mode,
19 int64_t fn_hidden_size,
20 int64_t fn_num_layers,
21 bool batch_first,
22 double fn_dropout,
23 bool fn_train,
24 bool fn_bidirectional,
25 IntArrayRef fn_batch_sizes,
26 const std::optional<Tensor>& fn_dropout_state_opt) {
27
28 #if AT_ROCM_ENABLED()
29
30 c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
31
32 return at::miopen_rnn(
33 cached_cast(at::kHalf, input_r),
34 cached_cast(at::kHalf, weight),
35 weight_stride0,
36 cached_cast(at::kHalf, hx),
37 cached_cast(at::kHalf, cx_opt),
38 fn_mode,
39 fn_hidden_size,
40 fn_num_layers,
41 batch_first,
42 fn_dropout,
43 fn_train,
44 fn_bidirectional,
45 fn_batch_sizes,
46 fn_dropout_state_opt);
47
48 #else
49 AT_ERROR("autocast::miopen_rnn: ATen not compiled with ROCm enabled");
50 return {Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{}}; // placate the compiler
51 #endif
52
53 }
54
55 // Register Autocast dispatch
56 namespace {
TORCH_LIBRARY_IMPL(aten,Autocast,m)57 TORCH_LIBRARY_IMPL(aten, Autocast, m) {
58 m.impl("miopen_rnn",
59 TORCH_FN((&at::autocast::miopen_rnn)));
60 }
61 } // anonymous namespace
62
63 } // namespace autocast
64 } // namespace at
65