xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/mkldnn/MKLDNNCommon.h>
5 #include <ATen/native/mkldnn/Utils.h>
6 #include <ATen/native/utils/ParamUtils.h>
7 #include <torch/library.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/_to_dense_native.h>
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/empty_native.h>
17 #include <ATen/ops/from_blob.h>
18 #include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
19 #include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
20 #include <ATen/ops/to_mkldnn_native.h>
21 #include <ATen/ops/zeros.h>
22 #endif
23 
24 
25 namespace at { namespace native {
26 
27 #if AT_MKLDNN_ENABLED()
28 
mkldnn_to_dense(const Tensor & mkldnn_tensor,std::optional<ScalarType> dtype,std::optional<bool> masked_grad)29 Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, std::optional<ScalarType> dtype, std::optional<bool> masked_grad) {
30   TORCH_CHECK(mkldnn_tensor.scalar_type() == ScalarType::Float ||
31               mkldnn_tensor.scalar_type() == ScalarType::BFloat16 ||
32               mkldnn_tensor.scalar_type() == ScalarType::Half ||
33               mkldnn_tensor.scalar_type() == ScalarType::Byte ||
34               mkldnn_tensor.scalar_type() == ScalarType::Char,
35               "mkldnn_to_dense expects float, bfloat16, half, uint8, int8 tensor input");
36   ideep::tensor& stensor = itensor_from_mkldnn(mkldnn_tensor);
37   auto dims = stensor.get_dims();
38   auto data_type = dtype.has_value() ? dtype.value() : mkldnn_tensor.scalar_type();
39   TORCH_CHECK(data_type == ScalarType::Float ||
40               data_type == ScalarType::BFloat16 ||
41               data_type == ScalarType::Half ||
42               data_type == ScalarType::Byte ||
43               data_type == ScalarType::Char,
44               "mkldnn tensor only can be converted to be a float, bfloat16, Half, uint8, int8 cpu tensor")
45   if (mkldnn_tensor.scalar_type() == ScalarType::Byte || mkldnn_tensor.scalar_type() == ScalarType::Char) {
46     // For int8, uint8 input, we should not change the data type.
47     TORCH_CHECK(mkldnn_tensor.scalar_type() == data_type,
48             "For int8, uint8 mkldnn_tensor input, we should not change the data type.");
49   }
50   // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
51   Tensor cpu_tensor = at::empty(
52     std::vector<int64_t>(dims.begin(), dims.end()),
53     mkldnn_tensor.options().layout(c10::kStrided).dtype(data_type));
54   if (stensor.is_empty()) return cpu_tensor;
55   auto pub_tensor =
56       data_type == ScalarType::Float
57       ? stensor.to_public(cpu_tensor.template data_ptr<float>(),
58                           ideep::tensor::data_type::f32)
59       : (data_type == ScalarType::BFloat16
60          ? stensor.to_public(cpu_tensor.template data_ptr<BFloat16>(),
61                          ideep::tensor::data_type::bf16)
62          : (data_type == ScalarType::Half
63             ? stensor.to_public(cpu_tensor.template data_ptr<Half>(),
64                             ideep::tensor::data_type::f16)
65           : (data_type == ScalarType::Byte
66               ? stensor.to_public(cpu_tensor.template data_ptr<uint8_t>(),
67                               ideep::tensor::data_type::u8)
68               : stensor.to_public(cpu_tensor.template data_ptr<int8_t>(),
69                               ideep::tensor::data_type::s8)
70             )
71            )
72       );
73   cpu_tensor.as_strided_(dims, pub_tensor.get_strides());
74   // Make sure that NC11 strides follow formula of contiguous tensor.
75   return cpu_tensor.contiguous().resize_(dims, c10::MemoryFormat::Contiguous);
76 }
77 
dense_to_mkldnn(const Tensor & cpu_tensor,std::optional<ScalarType> dtype)78 Tensor dense_to_mkldnn(const Tensor& cpu_tensor, std::optional<ScalarType> dtype) {
79   TORCH_CHECK(cpu_tensor.device().is_cpu(),
80              "dense_to_mkldnn expects CPU tensor input");
81   TORCH_CHECK(cpu_tensor.layout() == Layout::Strided,
82              "dense_to_mkldnn expects strided tensor input");
83   TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Float ||
84               cpu_tensor.scalar_type() == ScalarType::BFloat16 ||
85               cpu_tensor.scalar_type() == ScalarType::Half ||
86               cpu_tensor.scalar_type() == ScalarType::Byte ||
87               cpu_tensor.scalar_type() == ScalarType::Char,
88              "dense_to_mkldnn expects float, bfloat16, half, uint8, int8 tensor input");
89   TORCH_CHECK(cpu_tensor.dim() <= 5,
90              "Can't convert cpu tensor with the number of dimensions > 5");
91   // NOTE: forbid direct convert from non-contiguous (or channels last) to `ideep::tensor`.
92   auto cpu_tensor_cont = cpu_tensor.contiguous();
93   auto data_type = dtype.has_value() ? dtype.value() : cpu_tensor.scalar_type();
94   if (cpu_tensor.scalar_type() == ScalarType::Byte || cpu_tensor.scalar_type() == ScalarType::Char) {
95     // For int8, uint8 input, we should not change the data type.
96     TORCH_CHECK(cpu_tensor.scalar_type() == data_type,
97             "For int8, uint8 cpu_tensor input, we should not change the data type.");
98   }
99   TORCH_CHECK(data_type == ScalarType::Float ||
100               data_type == ScalarType::BFloat16 ||
101               data_type == ScalarType::Half ||
102               data_type == ScalarType::Byte ||
103               data_type == ScalarType::Char,
104               "cpu tensor only can be converted to be a float, bfloat16, half, uint8, int8 mkldnn tensor")
105   Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), data_type,
106                                       cpu_tensor_cont.options().layout_opt(), cpu_tensor_cont.options().device_opt(),
107                                       cpu_tensor_cont.options().pinned_memory_opt());
108   ideep::tensor& dtensor = itensor_from_mkldnn(mkldnn_tensor);
109   if (cpu_tensor.scalar_type() == ScalarType::Float) {
110     dtensor.feed_from(dtensor.get_dims(),
111                       ideep::tensor::data_type::f32,
112                       (cpu_tensor_cont.template data_ptr<float>()));
113   } else if (cpu_tensor.scalar_type() == ScalarType::BFloat16) {
114     dtensor.feed_from(dtensor.get_dims(),
115                       ideep::tensor::data_type::bf16,
116                       cpu_tensor_cont.template data_ptr<BFloat16>());
117   } else if (cpu_tensor.scalar_type() == ScalarType::Half) {
118     dtensor.feed_from(dtensor.get_dims(),
119                       ideep::tensor::data_type::f16,
120                       cpu_tensor_cont.template data_ptr<Half>());
121   } else if (cpu_tensor.scalar_type() == ScalarType::Byte) {
122     dtensor.feed_from(dtensor.get_dims(),
123                       ideep::tensor::data_type::u8,
124                       cpu_tensor_cont.template data_ptr<uint8_t>());
125   } else {
126     TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Char,
127             "Expect int8 input of cpu_tensor");
128     dtensor.feed_from(dtensor.get_dims(),
129                       ideep::tensor::data_type::s8,
130                       cpu_tensor_cont.template data_ptr<int8_t>());
131   }
132   return mkldnn_tensor;
133 }
134 
135 // Mkldnn tensor has special non-public format for conv2d weights
136 // (dense_to_mkldnn only converts dense tensor to mkldnn tensor with
137 // public format). Ideep conv kernel will do implicit reorder if the
138 // weight is not already in this optimized format. By the time I'm
139 // writing this note, we are seeing ~20% perf cost of doing the
140 // on-the-fly reorder.
mkldnn_reorder_conv2d_weight(const Tensor & self,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::OptionalArrayRef<int64_t> input_size)141 Tensor mkldnn_reorder_conv2d_weight(
142     const Tensor& self,
143     IntArrayRef padding,
144     IntArrayRef stride,
145     IntArrayRef dilation,
146     int64_t groups,
147     c10::OptionalArrayRef<int64_t> input_size) {
148   mkldnn_check_low_precision(self.scalar_type(), "mkldnn_reorder_conv2d_weight");
149   const auto padding_expanded = expand_param_if_needed(padding, "padding", 2);
150   const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
151   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 2);
152 
153   ideep::dims src_dims = ideep::dims();
154   bool is_channels_last = false;
155   auto memory_format = at::MemoryFormat::Contiguous;
156   if (input_size.has_value()) {
157     src_dims = input_size.value().vec();
158     // if has input size, we always use channels last.
159     is_channels_last = true;
160     memory_format = at::MemoryFormat::ChannelsLast;
161   }
162 
163   auto self_ = self.is_mkldnn() ? self : self.contiguous(memory_format);
164   auto w = itensor_from_tensor(self_);
165 
166   // Legacy mkldnn conv2d jitted module may contain a 5-d weight with an extra
167   // dimension when groups > 1, having dimension [g, o/g, i, h, w] instead of
168   // [o, i, h, w]. Ideally we should reorder the weight back in serialization.
169   // For backward compatibility, we squash the first two dims (g * o/g) back to
170   // its original form.
171   if (w.ndims() == 5) {
172     auto wdims = w.get_dims();
173     w.reshape({wdims[0] * wdims[1], wdims[2], wdims[3], wdims[4]});
174   }
175 
176   auto desc = ideep::convolution_forward::expected_weights_desc(
177       w.get_dims(),
178       w.get_data_type(),
179       stride_expanded,
180       padding_expanded,
181       padding_expanded,
182       dilation_expanded,
183       groups,
184       ideep::algorithm::convolution_direct,
185       ideep::prop_kind::forward,
186       w.get_data_type(),
187       src_dims,
188       ideep::attr_t(),
189       is_channels_last);
190   ideep::tensor result;
191   result.init(desc);
192   result.feed_from(w);
193 
194   return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()),
195                                  self.options().device_opt());
196 }
197 
mkldnn_reorder_conv3d_weight(const Tensor & self,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::OptionalArrayRef<int64_t> input_size)198 Tensor mkldnn_reorder_conv3d_weight(
199     const Tensor& self,
200     IntArrayRef padding,
201     IntArrayRef stride,
202     IntArrayRef dilation,
203     int64_t groups,
204     c10::OptionalArrayRef<int64_t> input_size) {
205   mkldnn_check_low_precision(self.scalar_type(), "mkldnn_reorder_conv3d_weight");
206   const auto padding_expanded = expand_param_if_needed(padding, "padding", 3);
207   const auto stride_expanded = expand_param_if_needed(stride, "stride", 3);
208   const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 3);
209 
210   ideep::dims src_dims = ideep::dims();
211   bool is_channels_last = false;
212   auto memory_format = at::MemoryFormat::Contiguous;
213   if (input_size.has_value()) {
214     src_dims = input_size.value().vec();
215     // if has input size, we always use channels last.
216     is_channels_last = true;
217     memory_format = at::MemoryFormat::ChannelsLast3d;
218   }
219 
220   auto self_ = self.is_mkldnn() ? self : self.contiguous(memory_format);
221   auto w = itensor_from_tensor(self_);
222 
223   auto desc = ideep::convolution_forward::expected_weights_desc(
224       w.get_dims(),
225       w.get_data_type(),
226       stride_expanded,
227       padding_expanded,
228       padding_expanded,
229       dilation_expanded,
230       groups,
231       ideep::algorithm::convolution_direct,
232       ideep::prop_kind::forward,
233       w.get_data_type(),
234       src_dims,
235       ideep::attr_t(),
236       is_channels_last);
237   ideep::tensor result;
238   result.init(desc);
239   result.feed_from(w);
240 
241   return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
242 }
243 
mkldnn_reorder_conv_weight(const Tensor & self,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::OptionalArrayRef<int64_t> input_size)244 static Tensor mkldnn_reorder_conv_weight(
245     const Tensor& self,
246     IntArrayRef padding,
247     IntArrayRef stride,
248     IntArrayRef dilation,
249     int64_t groups,
250     c10::OptionalArrayRef<int64_t> input_size) {
251   TORCH_CHECK((self.dim() == 4 || self.dim() == 5), "mkldnn_reorder_conv_weight only supports conv2d and conv3d");
252   if (self.dim() == 4) {
253     return at::native::mkldnn_reorder_conv2d_weight(self, padding, stride, dilation, groups, input_size);
254   } else {
255     return at::native::mkldnn_reorder_conv3d_weight(self, padding, stride, dilation, groups, input_size);
256   }
257 }
258 
mkldnn_reorder_linear_weight(const Tensor & self,std::optional<int64_t> batch_size_opt)259 static Tensor mkldnn_reorder_linear_weight(
260     const Tensor& self,
261     std::optional<int64_t> batch_size_opt) {
262   mkldnn_check_low_precision(self.scalar_type(), "mkldnn_reorder_linear_weight");
263   auto out_features = self.size(0);
264   auto in_features = self.size(1);
265   auto self_ = self.contiguous();
266   auto w = itensor_from_tensor(self_);
267   ideep::dims input_size;
268   auto dtype = w.get_data_type();
269   if (batch_size_opt.has_value()) {
270     input_size = {batch_size_opt.value(), in_features};
271   }
272   auto packed_desc = ideep::inner_product_forward::expected_weights_desc(
273       {out_features, in_features},
274       input_size,
275       /* weight dtype */ dtype,
276       /* src dtype */ dtype);
277   ideep::tensor result;
278   result.init(packed_desc);
279   result.feed_from(w);
280   return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
281 }
282 
get_conv_transpose_expected_weights_desc(const ideep::tensor::dims & weights_dims,ideep::tensor::data_type w_dtype,const ideep::tensor::dims & strides,const ideep::tensor::dims & padding_l,const ideep::tensor::dims & padding_r,const ideep::tensor::dims & dilates,int groups,bool channels_last,ideep::algorithm aalgorithm,ideep::data_type x_dtype,const ideep::dims & src_dims)283 static ideep::tensor::desc get_conv_transpose_expected_weights_desc(
284     const ideep::tensor::dims& weights_dims,
285     ideep::tensor::data_type w_dtype,
286     const ideep::tensor::dims& strides,
287     const ideep::tensor::dims& padding_l,
288     const ideep::tensor::dims& padding_r,
289     const ideep::tensor::dims& dilates,
290     int groups,
291     bool channels_last,
292     ideep::algorithm aalgorithm,
293     ideep::data_type x_dtype,
294     const ideep::dims& src_dims) {
295   if (channels_last) {
296     return ideep::convolution_transpose_forward::expected_weights_desc<true>(
297         weights_dims,
298         w_dtype,
299         strides,
300         padding_l,
301         padding_r,
302         dilates,
303         groups,
304         aalgorithm,
305         ideep::prop_kind::forward,
306         src_dims);
307   } else {
308     return ideep::convolution_transpose_forward::expected_weights_desc<false>(
309         weights_dims,
310         w_dtype,
311         strides,
312         padding_l,
313         padding_r,
314         dilates,
315         groups,
316         aalgorithm,
317         ideep::prop_kind::forward,
318         src_dims);
319   }
320 }
321 
mkldnn_reorder_conv_transpose_weight(const Tensor & self,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::OptionalArrayRef<int64_t> input_size)322 static Tensor mkldnn_reorder_conv_transpose_weight(
323     const Tensor& self,
324     IntArrayRef padding,
325     IntArrayRef output_padding,
326     IntArrayRef stride,
327     IntArrayRef dilation,
328     int64_t groups,
329     c10::OptionalArrayRef<int64_t> input_size) {
330   TORCH_CHECK(
331       (self.dim() == 4 || self.dim() == 5),
332       "mkldnn_reorder_conv_transpose_weight only supports conv_transpose2d and conv_transpose3d");
333   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
334   mkldnn_check_low_precision(
335       self.scalar_type(), "mkldnn_reorder_conv_transpose_weight");
336   int64_t pdim = self.dim() - 2;
337   const auto padding_expanded =
338       expand_param_if_needed(padding, "padding", pdim);
339   const auto stride_expanded = expand_param_if_needed(stride, "stride", pdim);
340   const auto dilation_expanded =
341       expand_param_if_needed(dilation, "dilation", pdim);
342   const auto output_padding_expanded =
343       expand_param_if_needed(output_padding, "output_padding", pdim);
344 
345   ideep::dims src_dims = ideep::dims();
346   bool is_channels_last = false;
347   auto memory_format = at::MemoryFormat::Contiguous;
348   if (input_size.has_value()) {
349     src_dims = input_size.value().vec();
350     // if has input size, we always use channels last.
351     is_channels_last = true;
352     memory_format = self.dim() == 4 ? at::MemoryFormat::ChannelsLast
353                                     : at::MemoryFormat::ChannelsLast3d;
354   }
355 
356   auto self_ = self.contiguous(memory_format);
357   ideep::tensor w = itensor_from_tensor(self_);
358 
359   auto expected_desc = get_conv_transpose_expected_weights_desc(
360       w.get_dims(),
361       w.get_data_type(),
362       stride_expanded,
363       padding_expanded,
364       padding_r(padding_expanded, output_padding_expanded),
365       dilation_expanded,
366       groups,
367       is_channels_last,
368       ideep::algorithm::deconvolution_direct,
369       w.get_data_type(),
370       src_dims);
371 
372   if (groups > 1) {
373     expected_desc = expected_desc.transpose(1, 2);
374   } else {
375     expected_desc = expected_desc.transpose(0, 1);
376   }
377 
378   ideep::tensor result;
379   result.init(expected_desc);
380   w.transpose_(0, 1);
381   result.feed_from(w, /*is_deconv_weights*/true);
382 
383   return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()),
384                                  self.options().device_opt());
385 }
386 
get_lstm_packed_weights(const at::Tensor & weight_ih,const at::Tensor & weight_hh,const at::Tensor & weight2,const at::Tensor & weight3,int64_t layer_feature_size,int64_t hidden_size,bool has_biases,int64_t num_layers,bool bidirectional,int64_t time_step,int64_t batch_size,bool reverse)387 static std::tuple<ideep::tensor, ideep::tensor> get_lstm_packed_weights(
388     const at::Tensor& weight_ih,
389     const at::Tensor& weight_hh,
390     const at::Tensor& weight2,
391     const at::Tensor& weight3,
392     int64_t layer_feature_size,
393     int64_t hidden_size,
394     bool has_biases,
395     int64_t num_layers,
396     bool bidirectional,
397     int64_t time_step,
398     int64_t batch_size,
399     bool reverse) {
400 
401   ideep::tensor cached_weight_ih, cached_weight_hh;
402 
403   int64_t num_gates = 4;
404   int64_t num_bias_gates = 4;
405   std::vector<int64_t> output_sizes = {time_step, batch_size, hidden_size};
406 
407   auto dtype = get_mkldnn_dtype(weight_ih.scalar_type());
408   ideep::tensor::desc src_layer_desc({time_step, batch_size, layer_feature_size}, dtype, ideep::format_tag::tnc);
409   ideep::tensor::desc src_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
410   ideep::tensor::desc src_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
411   ideep::tensor::desc bias_desc({1, 1, num_bias_gates, hidden_size}, dtype, ideep::format_tag::ldgo);
412 
413   ideep::tensor::desc dst_layer_desc({time_step, batch_size, hidden_size}, dtype, ideep::format_tag::tnc);
414   ideep::tensor::desc dst_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
415   ideep::tensor::desc dst_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
416 
417   ideep::tensor src_layer(src_layer_desc);
418   ideep::tensor src_iter(src_iter_desc);
419   ideep::tensor src_iter_c(src_iter_c_desc);
420   ideep::tensor bias(bias_desc);
421 
422   auto w1 = itensor_view_from_dense(
423       weight_ih,
424       {{1, 1, layer_feature_size, num_gates, hidden_size},
425         get_mkldnn_dtype(weight_ih.scalar_type()),
426         ideep::format_tag::ldgoi});
427 
428   auto w2 = itensor_view_from_dense(
429       weight_hh,
430       {{1, 1, hidden_size, num_gates, hidden_size},
431         get_mkldnn_dtype(weight_hh.scalar_type()),
432         ideep::format_tag::ldgoi});
433 
434   auto [packed_desc_ih, packed_desc_hh] =
435       ideep::lstm_forward_inference::expected_weights_desc(
436           output_sizes,
437           src_layer,
438           src_iter,
439           src_iter_c,
440           w1,
441           w2,
442           bias,
443           reverse);
444 
445   cached_weight_ih.init(packed_desc_ih);
446   cached_weight_hh.init(packed_desc_hh);
447 
448   cached_weight_ih.feed_from(w1);
449   cached_weight_hh.feed_from(w2);
450 
451   return std::make_tuple(cached_weight_ih, cached_weight_hh);
452 }
453 
should_use_plain_format(ideep::tensor w)454 static bool should_use_plain_format(ideep::tensor w) {
455 #if defined(IDEEP_VERSION_MAJOR) && IDEEP_VERSION_MAJOR>=3
456   return w.get_desc().is_opaque() || w.get_desc().is_plain();
457 # else
458   return w.get_desc().is_rnn_packed() || w.get_desc().is_plain();
459 #endif
460 }
461 
mkldnn_reorder_mkldnn_rnn_layer_weight(Tensor weight0,Tensor weight1,int64_t hidden_size,bool reverse,bool has_biases,bool batch_first,c10::OptionalArrayRef<int64_t> input_size)462 static std::vector<Tensor> mkldnn_reorder_mkldnn_rnn_layer_weight(
463  Tensor weight0,
464  Tensor weight1,
465  int64_t hidden_size,
466  bool reverse,
467  bool has_biases,
468  bool batch_first,
469  c10::OptionalArrayRef<int64_t> input_size) {
470 
471   std::vector<int64_t> input_size_value;
472   int64_t time_step, batch_size;
473   if (input_size.has_value()) {
474     input_size_value = input_size.value().vec();
475     int64_t time_index = batch_first ? 1: 0;
476     int64_t batch_size_index = batch_first ? 0: 1;
477 
478     time_step = input_size_value[time_index];
479     batch_size = input_size_value[batch_size_index];
480   } else {
481     // no value fed, provide one here
482     time_step = 5;
483     batch_size = 10;
484   }
485 
486   at::Tensor packed_w1, packed_w2;
487 
488   int64_t feature_size = weight0.size(-1);
489 
490   auto [w1_, w2_] = get_lstm_packed_weights(
491     weight0,
492     weight1,
493     at::zeros(
494       weight0.sizes(),
495       weight0.options()),
496     at::zeros(
497       weight1.sizes(),
498       weight1.options()),
499     feature_size,
500     hidden_size,
501     has_biases, // has_biases
502     1, // num_layers
503     false, // bidirectional
504     time_step,
505     batch_size,
506     reverse);
507 
508   if (should_use_plain_format(w1_)) {
509     packed_w1 = weight0;
510   } else {
511     packed_w1 = new_with_itensor_mkldnn(std::move(w1_), optTypeMetaToScalarType(weight0.options().dtype_opt()), weight0.options().device_opt());
512   }
513 
514   if (should_use_plain_format(w2_)) {
515     packed_w2 = weight1;
516   } else {
517     packed_w2 = new_with_itensor_mkldnn(std::move(w2_), optTypeMetaToScalarType(weight1.options().dtype_opt()), weight1.options().device_opt());
518   }
519   return {packed_w1, packed_w2};
520 }
521 
get_mkldnn_serialized_md(const Tensor & self)522 static Tensor get_mkldnn_serialized_md(const Tensor& self) {
523   const ideep::tensor packed_w = itensor_from_tensor(self);
524   auto packed_w_desc = packed_w.get_desc();
525   std::vector<uint8_t> serialized_wei_desc;
526 
527 #if IDEEP_PREREQ(3, 4, 1, 2)
528   serialized_wei_desc = packed_w_desc.get_blob();
529 #else
530       TORCH_CHECK(false, "Unexpected IDeep version to do weight serialization.");
531 #endif
532   Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {(int64_t)serialized_wei_desc.size()}, at::TensorOptions(at::kByte));
533   auto res = at::empty_like(serialized_md);
534   // serialized_md shares the buffer with serialized_wei_desc,
535   // which will be released outside of this function thus invalidating the buffer of serialized_md.
536   // A copy is needed here so that res has its own buffer, which remains valid even after serialized_wei_desc is released.
537   res.copy_(serialized_md);
538   return res;
539 }
540 
TORCH_LIBRARY_IMPL(mkldnn,CPU,m)541 TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
542   m.impl(
543       TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
544       TORCH_FN(mkldnn_reorder_conv_transpose_weight));
545   m.impl(
546       TORCH_SELECTIVE_NAME("mkldnn::_reorder_linear_weight"),
547       TORCH_FN(mkldnn_reorder_linear_weight));
548   m.impl(
549       TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_weight"),
550       TORCH_FN(mkldnn_reorder_conv_weight));
551   m.impl(
552       TORCH_SELECTIVE_NAME("mkldnn::_reorder_mkldnn_rnn_layer_weight"),
553       TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
554 }
555 
TORCH_LIBRARY_IMPL(mkldnn,MkldnnCPU,m)556 TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
557   m.impl(
558       TORCH_SELECTIVE_NAME("mkldnn::_get_mkldnn_serialized_md"),
559       TORCH_FN(get_mkldnn_serialized_md ));
560 }
561 
562 #else
563 
564 Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, std::optional<ScalarType> dtype, std::optional<bool> masked_grad) {
565   TORCH_CHECK(false, "MKL-DNN build is disabled");
566 }
567 
568 Tensor dense_to_mkldnn(const Tensor& cpu_tensor, std::optional<ScalarType> dtype) {
569   TORCH_CHECK(false, "MKL-DNN build is disabled");
570 }
571 
572 Tensor mkldnn_reorder_conv2d_weight(
573     const Tensor& self,
574     IntArrayRef padding,
575     IntArrayRef stride,
576     IntArrayRef dilation,
577     int64_t groups,
578     c10::OptionalArrayRef<int64_t> input_size) {
579   TORCH_CHECK(false, "mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
580 }
581 
582 Tensor mkldnn_reorder_conv3d_weight(
583     const Tensor& self,
584     IntArrayRef padding,
585     IntArrayRef stride,
586     IntArrayRef dilation,
587     int64_t groups,
588     c10::OptionalArrayRef<int64_t> input_size) {
589   TORCH_CHECK(false, "mkldnn_reorder_conv3d_weight: MKL-DNN build is disabled");
590 }
591 
592 #endif // AT_MKLDNN_ENABLED()
593 
594 #if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED()
595 #include <mkl.h>
596 
mkl_reorder_linear_weight(const Tensor & weight,const int64_t batch_size)597 static Tensor mkl_reorder_linear_weight(
598     const Tensor& weight,
599     const int64_t batch_size) {
600   TORCH_CHECK(
601       weight.scalar_type() == ScalarType::Float,
602       "reorder_linear_weight: weight's dtype should be float");
603   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
604   auto M = batch_size;
605   auto N = weight.size(0);
606   auto K = weight.size(1);
607   int64_t pack_size =
608       (int64_t)(cblas_sgemm_pack_get_size(CblasBMatrix, M, N, K) / sizeof(float) + 1);
609   auto packed_weight = empty_mkldnn(
610       {pack_size, 1},
611       weight.scalar_type(),
612       weight.options().layout_opt(),
613       weight.options().device_opt(),
614       weight.options().pinned_memory_opt());
615   ideep::tensor& mkl_weight = itensor_from_mkldnn(packed_weight);
616   auto weight_ = weight.contiguous();
617   const ideep::tensor orig_w = itensor_view_from_dense(weight_);
618   cblas_sgemm_pack(
619       CblasRowMajor,
620       CblasBMatrix,
621       CblasTrans,
622       M,
623       N,
624       K,
625       1.0f,
626       (float*)(orig_w.get_data_handle()),
627       K,
628       (float*)(mkl_weight.get_data_handle()));
629   return packed_weight;
630 }
631 
TORCH_LIBRARY_IMPL(mkl,CPU,m)632 TORCH_LIBRARY_IMPL(mkl, CPU, m) {
633   m.impl(
634     TORCH_SELECTIVE_NAME("mkl::_mkl_reorder_linear_weight"),
635     TORCH_FN(mkl_reorder_linear_weight));
636 }
637 
638 #endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED
639 
640 }}
641