xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cudnn/AutocastRNN.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/autocast_mode.h>
3 #include <torch/library.h>
4 
5 // pulls in AT_CUDNN_ENABLED() as defined by cmake
6 #include <ATen/cuda/CUDAConfig.h>
7 
8 #if AT_CUDNN_ENABLED()
9 #include <ATen/native/cudnn/RNNUtils.h>
10 #endif
11 
12 
13 namespace at::autocast {
14 
15 /********************************************************************************
16 Autocast wrapper for CuDNN RNNs (the weight reflattening needs special attention)
17 ********************************************************************************/
18 
19 // To be registered for the "_cudnn_rnn(...)" schema.
20 // _cudnn_rnn is autograd-exposed (test_autocast_cudnn_rnn in test_cuda.py includes a test to confirm)
21 std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>
_cudnn_rnn_cast_reflatten(const Tensor & input,TensorList weight,int64_t weight_stride0,const std::optional<Tensor> & weight_buf_opt,const Tensor & hx,const std::optional<Tensor> & cx,int64_t mode,int64_t hidden_size,int64_t proj_size,int64_t num_layers,bool batch_first,double dropout,bool train,bool bidirectional,IntArrayRef batch_sizes,const std::optional<Tensor> & dropout_state)22 _cudnn_rnn_cast_reflatten(const Tensor & input,
23                           TensorList weight,
24                           int64_t weight_stride0,
25                           const std::optional<Tensor>& weight_buf_opt,
26                           const Tensor& hx,
27                           const std::optional<Tensor>& cx,
28                           int64_t mode,
29                           int64_t hidden_size,
30                           int64_t proj_size,
31                           int64_t num_layers,
32                           bool batch_first,
33                           double dropout,
34                           bool train,
35                           bool bidirectional,
36                           IntArrayRef batch_sizes,
37                           const std::optional<Tensor>& dropout_state) {
38 #if AT_CUDNN_ENABLED()
39   c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
40 
41   for (const auto& t : weight) {
42     TORCH_CHECK(weight[0].scalar_type() == t.scalar_type(), "Weight scalar types do not match.");
43   }
44   // weight_stride0 is the number of weight tensors per layer and direction, as seen by model.parameters().
45   // If bias is enabled, there are 4 such tensors (ih and hh weights, ih and hh biases).
46   // If bias is not enabled, there are 2 (ih and hh weights).
47   // This organization holds for all rnn types (RNN, GRU, and LSTM). If LSTM with projections is
48   // used, additional hr weight is added.
49   if (proj_size > 0) {
50     TORCH_INTERNAL_ASSERT((weight_stride0 == 3) || (weight_stride0 == 5),
51                           "weight_stride0 must be 3 (if no bias) or 5 (if bias) for LSTM with projections.  Received ",
52                           weight_stride0);
53   } else {
54     TORCH_INTERNAL_ASSERT((weight_stride0 == 2) || (weight_stride0 == 4),
55                           "weight_stride0 must be 2 (if no bias) or 4 (if bias).  Received ",
56                           weight_stride0);
57   }
58 
59 
60   Tensor weight_buf, redispatch_weight_buf;
61   std::vector<Tensor> redispatch_weight;
62   // There's an implicit contract here with native/cudnn/RNN.cpp:_cudnn_impl, which calls at:_cudnn_rnn.
63   // Code here assumes if _cudnn_impl passes weight_buf_opt containing a defined tensor, that tensor
64   // is valid flat storage of the weights in their incoming dtype.
65   if (weight_buf_opt.has_value()) {
66     weight_buf = *weight_buf_opt;
67   }
68   bool needs_cast_and_flatten = (weight_buf.defined() ?
69                                  // weight_buf is valid.  Only change it if it's eligible and not already FP16.
70                                  is_eligible(weight_buf) && (weight_buf.scalar_type() != at::kHalf) :
71                                  // weight_buf is not valid.  Only create it if other weights are eligible and not already FP16.
72                                  is_eligible(weight[0]) && (weight[0].scalar_type() != at::kHalf));
73   if (needs_cast_and_flatten) {
74     // Casts weight tensors to FP16 and ensures all weights for all layers are views into a large flat buffer,
75     // with the right locations and layouts expected by cudnn.
76     // This is (and should be) autograd-exposed.
77     bool include_bias = true;
78     if (weight_stride0 == 2 || (weight_stride0 == 3 && proj_size > 0)) {
79       include_bias = false;
80     }
81     std::tie(redispatch_weight_buf, redispatch_weight) =
82         at::native::cudnn_rnn::copy_weights_to_flat_buf_views(
83             weight,
84             weight_stride0,
85             input.size(-1),
86             mode,
87             hidden_size,
88             proj_size,
89             num_layers,
90             batch_first,
91             bidirectional,
92             /*flat_buf_datatype=*/at::native::getCudnnDataTypeFromScalarType(at::kHalf), // could just hardcode CUDNN_DATA_HALF
93             /*flat_buf_options=*/weight[0].options().dtype(at::kHalf),
94             /*set_orig_weights_to_flat_buf=*/false,
95             /*allow_type_change=*/true,
96             /*include_bias=*/include_bias);
97   }
98   return at::_cudnn_rnn(
99       cached_cast(at::kHalf, input),
100       needs_cast_and_flatten ? TensorList(redispatch_weight) : weight,
101       weight_stride0,
102       needs_cast_and_flatten ? redispatch_weight_buf : weight_buf,
103       cached_cast(at::kHalf, hx),
104       cached_cast(at::kHalf, cx),
105       mode,
106       hidden_size,
107       proj_size,
108       num_layers,
109       batch_first,
110       dropout,
111       train,
112       bidirectional,
113       batch_sizes,
114       dropout_state);
115 #else // AT_CUDNN_ENABLED()
116   AT_ERROR("autocast::_cudnn_rnn_cast_reflatten: ATen not compiled with cuDNN support");
117   return {Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{}}; // never reached, placates the compiler
118 #endif // AT_CUDNN_ENABLED()
119 }
120 
121 namespace {
TORCH_LIBRARY_IMPL(aten,Autocast,m)122 TORCH_LIBRARY_IMPL(aten, Autocast, m) {
123   m.impl("_cudnn_rnn",
124          TORCH_FN((&at::autocast::_cudnn_rnn_cast_reflatten)));
125 }
126 } // anonymous namespace
127 
128 } // namespace at::autocast
129