xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/RNN.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/RNN.h>
2 #include <ATen/ATen.h>
3 #include <ATen/Config.h>
4 #include <ATen/InitialTensorOptions.h>
5 #include <ATen/MatrixRef.h>
6 
7 #include <ATen/TensorUtils.h>
8 #include <ATen/Dispatch.h>
9 #include <c10/core/GradMode.h>
10 #include <c10/macros/Macros.h>
11 #include <c10/util/Exception.h>
12 #include <torch/library.h>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/mkldnn_convolution_native.h>
18 #include <ATen/ops/mkldnn_rnn_layer_backward_native.h>
19 #include <ATen/ops/mkldnn_rnn_layer_native.h>
20 #endif
21 
22 #if !AT_MKLDNN_ENABLED()
23 
24 namespace at::native {
25 
26 
mkldnn_rnn_layer(const Tensor & input,const Tensor & w0,const Tensor & w1,const Tensor & w2,const Tensor & w3,const Tensor & hx_,const Tensor & cx_,bool reverse,IntArrayRef batch_sizes,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool bidirectional,bool batch_first,bool train)27 std::tuple<Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer(
28 const Tensor& input,
29     const Tensor& w0,
30     const Tensor& w1,
31     const Tensor& w2,
32     const Tensor& w3,
33     const Tensor& hx_,
34     const Tensor& cx_,
35     bool reverse,
36     IntArrayRef batch_sizes,
37     int64_t mode,
38     int64_t hidden_size,
39     int64_t num_layers,
40     bool has_biases,
41     bool bidirectional,
42     bool batch_first,
43     bool train) {
44       AT_ERROR("mkldnn_rnn_layer: ATen not compiled with MKLDNN support");
45   }
46 
mkldnn_rnn_layer_backward(const Tensor & input,const Tensor & weight0,const Tensor & weight1,const Tensor & weight2,const Tensor & weight3,const Tensor & hx_,const Tensor & cx_tmp,const Tensor & output,const Tensor & hy_,const Tensor & cy_,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,bool reverse,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool train,bool bidirectional,at::IntArrayRef batch_sizes,bool batch_first,const at::Tensor & workspace)47 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer_backward(
48     const Tensor& input,
49     const Tensor& weight0,
50     const Tensor& weight1,
51     const Tensor& weight2,
52     const Tensor& weight3,
53     const Tensor& hx_,
54     const Tensor& cx_tmp,
55     const Tensor& output,
56     const Tensor& hy_,
57     const Tensor& cy_,
58     const std::optional<Tensor>& grad_output_r_opt,
59     const std::optional<Tensor>& grad_hy_r_opt,
60     const std::optional<Tensor>& grad_cy_r_opt,
61     bool reverse,
62     int64_t mode,
63     int64_t hidden_size,
64     int64_t num_layers,
65     bool has_biases,
66     bool train,
67     bool bidirectional,
68     at::IntArrayRef batch_sizes,
69     bool batch_first,
70     const at::Tensor& workspace) {
71       AT_ERROR("mkldnn_rnn_layer_backward: ATen not compiled with MKLDNN support");
72     }
73 
74 REGISTER_NO_CPU_DISPATCH(lstm_mkldnn_stub);
75 
76 } // namespace at::native
77 
78 #else // AT_MKLDNN_ENABLED
79 
80 #include <ATen/native/mkldnn/MKLDNNCommon.h>
81 #include <ATen/native/mkldnn/Utils.h>
82 
83 namespace at::native {
84 
85 struct RNNParams {
86   ideep::rnn_kind mode;
87   int64_t seq_length;
88   int64_t mini_batch;
89   int64_t input_size;
90   int64_t hidden_size;
91   int64_t num_directions;
92   int64_t num_layers;
93   bool batch_first;
94   bool train;
95   at::IntArrayRef batch_sizes;
96   int64_t num_gates;
97   int64_t num_bias_gates;
98 
RNNParamsat::native::RNNParams99   RNNParams(
100       const at::Tensor& input,
101       at::IntArrayRef batch_sizes_,
102       int64_t mode_,
103       int64_t hidden_size_,
104       int64_t num_layers_,
105       bool bidirectional,
106       bool batch_first_,
107       bool train_) {
108     mode = static_cast<ideep::rnn_kind>(mode_);
109     batch_first = batch_first_;
110     seq_length = input.size(0);
111     mini_batch = input.size(1);
112     input_size = input.size(2);
113     hidden_size = hidden_size_;
114     num_directions = bidirectional ? 2 : 1;
115     num_layers = num_layers_;
116     train = train_;
117     batch_sizes = batch_sizes_;
118     if (mode == ideep::rnn_kind::LSTM) {
119       num_gates = 4;
120       num_bias_gates = 4;
121     } else if (mode == ideep::rnn_kind::GRU) {
122       num_gates = 3;
123       num_bias_gates = 4;
124     } else {
125       // RNN_RELU; RNN_TANH
126       num_gates = 1;
127       num_bias_gates = 1;
128     }
129   }
130 
131   // mkldnn memory descriptors
132   using format = ideep::format_tag;
133   using desc = ideep::tensor::desc;
134   using dtype = ideep::tensor::data_type;
src_layer_descat::native::RNNParams135   desc src_layer_desc(int64_t _input_size, dtype dtype) const {
136     return {{seq_length, mini_batch, _input_size}, dtype, format::tnc};
137   }
src_iter_descat::native::RNNParams138   desc src_iter_desc(dtype dtype) const {
139     return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc};
140   }
src_iter_c_descat::native::RNNParams141   desc src_iter_c_desc(dtype dtype) const {
142     return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc};
143   }
144   // logical size described as ldigo
weights_layer_descat::native::RNNParams145   desc weights_layer_desc(int64_t _input_size, dtype dtype) const {
146     return {{1, 1, _input_size, num_gates, hidden_size}, dtype, format::ldgoi};
147   }
weights_layer_ldigo_descat::native::RNNParams148   desc weights_layer_ldigo_desc(int64_t _input_size, dtype dtype) const {
149     return {{1, 1, _input_size, num_gates, hidden_size}, dtype, format::ldigo};
150   }
weights_iter_descat::native::RNNParams151   desc weights_iter_desc(dtype dtype) const {
152     return {{1, 1, hidden_size, num_gates, hidden_size}, dtype, format::ldgoi};
153   }
weights_iter_ldigo_descat::native::RNNParams154   desc weights_iter_ldigo_desc(dtype dtype) const {
155     return {{1, 1, hidden_size, num_gates, hidden_size}, dtype, format::ldigo};
156   }
bias_descat::native::RNNParams157   desc bias_desc(dtype dtype) const {
158     return {{1, 1, num_bias_gates, hidden_size}, dtype, format::ldgo};
159   }
dst_layer_descat::native::RNNParams160   desc dst_layer_desc(dtype dtype) const {
161     return {{seq_length, mini_batch, hidden_size}, dtype, format::tnc};
162   }
dst_iter_descat::native::RNNParams163   desc dst_iter_desc(dtype dtype) const {
164     return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc};
165   }
dst_iter_c_descat::native::RNNParams166   desc dst_iter_c_desc(dtype dtype) const {
167     return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc};
168   }
169 };
170 
171 template<bool is_single_direction>
_output_size(const RNNParams & rnn)172 std::vector<int64_t> _output_size(const RNNParams& rnn) {
173   auto output_channels = is_single_direction ? rnn.hidden_size
174                                              : rnn.hidden_size * rnn.num_directions;
175   return {rnn.seq_length, rnn.mini_batch, output_channels};
176 }
177 
178 // MKLDNN GRU gate order is different from PyTorch's which requires gates shuffle
179 // (let rt,zt,nt be reset, update, new gates respectively)
180 //
181 //   MKLDNN GRU weight_ih/weight_hh gates order: (zt, rt, nt)
182 //   PyTorch GRU weight_ih/weight_hh gates order: (rt, zt, nt)
183 //
184 // MKLDNN GRU bias has 4 gates instead of 3
185 //  (PyTorch GRU bias)     (MKLDNN GRU bias)
186 //
187 //  bias_ih    bias_hh          bias
188 //  +-----+    +-----+       +---------+
189 //  | rt1 |    | rt2 |       | zt1+zt2 |
190 //  |-----|    |-----|       |---------|
191 //  | zt1 |    | zt2 |       | rt1+rt2 |
192 //  |-----|    |-----|       |---------|
193 //  | nt1 |    | nt2 |       |   nt1   |
194 //  +-----+    +-----+       |---------|
195 //                           |   nt2   |
196 //                           +---------+
197 //
_shuffle_weight(const Tensor & weight,int64_t fn_mode)198 static Tensor _shuffle_weight(const Tensor& weight, int64_t fn_mode) {
199   auto weight_t = weight.contiguous();
200   if (static_cast<ideep::rnn_kind>(fn_mode) == ideep::rnn_kind::GRU) {
201     std::vector<Tensor> gates = weight_t.chunk(3, /*gates*/0);
202     return at::cat({gates[1], gates[0], gates[2]}, /*gates*/0);
203   }
204   return weight_t;
205 }
206 
_shuffle_bias(const Tensor & bias_ih,const Tensor & bias_hh,int64_t fn_mode)207 static Tensor _shuffle_bias(const Tensor& bias_ih, const Tensor& bias_hh, int64_t fn_mode) {
208   if (static_cast<ideep::rnn_kind>(fn_mode) == ideep::rnn_kind::GRU) {
209     std::vector<Tensor> b1 = bias_ih.chunk(3, /*output_channels*/0);
210     std::vector<Tensor> b2 = bias_hh.chunk(3, /*output_channels*/0);
211     return at::cat({b1[1] + b2[1], b1[0] + b2[0], b1[2], b2[2]}, /*output_channels*/0);
212   }
213   return bias_ih + bias_hh;
214 }
215 
mkldnn_rnn_layer(const Tensor & input,const Tensor & w0,const Tensor & w1,const Tensor & w2,const Tensor & w3,const Tensor & hx_,const Tensor & cx_,bool reverse,IntArrayRef batch_sizes,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool bidirectional,bool batch_first,bool train)216 std::tuple<Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer(const Tensor& input,
217     const Tensor& w0,
218     const Tensor& w1,
219     const Tensor& w2,
220     const Tensor& w3,
221     const Tensor& hx_,
222     const Tensor& cx_,
223     bool reverse,
224     IntArrayRef batch_sizes,
225     int64_t mode,
226     int64_t hidden_size,
227     int64_t num_layers,
228     bool has_biases,
229     bool bidirectional,
230     bool batch_first,
231     bool train) {
232   RNNParams rnn(
233       input,
234       batch_sizes,
235       mode,
236       hidden_size,
237       num_layers,
238       bidirectional,
239       batch_first,
240       train);
241 
242   auto output_size = _output_size</*is_single_direction*/ true>(rnn);
243   auto output = at::empty(output_size, input.options());
244 
245   auto hy_ = at::empty(hx_.sizes(), hx_.options());
246   auto cy_ = at::empty(cx_.sizes(), cx_.options());
247 
248   auto weight_ih = _shuffle_weight(w0, rnn.mode);
249   auto weight_hh = _shuffle_weight(w1, rnn.mode);
250 
251   // Packed weight will be mkldnn layout while bias won't be packed
252   auto bias = has_biases
253       ? _shuffle_bias(w2, w3, rnn.mode)
254       : at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options().layout(at::Layout::Strided));
255 
256   // per layer input size
257   int64_t input_size = input.size(2);
258   ideep::tensor w1_, w2_;
259   auto x = itensor_view_from_dense(
260       input,
261       rnn.src_layer_desc(input_size, get_mkldnn_dtype(input)));
262   auto hx = itensor_view_from_dense(
263       hx_, rnn.src_iter_desc(get_mkldnn_dtype(hx_)));
264   auto cx = itensor_view_from_dense(
265       cx_, rnn.src_iter_c_desc(get_mkldnn_dtype(cx_)));
266   auto b = itensor_view_from_dense(
267       bias, rnn.bias_desc(get_mkldnn_dtype(bias)));
268   auto y = itensor_view_from_dense(
269       output, rnn.dst_layer_desc(get_mkldnn_dtype(output)));
270   auto hy = itensor_view_from_dense(
271       hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_)));
272   auto cy = itensor_view_from_dense(
273       cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_)));
274   w1_ = weight_ih.is_mkldnn() ? itensor_from_tensor(weight_ih) : itensor_view_from_dense(weight_ih, rnn.weights_layer_desc(input_size, get_mkldnn_dtype(weight_ih)));
275   w2_ = weight_hh.is_mkldnn() ? itensor_from_tensor(weight_hh) : itensor_view_from_dense(weight_hh, rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh)));
276   if (at::GradMode::is_enabled()) {
277     Tensor workspace = Tensor();
278     auto pd = ideep::lstm_forward_training::prepare(
279         x, hx, cx, w1_, w2_, b, y, hy, cy, reverse);
280     workspace = at::empty(pd.workspace_desc().get_size() / sizeof(uint8_t), input.options().dtype(at::kByte));
281     ideep::tensor mkldnn_workspace;
282     mkldnn_workspace.init(
283         pd.workspace_desc(), workspace.template data_ptr<uint8_t>());
284     ideep::lstm_forward_training::compute(
285         pd, x, hx, cx, w1_, w2_, b, mkldnn_workspace, y, hy, cy, reverse, ideep::prop_kind::forward_training);
286     return std::make_tuple(output, hy_, cy_, workspace);
287   } else {
288     ideep::lstm_forward_inference::compute(
289         x, hx, cx, w1_, w2_, b, y, hy, cy, reverse, ideep::prop_kind::forward_inference);
290     return std::make_tuple(output, hy_, cy_, Tensor());
291   }
292 }
293 
mkldnn_rnn_layer_backward(const Tensor & input,const Tensor & weight0,const Tensor & weight1,const Tensor & weight2,const Tensor & weight3,const Tensor & hx_,const Tensor & cx_tmp,const Tensor & output,const Tensor & hy_,const Tensor & cy_,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,bool reverse,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool train,bool bidirectional,at::IntArrayRef batch_sizes,bool batch_first,const at::Tensor & workspace)294 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer_backward(
295     const Tensor& input,
296     const Tensor& weight0,
297     const Tensor& weight1,
298     const Tensor& weight2,
299     const Tensor& weight3,
300     const Tensor& hx_,
301     const Tensor& cx_tmp,
302     const Tensor& output,
303     const Tensor& hy_,
304     const Tensor& cy_,
305     const std::optional<Tensor>& grad_output_r_opt,
306     const std::optional<Tensor>& grad_hy_r_opt,
307     const std::optional<Tensor>& grad_cy_r_opt,
308     bool reverse,
309     int64_t mode,
310     int64_t hidden_size,
311     int64_t num_layers,
312     bool has_biases,
313     bool train,
314     bool bidirectional,
315     at::IntArrayRef batch_sizes,
316     bool batch_first,
317     const at::Tensor& workspace) {
318   const Tensor& grad_output_r = c10::value_or_else(grad_output_r_opt, [] {return Tensor();});
319   const Tensor& grad_hy_r = c10::value_or_else(grad_hy_r_opt, [] {return Tensor();});
320   const Tensor& grad_cy_r = c10::value_or_else(grad_cy_r_opt, [] {return Tensor();});
321   if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) {
322       return std::make_tuple(Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor());
323   }
324   auto grad_output = grad_output_r.defined() ? grad_output_r.contiguous() : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
325   auto grad_hy = grad_hy_r.defined() ? grad_hy_r.contiguous() : at::zeros_like(hx_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
326   auto grad_cy = cx_tmp.defined() ? (grad_cy_r.defined() ? grad_cy_r.contiguous() : at::zeros_like(cx_tmp, LEGACY_CONTIGUOUS_MEMORY_FORMAT)) : grad_cy_r.contiguous();
327   RNNParams rnn(
328       input,
329       batch_sizes,
330       mode,
331       hidden_size,
332       num_layers,
333       bidirectional,
334       batch_first,
335       train);
336   auto output_size = _output_size</*is_single_direction*/ true>(rnn);
337 
338   auto weight_ih = _shuffle_weight(weight0, rnn.mode);
339   auto weight_hh = _shuffle_weight(weight1, rnn.mode);
340   auto bias = has_biases
341       ? _shuffle_bias(weight2, weight3, rnn.mode)
342       : at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options());
343 
344   auto cx_  =  hx_.storage().unsafeGetStorageImpl() == cx_tmp.storage().unsafeGetStorageImpl() ? at::clone(cx_tmp) : cx_tmp;
345 
346   // per layer input size
347   int64_t input_size = input.size(2);
348   auto x = itensor_view_from_dense(
349       input,
350       rnn.src_layer_desc(input_size, get_mkldnn_dtype(input.scalar_type())));
351   auto hx = itensor_view_from_dense(
352       hx_, rnn.src_iter_desc(get_mkldnn_dtype(hx_.scalar_type())));
353   auto cx = itensor_view_from_dense(
354       cx_, rnn.src_iter_c_desc(get_mkldnn_dtype(cx_.scalar_type())));
355   auto w1 = itensor_view_from_dense(
356       weight_ih,
357       rnn.weights_layer_desc(
358           input_size, get_mkldnn_dtype(weight_ih.scalar_type())));
359   auto w2 = itensor_view_from_dense(
360       weight_hh,
361       rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh.scalar_type())));
362   auto b = itensor_view_from_dense(
363       bias, rnn.bias_desc(get_mkldnn_dtype(bias.scalar_type())));
364   auto y = itensor_view_from_dense(
365       output, rnn.dst_layer_desc(get_mkldnn_dtype(output.scalar_type())));
366   auto hy = itensor_view_from_dense(
367       hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_.scalar_type())));
368   auto cy = itensor_view_from_dense(
369       cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_.scalar_type())));
370 
371   // Create diff_* ATen tensor and corresponding ideep tensor as fp32
372   auto diff_x_ =
373       at::empty(input.sizes(), input.options().dtype(at::ScalarType::Float));
374   auto diff_hx_ =
375       at::empty(hx_.sizes(), hx_.options().dtype(at::ScalarType::Float));
376   auto diff_cx_ =
377       at::empty(cx_.sizes(), cx_.options().dtype(at::ScalarType::Float));
378   auto diff_w1_ = at::empty(
379       weight_ih.sizes(), weight_ih.options().dtype(at::ScalarType::Float));
380   auto diff_w2_ = at::empty(
381       weight_hh.sizes(), weight_hh.options().dtype(at::ScalarType::Float));
382   auto diff_b_ =
383       at::empty(bias.sizes(), bias.options().dtype(at::ScalarType::Float));
384 
385   auto diff_x = itensor_view_from_dense(
386       diff_x_, rnn.src_layer_desc(input_size, ideep::tensor::data_type::f32));
387   auto diff_hx = itensor_view_from_dense(
388       diff_hx_, rnn.src_iter_desc(ideep::tensor::data_type::f32));
389   auto diff_cx = itensor_view_from_dense(
390       diff_cx_, rnn.src_iter_c_desc(ideep::tensor::data_type::f32));
391   auto diff_w1 = itensor_view_from_dense(
392       diff_w1_,
393       rnn.weights_layer_desc(input_size, ideep::tensor::data_type::f32));
394   auto diff_w2 = itensor_view_from_dense(
395       diff_w2_, rnn.weights_iter_desc(ideep::tensor::data_type::f32));
396   auto diff_b = itensor_view_from_dense(
397       diff_b_, rnn.bias_desc(ideep::tensor::data_type::f32));
398 
399   // Convert grad_y, grad_hy, grad_cy to fp32 in non-fp32 backward
400   ideep::tensor diff_y, diff_hy, diff_cy;
401   at::Tensor grad_y_, grad_hy_, grad_cy_;
402   if (input.scalar_type() != at::ScalarType::Float) {
403     grad_y_ = at::empty(
404         grad_output.sizes(),
405         grad_output.options().dtype(at::ScalarType::Float));
406     grad_y_.copy_(grad_output);
407     grad_hy_ = at::empty(
408         grad_hy.sizes(), grad_hy.options().dtype(at::ScalarType::Float));
409     grad_hy_.copy_(grad_hy);
410     grad_cy_ = at::empty(
411         grad_cy.sizes(), grad_cy.options().dtype(at::ScalarType::Float));
412     grad_cy_.copy_(grad_cy);
413 
414     diff_y = itensor_view_from_dense(
415         grad_y_, rnn.dst_layer_desc(get_mkldnn_dtype(grad_y_.scalar_type())));
416     diff_hy = itensor_view_from_dense(
417         grad_hy_, rnn.dst_iter_desc(get_mkldnn_dtype(grad_hy_.scalar_type())));
418     diff_cy = itensor_view_from_dense(
419         grad_cy_, rnn.dst_iter_desc(get_mkldnn_dtype(grad_cy_.scalar_type())));
420   } else {
421     diff_y = itensor_view_from_dense(
422         grad_output, rnn.dst_layer_desc(ideep::tensor::data_type::f32));
423     diff_hy = itensor_view_from_dense(
424         grad_hy, rnn.dst_iter_desc(ideep::tensor::data_type::f32));
425     diff_cy = itensor_view_from_dense(
426         grad_cy, rnn.dst_iter_desc(ideep::tensor::data_type::f32));
427   }
428 
429   auto forward_hint = ideep::lstm_forward_training::prepare(x, hx, cx, w1, w2, b, y, hy, cy, reverse);
430   ideep::tensor mkldnn_workspace;
431   mkldnn_workspace.init(
432       forward_hint.workspace_desc(), workspace.template data_ptr<uint8_t>());
433   ideep::lstm_backward::compute(forward_hint, x, hx, cx, w1, w2, b, y, hy, cy, diff_y, diff_hy, diff_cy, mkldnn_workspace, diff_x, diff_hx, diff_cx, diff_w1, diff_w2, diff_b, reverse);
434   auto diff_b2_ = at::clone(diff_b_);
435   return std::make_tuple(diff_x_, diff_w1_, diff_w2_, diff_b_, diff_b2_, diff_hx_, diff_cx_);
436 }
437 
438 // MKLDNN RNN integration notes:
439 // I. Memory Formats
440 //   a. mkldnn will use plain formats for input, hx/cx, output, hy/cy
441 //      and possibly use blocked formats for weights depending shape info.
442 //   b. All mkldnn memorys are created (in plain format) as views on ATen tensor,
443 //      the weight reorder(if any) is handed automatically inside ideep (mkldnn bridge)
444 //
445 // II. MKLDNN Primitive Mapping
446 //   a. mkldnn rnn primitive doesn't support training with dropout or padded input sequence.
447 //   b. here break a single RNN module into { num_layers * num_directions } mkldnn rnn primitives
448 //      for future need to cover these feature gaps.
449 //
450 //TODO: a. training with dropout
451 //   b. padded sequence input support
452 //
453 
mkldnn_rnn(const Tensor & input_,TensorList weight,int64_t weight_stride0,const Tensor & hx_,const Tensor & cx_,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool batch_first,double dropout_p,bool train,bool bidirectional,IntArrayRef batch_sizes)454 static std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
455     const Tensor& input_, TensorList weight, int64_t weight_stride0,
456     const Tensor& hx_, const Tensor& cx_,
457     int64_t mode, int64_t hidden_size,
458     int64_t num_layers, bool has_biases, bool batch_first, double dropout_p,
459     bool train, bool bidirectional, IntArrayRef batch_sizes) {
460   TORCH_CHECK(batch_sizes.size() == 0, "mkldnn_rnn doesn't support packed input");
461   if (static_cast<ideep::rnn_kind>(mode) != ideep::rnn_kind::LSTM) {
462     TORCH_CHECK(!cx_.defined(), "mkldnn_rnn: illegal defined cx for non-LSTM RNN");
463   }
464 
465   auto input = input_;
466   if (batch_first) {
467     input = input.transpose(0, 1);
468   }
469   input = input.contiguous();
470 
471   auto hx = hx_.contiguous();
472   auto cx = cx_.contiguous();
473 
474   MatrixRef<Tensor> weights{weight, static_cast<size_t>(weight_stride0)};
475 
476   auto num_directions = bidirectional ? 2 : 1;
477   auto layer_input = input;
478   std::vector<at::Tensor> layer_output(num_directions);
479   std::vector<at::Tensor> layer_hy(num_layers * num_directions);
480   std::vector<at::Tensor> layer_cy(num_layers * num_directions);
481   for (const auto layer: c10::irange(num_layers)) {
482     for (const auto direction: c10::irange(num_directions)) {
483       const auto index = layer * num_directions + direction;
484       auto layer_weights = weights[index];
485       TORCH_CHECK(layer_weights.size() == 2 || layer_weights.size() == 4);
486       auto layer_hx = hx[index];
487       auto layer_cx = cx[index];
488       auto reverse = (direction > 0);
489       // bias won't be packed
490       auto outputs = at::mkldnn_rnn_layer(layer_input, layer_weights[0], layer_weights[1],
491                                         has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options().layout(at::Layout::Strided)),
492           has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options().layout(at::Layout::Strided)), layer_hx,
493           layer_cx, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train);
494       layer_output[direction] = std::get<0>(outputs);
495       layer_hy[index] = std::get<1>(outputs);
496       layer_cy[index] = std::get<2>(outputs);
497     }
498     layer_input = num_directions == 1 ? layer_output[0]
499                                       : at::cat(layer_output, /*output_channels*/-1);
500     if (dropout_p != 0 && train && layer < num_layers - 1) {
501       layer_input = at::dropout(layer_input, dropout_p, /*train=*/true);
502     }
503   }
504   auto output = layer_input;
505   auto hy = at::stack(layer_hy, 0);
506   auto cy = at::stack(layer_cy, 0);
507   if (batch_first) {
508     output = output.transpose(0, 1);
509   }
510   return std::make_tuple(output, hy, cy);
511 }
512 
513 ////////////////////////////////////////////////////////////////////////////////
514 //// MKLDNN dispatch for the generic RNN ops (at::lstm, at::gru, ...)
515 ////////////////////////////////////////////////////////////////////////////////
516 
517 namespace {
518 
519 // Helpers for working with different hidden types.
unpack_hidden(const std::tuple<Tensor,Tensor> & hidden)520 std::tuple<Tensor, Tensor> unpack_hidden(const std::tuple<Tensor, Tensor>& hidden) {
521   return hidden;
522 }
523 
524 template<typename hidden_type>
pack_hidden(const Tensor & hx,const Tensor & cx)525 hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) {
526   static_assert(false && sizeof(hidden_type), "pack_hidden not implemented for this type");
527 }
528 
529 template<>
pack_hidden(const Tensor & hx,const Tensor & cx)530 std::tuple<Tensor, Tensor> pack_hidden<std::tuple<Tensor, Tensor>>(const Tensor& hx, const Tensor& cx) {
531   return std::make_tuple(hx, cx);
532 }
533 
534 template<typename hidden_type>
mkldnn_impl(const Tensor & input,const hidden_type & hidden,TensorList params,bool has_biases,ideep::rnn_kind mode,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first)535 std::pair<Tensor, hidden_type> mkldnn_impl(
536     const Tensor& input, const hidden_type& hidden,
537     TensorList params, bool has_biases, ideep::rnn_kind mode,
538     int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
539   auto [hx, cx] = unpack_hidden(hidden);
540   int64_t hidden_size = hx.size(2);
541 
542   auto mkldnn_output = mkldnn_rnn(
543       input, params, has_biases ? 4 : 2,
544       hx, cx, static_cast<int>(mode), hidden_size, num_layers, has_biases, batch_first, dropout_p,
545       train, bidirectional, /*batch_sizes*/{});
546 
547   return {std::get<0>(mkldnn_output),
548           pack_hidden<hidden_type>(std::get<1>(mkldnn_output), std::get<2>(mkldnn_output))};
549 }
550 
lstm_mkldnn(Tensor & output,Tensor & hy,Tensor & cy,const Tensor & input,TensorList hx,TensorList params,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first)551 void lstm_mkldnn(Tensor& output, Tensor& hy, Tensor& cy,
552     const Tensor& input, TensorList hx, TensorList params, bool has_biases,
553     int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
554   auto result = mkldnn_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases,
555       ideep::rnn_kind::LSTM, num_layers, dropout_p, train, bidirectional, batch_first);
556   output = result.first;
557   hy = std::get<0>(result.second);
558   cy = std::get<1>(result.second);
559 }
560 } // anonymous namespace
561 
562 REGISTER_ALL_CPU_DISPATCH(lstm_mkldnn_stub, &lstm_mkldnn);
563 
564 } // namespace at::native
565 
566 #endif // AT_MKLDNN_ENABLED
567