1 #include <ATen/ATen.h>
2 #include <ATen/autocast_mode.h>
3 #include <torch/library.h>
4
5 // pulls in AT_CUDNN_ENABLED() as defined by cmake
6 #include <ATen/cuda/CUDAConfig.h>
7
8 #if AT_CUDNN_ENABLED()
9 #include <ATen/native/cudnn/RNNUtils.h>
10 #endif
11
12
13 namespace at::autocast {
14
15 /********************************************************************************
16 Autocast wrapper for CuDNN RNNs (the weight reflattening needs special attention)
17 ********************************************************************************/
18
19 // To be registered for the "_cudnn_rnn(...)" schema.
20 // _cudnn_rnn is autograd-exposed (test_autocast_cudnn_rnn in test_cuda.py includes a test to confirm)
21 std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>
_cudnn_rnn_cast_reflatten(const Tensor & input,TensorList weight,int64_t weight_stride0,const std::optional<Tensor> & weight_buf_opt,const Tensor & hx,const std::optional<Tensor> & cx,int64_t mode,int64_t hidden_size,int64_t proj_size,int64_t num_layers,bool batch_first,double dropout,bool train,bool bidirectional,IntArrayRef batch_sizes,const std::optional<Tensor> & dropout_state)22 _cudnn_rnn_cast_reflatten(const Tensor & input,
23 TensorList weight,
24 int64_t weight_stride0,
25 const std::optional<Tensor>& weight_buf_opt,
26 const Tensor& hx,
27 const std::optional<Tensor>& cx,
28 int64_t mode,
29 int64_t hidden_size,
30 int64_t proj_size,
31 int64_t num_layers,
32 bool batch_first,
33 double dropout,
34 bool train,
35 bool bidirectional,
36 IntArrayRef batch_sizes,
37 const std::optional<Tensor>& dropout_state) {
38 #if AT_CUDNN_ENABLED()
39 c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
40
41 for (const auto& t : weight) {
42 TORCH_CHECK(weight[0].scalar_type() == t.scalar_type(), "Weight scalar types do not match.");
43 }
44 // weight_stride0 is the number of weight tensors per layer and direction, as seen by model.parameters().
45 // If bias is enabled, there are 4 such tensors (ih and hh weights, ih and hh biases).
46 // If bias is not enabled, there are 2 (ih and hh weights).
47 // This organization holds for all rnn types (RNN, GRU, and LSTM). If LSTM with projections is
48 // used, additional hr weight is added.
49 if (proj_size > 0) {
50 TORCH_INTERNAL_ASSERT((weight_stride0 == 3) || (weight_stride0 == 5),
51 "weight_stride0 must be 3 (if no bias) or 5 (if bias) for LSTM with projections. Received ",
52 weight_stride0);
53 } else {
54 TORCH_INTERNAL_ASSERT((weight_stride0 == 2) || (weight_stride0 == 4),
55 "weight_stride0 must be 2 (if no bias) or 4 (if bias). Received ",
56 weight_stride0);
57 }
58
59
60 Tensor weight_buf, redispatch_weight_buf;
61 std::vector<Tensor> redispatch_weight;
62 // There's an implicit contract here with native/cudnn/RNN.cpp:_cudnn_impl, which calls at:_cudnn_rnn.
63 // Code here assumes if _cudnn_impl passes weight_buf_opt containing a defined tensor, that tensor
64 // is valid flat storage of the weights in their incoming dtype.
65 if (weight_buf_opt.has_value()) {
66 weight_buf = *weight_buf_opt;
67 }
68 bool needs_cast_and_flatten = (weight_buf.defined() ?
69 // weight_buf is valid. Only change it if it's eligible and not already FP16.
70 is_eligible(weight_buf) && (weight_buf.scalar_type() != at::kHalf) :
71 // weight_buf is not valid. Only create it if other weights are eligible and not already FP16.
72 is_eligible(weight[0]) && (weight[0].scalar_type() != at::kHalf));
73 if (needs_cast_and_flatten) {
74 // Casts weight tensors to FP16 and ensures all weights for all layers are views into a large flat buffer,
75 // with the right locations and layouts expected by cudnn.
76 // This is (and should be) autograd-exposed.
77 bool include_bias = true;
78 if (weight_stride0 == 2 || (weight_stride0 == 3 && proj_size > 0)) {
79 include_bias = false;
80 }
81 std::tie(redispatch_weight_buf, redispatch_weight) =
82 at::native::cudnn_rnn::copy_weights_to_flat_buf_views(
83 weight,
84 weight_stride0,
85 input.size(-1),
86 mode,
87 hidden_size,
88 proj_size,
89 num_layers,
90 batch_first,
91 bidirectional,
92 /*flat_buf_datatype=*/at::native::getCudnnDataTypeFromScalarType(at::kHalf), // could just hardcode CUDNN_DATA_HALF
93 /*flat_buf_options=*/weight[0].options().dtype(at::kHalf),
94 /*set_orig_weights_to_flat_buf=*/false,
95 /*allow_type_change=*/true,
96 /*include_bias=*/include_bias);
97 }
98 return at::_cudnn_rnn(
99 cached_cast(at::kHalf, input),
100 needs_cast_and_flatten ? TensorList(redispatch_weight) : weight,
101 weight_stride0,
102 needs_cast_and_flatten ? redispatch_weight_buf : weight_buf,
103 cached_cast(at::kHalf, hx),
104 cached_cast(at::kHalf, cx),
105 mode,
106 hidden_size,
107 proj_size,
108 num_layers,
109 batch_first,
110 dropout,
111 train,
112 bidirectional,
113 batch_sizes,
114 dropout_state);
115 #else // AT_CUDNN_ENABLED()
116 AT_ERROR("autocast::_cudnn_rnn_cast_reflatten: ATen not compiled with cuDNN support");
117 return {Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{}}; // never reached, placates the compiler
118 #endif // AT_CUDNN_ENABLED()
119 }
120
121 namespace {
TORCH_LIBRARY_IMPL(aten,Autocast,m)122 TORCH_LIBRARY_IMPL(aten, Autocast, m) {
123 m.impl("_cudnn_rnn",
124 TORCH_FN((&at::autocast::_cudnn_rnn_cast_reflatten)));
125 }
126 } // anonymous namespace
127
128 } // namespace at::autocast
129