1 #include <ATen/native/RNN.h>
2 #include <ATen/ATen.h>
3 #include <ATen/Config.h>
4 #include <ATen/InitialTensorOptions.h>
5 #include <ATen/MatrixRef.h>
6
7 #include <ATen/TensorUtils.h>
8 #include <ATen/Dispatch.h>
9 #include <c10/core/GradMode.h>
10 #include <c10/macros/Macros.h>
11 #include <c10/util/Exception.h>
12 #include <torch/library.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/mkldnn_convolution_native.h>
18 #include <ATen/ops/mkldnn_rnn_layer_backward_native.h>
19 #include <ATen/ops/mkldnn_rnn_layer_native.h>
20 #endif
21
22 #if !AT_MKLDNN_ENABLED()
23
24 namespace at::native {
25
26
mkldnn_rnn_layer(const Tensor & input,const Tensor & w0,const Tensor & w1,const Tensor & w2,const Tensor & w3,const Tensor & hx_,const Tensor & cx_,bool reverse,IntArrayRef batch_sizes,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool bidirectional,bool batch_first,bool train)27 std::tuple<Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer(
28 const Tensor& input,
29 const Tensor& w0,
30 const Tensor& w1,
31 const Tensor& w2,
32 const Tensor& w3,
33 const Tensor& hx_,
34 const Tensor& cx_,
35 bool reverse,
36 IntArrayRef batch_sizes,
37 int64_t mode,
38 int64_t hidden_size,
39 int64_t num_layers,
40 bool has_biases,
41 bool bidirectional,
42 bool batch_first,
43 bool train) {
44 AT_ERROR("mkldnn_rnn_layer: ATen not compiled with MKLDNN support");
45 }
46
mkldnn_rnn_layer_backward(const Tensor & input,const Tensor & weight0,const Tensor & weight1,const Tensor & weight2,const Tensor & weight3,const Tensor & hx_,const Tensor & cx_tmp,const Tensor & output,const Tensor & hy_,const Tensor & cy_,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,bool reverse,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool train,bool bidirectional,at::IntArrayRef batch_sizes,bool batch_first,const at::Tensor & workspace)47 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer_backward(
48 const Tensor& input,
49 const Tensor& weight0,
50 const Tensor& weight1,
51 const Tensor& weight2,
52 const Tensor& weight3,
53 const Tensor& hx_,
54 const Tensor& cx_tmp,
55 const Tensor& output,
56 const Tensor& hy_,
57 const Tensor& cy_,
58 const std::optional<Tensor>& grad_output_r_opt,
59 const std::optional<Tensor>& grad_hy_r_opt,
60 const std::optional<Tensor>& grad_cy_r_opt,
61 bool reverse,
62 int64_t mode,
63 int64_t hidden_size,
64 int64_t num_layers,
65 bool has_biases,
66 bool train,
67 bool bidirectional,
68 at::IntArrayRef batch_sizes,
69 bool batch_first,
70 const at::Tensor& workspace) {
71 AT_ERROR("mkldnn_rnn_layer_backward: ATen not compiled with MKLDNN support");
72 }
73
74 REGISTER_NO_CPU_DISPATCH(lstm_mkldnn_stub);
75
76 } // namespace at::native
77
78 #else // AT_MKLDNN_ENABLED
79
80 #include <ATen/native/mkldnn/MKLDNNCommon.h>
81 #include <ATen/native/mkldnn/Utils.h>
82
83 namespace at::native {
84
85 struct RNNParams {
86 ideep::rnn_kind mode;
87 int64_t seq_length;
88 int64_t mini_batch;
89 int64_t input_size;
90 int64_t hidden_size;
91 int64_t num_directions;
92 int64_t num_layers;
93 bool batch_first;
94 bool train;
95 at::IntArrayRef batch_sizes;
96 int64_t num_gates;
97 int64_t num_bias_gates;
98
RNNParamsat::native::RNNParams99 RNNParams(
100 const at::Tensor& input,
101 at::IntArrayRef batch_sizes_,
102 int64_t mode_,
103 int64_t hidden_size_,
104 int64_t num_layers_,
105 bool bidirectional,
106 bool batch_first_,
107 bool train_) {
108 mode = static_cast<ideep::rnn_kind>(mode_);
109 batch_first = batch_first_;
110 seq_length = input.size(0);
111 mini_batch = input.size(1);
112 input_size = input.size(2);
113 hidden_size = hidden_size_;
114 num_directions = bidirectional ? 2 : 1;
115 num_layers = num_layers_;
116 train = train_;
117 batch_sizes = batch_sizes_;
118 if (mode == ideep::rnn_kind::LSTM) {
119 num_gates = 4;
120 num_bias_gates = 4;
121 } else if (mode == ideep::rnn_kind::GRU) {
122 num_gates = 3;
123 num_bias_gates = 4;
124 } else {
125 // RNN_RELU; RNN_TANH
126 num_gates = 1;
127 num_bias_gates = 1;
128 }
129 }
130
131 // mkldnn memory descriptors
132 using format = ideep::format_tag;
133 using desc = ideep::tensor::desc;
134 using dtype = ideep::tensor::data_type;
src_layer_descat::native::RNNParams135 desc src_layer_desc(int64_t _input_size, dtype dtype) const {
136 return {{seq_length, mini_batch, _input_size}, dtype, format::tnc};
137 }
src_iter_descat::native::RNNParams138 desc src_iter_desc(dtype dtype) const {
139 return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc};
140 }
src_iter_c_descat::native::RNNParams141 desc src_iter_c_desc(dtype dtype) const {
142 return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc};
143 }
144 // logical size described as ldigo
weights_layer_descat::native::RNNParams145 desc weights_layer_desc(int64_t _input_size, dtype dtype) const {
146 return {{1, 1, _input_size, num_gates, hidden_size}, dtype, format::ldgoi};
147 }
weights_layer_ldigo_descat::native::RNNParams148 desc weights_layer_ldigo_desc(int64_t _input_size, dtype dtype) const {
149 return {{1, 1, _input_size, num_gates, hidden_size}, dtype, format::ldigo};
150 }
weights_iter_descat::native::RNNParams151 desc weights_iter_desc(dtype dtype) const {
152 return {{1, 1, hidden_size, num_gates, hidden_size}, dtype, format::ldgoi};
153 }
weights_iter_ldigo_descat::native::RNNParams154 desc weights_iter_ldigo_desc(dtype dtype) const {
155 return {{1, 1, hidden_size, num_gates, hidden_size}, dtype, format::ldigo};
156 }
bias_descat::native::RNNParams157 desc bias_desc(dtype dtype) const {
158 return {{1, 1, num_bias_gates, hidden_size}, dtype, format::ldgo};
159 }
dst_layer_descat::native::RNNParams160 desc dst_layer_desc(dtype dtype) const {
161 return {{seq_length, mini_batch, hidden_size}, dtype, format::tnc};
162 }
dst_iter_descat::native::RNNParams163 desc dst_iter_desc(dtype dtype) const {
164 return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc};
165 }
dst_iter_c_descat::native::RNNParams166 desc dst_iter_c_desc(dtype dtype) const {
167 return {{1, 1, mini_batch, hidden_size}, dtype, format::ldnc};
168 }
169 };
170
171 template<bool is_single_direction>
_output_size(const RNNParams & rnn)172 std::vector<int64_t> _output_size(const RNNParams& rnn) {
173 auto output_channels = is_single_direction ? rnn.hidden_size
174 : rnn.hidden_size * rnn.num_directions;
175 return {rnn.seq_length, rnn.mini_batch, output_channels};
176 }
177
178 // MKLDNN GRU gate order is different from PyTorch's which requires gates shuffle
179 // (let rt,zt,nt be reset, update, new gates respectively)
180 //
181 // MKLDNN GRU weight_ih/weight_hh gates order: (zt, rt, nt)
182 // PyTorch GRU weight_ih/weight_hh gates order: (rt, zt, nt)
183 //
184 // MKLDNN GRU bias has 4 gates instead of 3
185 // (PyTorch GRU bias) (MKLDNN GRU bias)
186 //
187 // bias_ih bias_hh bias
188 // +-----+ +-----+ +---------+
189 // | rt1 | | rt2 | | zt1+zt2 |
190 // |-----| |-----| |---------|
191 // | zt1 | | zt2 | | rt1+rt2 |
192 // |-----| |-----| |---------|
193 // | nt1 | | nt2 | | nt1 |
194 // +-----+ +-----+ |---------|
195 // | nt2 |
196 // +---------+
197 //
_shuffle_weight(const Tensor & weight,int64_t fn_mode)198 static Tensor _shuffle_weight(const Tensor& weight, int64_t fn_mode) {
199 auto weight_t = weight.contiguous();
200 if (static_cast<ideep::rnn_kind>(fn_mode) == ideep::rnn_kind::GRU) {
201 std::vector<Tensor> gates = weight_t.chunk(3, /*gates*/0);
202 return at::cat({gates[1], gates[0], gates[2]}, /*gates*/0);
203 }
204 return weight_t;
205 }
206
_shuffle_bias(const Tensor & bias_ih,const Tensor & bias_hh,int64_t fn_mode)207 static Tensor _shuffle_bias(const Tensor& bias_ih, const Tensor& bias_hh, int64_t fn_mode) {
208 if (static_cast<ideep::rnn_kind>(fn_mode) == ideep::rnn_kind::GRU) {
209 std::vector<Tensor> b1 = bias_ih.chunk(3, /*output_channels*/0);
210 std::vector<Tensor> b2 = bias_hh.chunk(3, /*output_channels*/0);
211 return at::cat({b1[1] + b2[1], b1[0] + b2[0], b1[2], b2[2]}, /*output_channels*/0);
212 }
213 return bias_ih + bias_hh;
214 }
215
mkldnn_rnn_layer(const Tensor & input,const Tensor & w0,const Tensor & w1,const Tensor & w2,const Tensor & w3,const Tensor & hx_,const Tensor & cx_,bool reverse,IntArrayRef batch_sizes,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool bidirectional,bool batch_first,bool train)216 std::tuple<Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer(const Tensor& input,
217 const Tensor& w0,
218 const Tensor& w1,
219 const Tensor& w2,
220 const Tensor& w3,
221 const Tensor& hx_,
222 const Tensor& cx_,
223 bool reverse,
224 IntArrayRef batch_sizes,
225 int64_t mode,
226 int64_t hidden_size,
227 int64_t num_layers,
228 bool has_biases,
229 bool bidirectional,
230 bool batch_first,
231 bool train) {
232 RNNParams rnn(
233 input,
234 batch_sizes,
235 mode,
236 hidden_size,
237 num_layers,
238 bidirectional,
239 batch_first,
240 train);
241
242 auto output_size = _output_size</*is_single_direction*/ true>(rnn);
243 auto output = at::empty(output_size, input.options());
244
245 auto hy_ = at::empty(hx_.sizes(), hx_.options());
246 auto cy_ = at::empty(cx_.sizes(), cx_.options());
247
248 auto weight_ih = _shuffle_weight(w0, rnn.mode);
249 auto weight_hh = _shuffle_weight(w1, rnn.mode);
250
251 // Packed weight will be mkldnn layout while bias won't be packed
252 auto bias = has_biases
253 ? _shuffle_bias(w2, w3, rnn.mode)
254 : at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options().layout(at::Layout::Strided));
255
256 // per layer input size
257 int64_t input_size = input.size(2);
258 ideep::tensor w1_, w2_;
259 auto x = itensor_view_from_dense(
260 input,
261 rnn.src_layer_desc(input_size, get_mkldnn_dtype(input)));
262 auto hx = itensor_view_from_dense(
263 hx_, rnn.src_iter_desc(get_mkldnn_dtype(hx_)));
264 auto cx = itensor_view_from_dense(
265 cx_, rnn.src_iter_c_desc(get_mkldnn_dtype(cx_)));
266 auto b = itensor_view_from_dense(
267 bias, rnn.bias_desc(get_mkldnn_dtype(bias)));
268 auto y = itensor_view_from_dense(
269 output, rnn.dst_layer_desc(get_mkldnn_dtype(output)));
270 auto hy = itensor_view_from_dense(
271 hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_)));
272 auto cy = itensor_view_from_dense(
273 cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_)));
274 w1_ = weight_ih.is_mkldnn() ? itensor_from_tensor(weight_ih) : itensor_view_from_dense(weight_ih, rnn.weights_layer_desc(input_size, get_mkldnn_dtype(weight_ih)));
275 w2_ = weight_hh.is_mkldnn() ? itensor_from_tensor(weight_hh) : itensor_view_from_dense(weight_hh, rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh)));
276 if (at::GradMode::is_enabled()) {
277 Tensor workspace = Tensor();
278 auto pd = ideep::lstm_forward_training::prepare(
279 x, hx, cx, w1_, w2_, b, y, hy, cy, reverse);
280 workspace = at::empty(pd.workspace_desc().get_size() / sizeof(uint8_t), input.options().dtype(at::kByte));
281 ideep::tensor mkldnn_workspace;
282 mkldnn_workspace.init(
283 pd.workspace_desc(), workspace.template data_ptr<uint8_t>());
284 ideep::lstm_forward_training::compute(
285 pd, x, hx, cx, w1_, w2_, b, mkldnn_workspace, y, hy, cy, reverse, ideep::prop_kind::forward_training);
286 return std::make_tuple(output, hy_, cy_, workspace);
287 } else {
288 ideep::lstm_forward_inference::compute(
289 x, hx, cx, w1_, w2_, b, y, hy, cy, reverse, ideep::prop_kind::forward_inference);
290 return std::make_tuple(output, hy_, cy_, Tensor());
291 }
292 }
293
mkldnn_rnn_layer_backward(const Tensor & input,const Tensor & weight0,const Tensor & weight1,const Tensor & weight2,const Tensor & weight3,const Tensor & hx_,const Tensor & cx_tmp,const Tensor & output,const Tensor & hy_,const Tensor & cy_,const std::optional<Tensor> & grad_output_r_opt,const std::optional<Tensor> & grad_hy_r_opt,const std::optional<Tensor> & grad_cy_r_opt,bool reverse,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool train,bool bidirectional,at::IntArrayRef batch_sizes,bool batch_first,const at::Tensor & workspace)294 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> mkldnn_rnn_layer_backward(
295 const Tensor& input,
296 const Tensor& weight0,
297 const Tensor& weight1,
298 const Tensor& weight2,
299 const Tensor& weight3,
300 const Tensor& hx_,
301 const Tensor& cx_tmp,
302 const Tensor& output,
303 const Tensor& hy_,
304 const Tensor& cy_,
305 const std::optional<Tensor>& grad_output_r_opt,
306 const std::optional<Tensor>& grad_hy_r_opt,
307 const std::optional<Tensor>& grad_cy_r_opt,
308 bool reverse,
309 int64_t mode,
310 int64_t hidden_size,
311 int64_t num_layers,
312 bool has_biases,
313 bool train,
314 bool bidirectional,
315 at::IntArrayRef batch_sizes,
316 bool batch_first,
317 const at::Tensor& workspace) {
318 const Tensor& grad_output_r = c10::value_or_else(grad_output_r_opt, [] {return Tensor();});
319 const Tensor& grad_hy_r = c10::value_or_else(grad_hy_r_opt, [] {return Tensor();});
320 const Tensor& grad_cy_r = c10::value_or_else(grad_cy_r_opt, [] {return Tensor();});
321 if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) {
322 return std::make_tuple(Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor());
323 }
324 auto grad_output = grad_output_r.defined() ? grad_output_r.contiguous() : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
325 auto grad_hy = grad_hy_r.defined() ? grad_hy_r.contiguous() : at::zeros_like(hx_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
326 auto grad_cy = cx_tmp.defined() ? (grad_cy_r.defined() ? grad_cy_r.contiguous() : at::zeros_like(cx_tmp, LEGACY_CONTIGUOUS_MEMORY_FORMAT)) : grad_cy_r.contiguous();
327 RNNParams rnn(
328 input,
329 batch_sizes,
330 mode,
331 hidden_size,
332 num_layers,
333 bidirectional,
334 batch_first,
335 train);
336 auto output_size = _output_size</*is_single_direction*/ true>(rnn);
337
338 auto weight_ih = _shuffle_weight(weight0, rnn.mode);
339 auto weight_hh = _shuffle_weight(weight1, rnn.mode);
340 auto bias = has_biases
341 ? _shuffle_bias(weight2, weight3, rnn.mode)
342 : at::zeros({rnn.num_bias_gates * rnn.hidden_size}, weight_ih.options());
343
344 auto cx_ = hx_.storage().unsafeGetStorageImpl() == cx_tmp.storage().unsafeGetStorageImpl() ? at::clone(cx_tmp) : cx_tmp;
345
346 // per layer input size
347 int64_t input_size = input.size(2);
348 auto x = itensor_view_from_dense(
349 input,
350 rnn.src_layer_desc(input_size, get_mkldnn_dtype(input.scalar_type())));
351 auto hx = itensor_view_from_dense(
352 hx_, rnn.src_iter_desc(get_mkldnn_dtype(hx_.scalar_type())));
353 auto cx = itensor_view_from_dense(
354 cx_, rnn.src_iter_c_desc(get_mkldnn_dtype(cx_.scalar_type())));
355 auto w1 = itensor_view_from_dense(
356 weight_ih,
357 rnn.weights_layer_desc(
358 input_size, get_mkldnn_dtype(weight_ih.scalar_type())));
359 auto w2 = itensor_view_from_dense(
360 weight_hh,
361 rnn.weights_iter_desc(get_mkldnn_dtype(weight_hh.scalar_type())));
362 auto b = itensor_view_from_dense(
363 bias, rnn.bias_desc(get_mkldnn_dtype(bias.scalar_type())));
364 auto y = itensor_view_from_dense(
365 output, rnn.dst_layer_desc(get_mkldnn_dtype(output.scalar_type())));
366 auto hy = itensor_view_from_dense(
367 hy_, rnn.dst_iter_desc(get_mkldnn_dtype(hy_.scalar_type())));
368 auto cy = itensor_view_from_dense(
369 cy_, rnn.dst_iter_c_desc(get_mkldnn_dtype(cy_.scalar_type())));
370
371 // Create diff_* ATen tensor and corresponding ideep tensor as fp32
372 auto diff_x_ =
373 at::empty(input.sizes(), input.options().dtype(at::ScalarType::Float));
374 auto diff_hx_ =
375 at::empty(hx_.sizes(), hx_.options().dtype(at::ScalarType::Float));
376 auto diff_cx_ =
377 at::empty(cx_.sizes(), cx_.options().dtype(at::ScalarType::Float));
378 auto diff_w1_ = at::empty(
379 weight_ih.sizes(), weight_ih.options().dtype(at::ScalarType::Float));
380 auto diff_w2_ = at::empty(
381 weight_hh.sizes(), weight_hh.options().dtype(at::ScalarType::Float));
382 auto diff_b_ =
383 at::empty(bias.sizes(), bias.options().dtype(at::ScalarType::Float));
384
385 auto diff_x = itensor_view_from_dense(
386 diff_x_, rnn.src_layer_desc(input_size, ideep::tensor::data_type::f32));
387 auto diff_hx = itensor_view_from_dense(
388 diff_hx_, rnn.src_iter_desc(ideep::tensor::data_type::f32));
389 auto diff_cx = itensor_view_from_dense(
390 diff_cx_, rnn.src_iter_c_desc(ideep::tensor::data_type::f32));
391 auto diff_w1 = itensor_view_from_dense(
392 diff_w1_,
393 rnn.weights_layer_desc(input_size, ideep::tensor::data_type::f32));
394 auto diff_w2 = itensor_view_from_dense(
395 diff_w2_, rnn.weights_iter_desc(ideep::tensor::data_type::f32));
396 auto diff_b = itensor_view_from_dense(
397 diff_b_, rnn.bias_desc(ideep::tensor::data_type::f32));
398
399 // Convert grad_y, grad_hy, grad_cy to fp32 in non-fp32 backward
400 ideep::tensor diff_y, diff_hy, diff_cy;
401 at::Tensor grad_y_, grad_hy_, grad_cy_;
402 if (input.scalar_type() != at::ScalarType::Float) {
403 grad_y_ = at::empty(
404 grad_output.sizes(),
405 grad_output.options().dtype(at::ScalarType::Float));
406 grad_y_.copy_(grad_output);
407 grad_hy_ = at::empty(
408 grad_hy.sizes(), grad_hy.options().dtype(at::ScalarType::Float));
409 grad_hy_.copy_(grad_hy);
410 grad_cy_ = at::empty(
411 grad_cy.sizes(), grad_cy.options().dtype(at::ScalarType::Float));
412 grad_cy_.copy_(grad_cy);
413
414 diff_y = itensor_view_from_dense(
415 grad_y_, rnn.dst_layer_desc(get_mkldnn_dtype(grad_y_.scalar_type())));
416 diff_hy = itensor_view_from_dense(
417 grad_hy_, rnn.dst_iter_desc(get_mkldnn_dtype(grad_hy_.scalar_type())));
418 diff_cy = itensor_view_from_dense(
419 grad_cy_, rnn.dst_iter_desc(get_mkldnn_dtype(grad_cy_.scalar_type())));
420 } else {
421 diff_y = itensor_view_from_dense(
422 grad_output, rnn.dst_layer_desc(ideep::tensor::data_type::f32));
423 diff_hy = itensor_view_from_dense(
424 grad_hy, rnn.dst_iter_desc(ideep::tensor::data_type::f32));
425 diff_cy = itensor_view_from_dense(
426 grad_cy, rnn.dst_iter_desc(ideep::tensor::data_type::f32));
427 }
428
429 auto forward_hint = ideep::lstm_forward_training::prepare(x, hx, cx, w1, w2, b, y, hy, cy, reverse);
430 ideep::tensor mkldnn_workspace;
431 mkldnn_workspace.init(
432 forward_hint.workspace_desc(), workspace.template data_ptr<uint8_t>());
433 ideep::lstm_backward::compute(forward_hint, x, hx, cx, w1, w2, b, y, hy, cy, diff_y, diff_hy, diff_cy, mkldnn_workspace, diff_x, diff_hx, diff_cx, diff_w1, diff_w2, diff_b, reverse);
434 auto diff_b2_ = at::clone(diff_b_);
435 return std::make_tuple(diff_x_, diff_w1_, diff_w2_, diff_b_, diff_b2_, diff_hx_, diff_cx_);
436 }
437
438 // MKLDNN RNN integration notes:
439 // I. Memory Formats
440 // a. mkldnn will use plain formats for input, hx/cx, output, hy/cy
441 // and possibly use blocked formats for weights depending shape info.
442 // b. All mkldnn memorys are created (in plain format) as views on ATen tensor,
443 // the weight reorder(if any) is handed automatically inside ideep (mkldnn bridge)
444 //
445 // II. MKLDNN Primitive Mapping
446 // a. mkldnn rnn primitive doesn't support training with dropout or padded input sequence.
447 // b. here break a single RNN module into { num_layers * num_directions } mkldnn rnn primitives
448 // for future need to cover these feature gaps.
449 //
450 //TODO: a. training with dropout
451 // b. padded sequence input support
452 //
453
mkldnn_rnn(const Tensor & input_,TensorList weight,int64_t weight_stride0,const Tensor & hx_,const Tensor & cx_,int64_t mode,int64_t hidden_size,int64_t num_layers,bool has_biases,bool batch_first,double dropout_p,bool train,bool bidirectional,IntArrayRef batch_sizes)454 static std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
455 const Tensor& input_, TensorList weight, int64_t weight_stride0,
456 const Tensor& hx_, const Tensor& cx_,
457 int64_t mode, int64_t hidden_size,
458 int64_t num_layers, bool has_biases, bool batch_first, double dropout_p,
459 bool train, bool bidirectional, IntArrayRef batch_sizes) {
460 TORCH_CHECK(batch_sizes.size() == 0, "mkldnn_rnn doesn't support packed input");
461 if (static_cast<ideep::rnn_kind>(mode) != ideep::rnn_kind::LSTM) {
462 TORCH_CHECK(!cx_.defined(), "mkldnn_rnn: illegal defined cx for non-LSTM RNN");
463 }
464
465 auto input = input_;
466 if (batch_first) {
467 input = input.transpose(0, 1);
468 }
469 input = input.contiguous();
470
471 auto hx = hx_.contiguous();
472 auto cx = cx_.contiguous();
473
474 MatrixRef<Tensor> weights{weight, static_cast<size_t>(weight_stride0)};
475
476 auto num_directions = bidirectional ? 2 : 1;
477 auto layer_input = input;
478 std::vector<at::Tensor> layer_output(num_directions);
479 std::vector<at::Tensor> layer_hy(num_layers * num_directions);
480 std::vector<at::Tensor> layer_cy(num_layers * num_directions);
481 for (const auto layer: c10::irange(num_layers)) {
482 for (const auto direction: c10::irange(num_directions)) {
483 const auto index = layer * num_directions + direction;
484 auto layer_weights = weights[index];
485 TORCH_CHECK(layer_weights.size() == 2 || layer_weights.size() == 4);
486 auto layer_hx = hx[index];
487 auto layer_cx = cx[index];
488 auto reverse = (direction > 0);
489 // bias won't be packed
490 auto outputs = at::mkldnn_rnn_layer(layer_input, layer_weights[0], layer_weights[1],
491 has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options().layout(at::Layout::Strided)),
492 has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options().layout(at::Layout::Strided)), layer_hx,
493 layer_cx, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train);
494 layer_output[direction] = std::get<0>(outputs);
495 layer_hy[index] = std::get<1>(outputs);
496 layer_cy[index] = std::get<2>(outputs);
497 }
498 layer_input = num_directions == 1 ? layer_output[0]
499 : at::cat(layer_output, /*output_channels*/-1);
500 if (dropout_p != 0 && train && layer < num_layers - 1) {
501 layer_input = at::dropout(layer_input, dropout_p, /*train=*/true);
502 }
503 }
504 auto output = layer_input;
505 auto hy = at::stack(layer_hy, 0);
506 auto cy = at::stack(layer_cy, 0);
507 if (batch_first) {
508 output = output.transpose(0, 1);
509 }
510 return std::make_tuple(output, hy, cy);
511 }
512
513 ////////////////////////////////////////////////////////////////////////////////
514 //// MKLDNN dispatch for the generic RNN ops (at::lstm, at::gru, ...)
515 ////////////////////////////////////////////////////////////////////////////////
516
517 namespace {
518
519 // Helpers for working with different hidden types.
unpack_hidden(const std::tuple<Tensor,Tensor> & hidden)520 std::tuple<Tensor, Tensor> unpack_hidden(const std::tuple<Tensor, Tensor>& hidden) {
521 return hidden;
522 }
523
524 template<typename hidden_type>
pack_hidden(const Tensor & hx,const Tensor & cx)525 hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) {
526 static_assert(false && sizeof(hidden_type), "pack_hidden not implemented for this type");
527 }
528
529 template<>
pack_hidden(const Tensor & hx,const Tensor & cx)530 std::tuple<Tensor, Tensor> pack_hidden<std::tuple<Tensor, Tensor>>(const Tensor& hx, const Tensor& cx) {
531 return std::make_tuple(hx, cx);
532 }
533
534 template<typename hidden_type>
mkldnn_impl(const Tensor & input,const hidden_type & hidden,TensorList params,bool has_biases,ideep::rnn_kind mode,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first)535 std::pair<Tensor, hidden_type> mkldnn_impl(
536 const Tensor& input, const hidden_type& hidden,
537 TensorList params, bool has_biases, ideep::rnn_kind mode,
538 int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
539 auto [hx, cx] = unpack_hidden(hidden);
540 int64_t hidden_size = hx.size(2);
541
542 auto mkldnn_output = mkldnn_rnn(
543 input, params, has_biases ? 4 : 2,
544 hx, cx, static_cast<int>(mode), hidden_size, num_layers, has_biases, batch_first, dropout_p,
545 train, bidirectional, /*batch_sizes*/{});
546
547 return {std::get<0>(mkldnn_output),
548 pack_hidden<hidden_type>(std::get<1>(mkldnn_output), std::get<2>(mkldnn_output))};
549 }
550
lstm_mkldnn(Tensor & output,Tensor & hy,Tensor & cy,const Tensor & input,TensorList hx,TensorList params,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first)551 void lstm_mkldnn(Tensor& output, Tensor& hy, Tensor& cy,
552 const Tensor& input, TensorList hx, TensorList params, bool has_biases,
553 int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
554 auto result = mkldnn_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases,
555 ideep::rnn_kind::LSTM, num_layers, dropout_p, train, bidirectional, batch_first);
556 output = result.first;
557 hy = std::get<0>(result.second);
558 cy = std::get<1>(result.second);
559 }
560 } // anonymous namespace
561
562 REGISTER_ALL_CPU_DISPATCH(lstm_mkldnn_stub, &lstm_mkldnn);
563
564 } // namespace at::native
565
566 #endif // AT_MKLDNN_ENABLED
567