1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/mkldnn/MKLDNNCommon.h>
5 #include <ATen/native/mkldnn/Utils.h>
6 #include <ATen/native/utils/ParamUtils.h>
7 #include <torch/library.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/_to_dense_native.h>
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/empty_native.h>
17 #include <ATen/ops/from_blob.h>
18 #include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
19 #include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
20 #include <ATen/ops/to_mkldnn_native.h>
21 #include <ATen/ops/zeros.h>
22 #endif
23
24
25 namespace at { namespace native {
26
27 #if AT_MKLDNN_ENABLED()
28
mkldnn_to_dense(const Tensor & mkldnn_tensor,std::optional<ScalarType> dtype,std::optional<bool> masked_grad)29 Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, std::optional<ScalarType> dtype, std::optional<bool> masked_grad) {
30 TORCH_CHECK(mkldnn_tensor.scalar_type() == ScalarType::Float ||
31 mkldnn_tensor.scalar_type() == ScalarType::BFloat16 ||
32 mkldnn_tensor.scalar_type() == ScalarType::Half ||
33 mkldnn_tensor.scalar_type() == ScalarType::Byte ||
34 mkldnn_tensor.scalar_type() == ScalarType::Char,
35 "mkldnn_to_dense expects float, bfloat16, half, uint8, int8 tensor input");
36 ideep::tensor& stensor = itensor_from_mkldnn(mkldnn_tensor);
37 auto dims = stensor.get_dims();
38 auto data_type = dtype.has_value() ? dtype.value() : mkldnn_tensor.scalar_type();
39 TORCH_CHECK(data_type == ScalarType::Float ||
40 data_type == ScalarType::BFloat16 ||
41 data_type == ScalarType::Half ||
42 data_type == ScalarType::Byte ||
43 data_type == ScalarType::Char,
44 "mkldnn tensor only can be converted to be a float, bfloat16, Half, uint8, int8 cpu tensor")
45 if (mkldnn_tensor.scalar_type() == ScalarType::Byte || mkldnn_tensor.scalar_type() == ScalarType::Char) {
46 // For int8, uint8 input, we should not change the data type.
47 TORCH_CHECK(mkldnn_tensor.scalar_type() == data_type,
48 "For int8, uint8 mkldnn_tensor input, we should not change the data type.");
49 }
50 // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
51 Tensor cpu_tensor = at::empty(
52 std::vector<int64_t>(dims.begin(), dims.end()),
53 mkldnn_tensor.options().layout(c10::kStrided).dtype(data_type));
54 if (stensor.is_empty()) return cpu_tensor;
55 auto pub_tensor =
56 data_type == ScalarType::Float
57 ? stensor.to_public(cpu_tensor.template data_ptr<float>(),
58 ideep::tensor::data_type::f32)
59 : (data_type == ScalarType::BFloat16
60 ? stensor.to_public(cpu_tensor.template data_ptr<BFloat16>(),
61 ideep::tensor::data_type::bf16)
62 : (data_type == ScalarType::Half
63 ? stensor.to_public(cpu_tensor.template data_ptr<Half>(),
64 ideep::tensor::data_type::f16)
65 : (data_type == ScalarType::Byte
66 ? stensor.to_public(cpu_tensor.template data_ptr<uint8_t>(),
67 ideep::tensor::data_type::u8)
68 : stensor.to_public(cpu_tensor.template data_ptr<int8_t>(),
69 ideep::tensor::data_type::s8)
70 )
71 )
72 );
73 cpu_tensor.as_strided_(dims, pub_tensor.get_strides());
74 // Make sure that NC11 strides follow formula of contiguous tensor.
75 return cpu_tensor.contiguous().resize_(dims, c10::MemoryFormat::Contiguous);
76 }
77
dense_to_mkldnn(const Tensor & cpu_tensor,std::optional<ScalarType> dtype)78 Tensor dense_to_mkldnn(const Tensor& cpu_tensor, std::optional<ScalarType> dtype) {
79 TORCH_CHECK(cpu_tensor.device().is_cpu(),
80 "dense_to_mkldnn expects CPU tensor input");
81 TORCH_CHECK(cpu_tensor.layout() == Layout::Strided,
82 "dense_to_mkldnn expects strided tensor input");
83 TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Float ||
84 cpu_tensor.scalar_type() == ScalarType::BFloat16 ||
85 cpu_tensor.scalar_type() == ScalarType::Half ||
86 cpu_tensor.scalar_type() == ScalarType::Byte ||
87 cpu_tensor.scalar_type() == ScalarType::Char,
88 "dense_to_mkldnn expects float, bfloat16, half, uint8, int8 tensor input");
89 TORCH_CHECK(cpu_tensor.dim() <= 5,
90 "Can't convert cpu tensor with the number of dimensions > 5");
91 // NOTE: forbid direct convert from non-contiguous (or channels last) to `ideep::tensor`.
92 auto cpu_tensor_cont = cpu_tensor.contiguous();
93 auto data_type = dtype.has_value() ? dtype.value() : cpu_tensor.scalar_type();
94 if (cpu_tensor.scalar_type() == ScalarType::Byte || cpu_tensor.scalar_type() == ScalarType::Char) {
95 // For int8, uint8 input, we should not change the data type.
96 TORCH_CHECK(cpu_tensor.scalar_type() == data_type,
97 "For int8, uint8 cpu_tensor input, we should not change the data type.");
98 }
99 TORCH_CHECK(data_type == ScalarType::Float ||
100 data_type == ScalarType::BFloat16 ||
101 data_type == ScalarType::Half ||
102 data_type == ScalarType::Byte ||
103 data_type == ScalarType::Char,
104 "cpu tensor only can be converted to be a float, bfloat16, half, uint8, int8 mkldnn tensor")
105 Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), data_type,
106 cpu_tensor_cont.options().layout_opt(), cpu_tensor_cont.options().device_opt(),
107 cpu_tensor_cont.options().pinned_memory_opt());
108 ideep::tensor& dtensor = itensor_from_mkldnn(mkldnn_tensor);
109 if (cpu_tensor.scalar_type() == ScalarType::Float) {
110 dtensor.feed_from(dtensor.get_dims(),
111 ideep::tensor::data_type::f32,
112 (cpu_tensor_cont.template data_ptr<float>()));
113 } else if (cpu_tensor.scalar_type() == ScalarType::BFloat16) {
114 dtensor.feed_from(dtensor.get_dims(),
115 ideep::tensor::data_type::bf16,
116 cpu_tensor_cont.template data_ptr<BFloat16>());
117 } else if (cpu_tensor.scalar_type() == ScalarType::Half) {
118 dtensor.feed_from(dtensor.get_dims(),
119 ideep::tensor::data_type::f16,
120 cpu_tensor_cont.template data_ptr<Half>());
121 } else if (cpu_tensor.scalar_type() == ScalarType::Byte) {
122 dtensor.feed_from(dtensor.get_dims(),
123 ideep::tensor::data_type::u8,
124 cpu_tensor_cont.template data_ptr<uint8_t>());
125 } else {
126 TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Char,
127 "Expect int8 input of cpu_tensor");
128 dtensor.feed_from(dtensor.get_dims(),
129 ideep::tensor::data_type::s8,
130 cpu_tensor_cont.template data_ptr<int8_t>());
131 }
132 return mkldnn_tensor;
133 }
134
135 // Mkldnn tensor has special non-public format for conv2d weights
136 // (dense_to_mkldnn only converts dense tensor to mkldnn tensor with
137 // public format). Ideep conv kernel will do implicit reorder if the
138 // weight is not already in this optimized format. By the time I'm
139 // writing this note, we are seeing ~20% perf cost of doing the
140 // on-the-fly reorder.
mkldnn_reorder_conv2d_weight(const Tensor & self,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::OptionalArrayRef<int64_t> input_size)141 Tensor mkldnn_reorder_conv2d_weight(
142 const Tensor& self,
143 IntArrayRef padding,
144 IntArrayRef stride,
145 IntArrayRef dilation,
146 int64_t groups,
147 c10::OptionalArrayRef<int64_t> input_size) {
148 mkldnn_check_low_precision(self.scalar_type(), "mkldnn_reorder_conv2d_weight");
149 const auto padding_expanded = expand_param_if_needed(padding, "padding", 2);
150 const auto stride_expanded = expand_param_if_needed(stride, "stride", 2);
151 const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 2);
152
153 ideep::dims src_dims = ideep::dims();
154 bool is_channels_last = false;
155 auto memory_format = at::MemoryFormat::Contiguous;
156 if (input_size.has_value()) {
157 src_dims = input_size.value().vec();
158 // if has input size, we always use channels last.
159 is_channels_last = true;
160 memory_format = at::MemoryFormat::ChannelsLast;
161 }
162
163 auto self_ = self.is_mkldnn() ? self : self.contiguous(memory_format);
164 auto w = itensor_from_tensor(self_);
165
166 // Legacy mkldnn conv2d jitted module may contain a 5-d weight with an extra
167 // dimension when groups > 1, having dimension [g, o/g, i, h, w] instead of
168 // [o, i, h, w]. Ideally we should reorder the weight back in serialization.
169 // For backward compatibility, we squash the first two dims (g * o/g) back to
170 // its original form.
171 if (w.ndims() == 5) {
172 auto wdims = w.get_dims();
173 w.reshape({wdims[0] * wdims[1], wdims[2], wdims[3], wdims[4]});
174 }
175
176 auto desc = ideep::convolution_forward::expected_weights_desc(
177 w.get_dims(),
178 w.get_data_type(),
179 stride_expanded,
180 padding_expanded,
181 padding_expanded,
182 dilation_expanded,
183 groups,
184 ideep::algorithm::convolution_direct,
185 ideep::prop_kind::forward,
186 w.get_data_type(),
187 src_dims,
188 ideep::attr_t(),
189 is_channels_last);
190 ideep::tensor result;
191 result.init(desc);
192 result.feed_from(w);
193
194 return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()),
195 self.options().device_opt());
196 }
197
mkldnn_reorder_conv3d_weight(const Tensor & self,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::OptionalArrayRef<int64_t> input_size)198 Tensor mkldnn_reorder_conv3d_weight(
199 const Tensor& self,
200 IntArrayRef padding,
201 IntArrayRef stride,
202 IntArrayRef dilation,
203 int64_t groups,
204 c10::OptionalArrayRef<int64_t> input_size) {
205 mkldnn_check_low_precision(self.scalar_type(), "mkldnn_reorder_conv3d_weight");
206 const auto padding_expanded = expand_param_if_needed(padding, "padding", 3);
207 const auto stride_expanded = expand_param_if_needed(stride, "stride", 3);
208 const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 3);
209
210 ideep::dims src_dims = ideep::dims();
211 bool is_channels_last = false;
212 auto memory_format = at::MemoryFormat::Contiguous;
213 if (input_size.has_value()) {
214 src_dims = input_size.value().vec();
215 // if has input size, we always use channels last.
216 is_channels_last = true;
217 memory_format = at::MemoryFormat::ChannelsLast3d;
218 }
219
220 auto self_ = self.is_mkldnn() ? self : self.contiguous(memory_format);
221 auto w = itensor_from_tensor(self_);
222
223 auto desc = ideep::convolution_forward::expected_weights_desc(
224 w.get_dims(),
225 w.get_data_type(),
226 stride_expanded,
227 padding_expanded,
228 padding_expanded,
229 dilation_expanded,
230 groups,
231 ideep::algorithm::convolution_direct,
232 ideep::prop_kind::forward,
233 w.get_data_type(),
234 src_dims,
235 ideep::attr_t(),
236 is_channels_last);
237 ideep::tensor result;
238 result.init(desc);
239 result.feed_from(w);
240
241 return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
242 }
243
mkldnn_reorder_conv_weight(const Tensor & self,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::OptionalArrayRef<int64_t> input_size)244 static Tensor mkldnn_reorder_conv_weight(
245 const Tensor& self,
246 IntArrayRef padding,
247 IntArrayRef stride,
248 IntArrayRef dilation,
249 int64_t groups,
250 c10::OptionalArrayRef<int64_t> input_size) {
251 TORCH_CHECK((self.dim() == 4 || self.dim() == 5), "mkldnn_reorder_conv_weight only supports conv2d and conv3d");
252 if (self.dim() == 4) {
253 return at::native::mkldnn_reorder_conv2d_weight(self, padding, stride, dilation, groups, input_size);
254 } else {
255 return at::native::mkldnn_reorder_conv3d_weight(self, padding, stride, dilation, groups, input_size);
256 }
257 }
258
mkldnn_reorder_linear_weight(const Tensor & self,std::optional<int64_t> batch_size_opt)259 static Tensor mkldnn_reorder_linear_weight(
260 const Tensor& self,
261 std::optional<int64_t> batch_size_opt) {
262 mkldnn_check_low_precision(self.scalar_type(), "mkldnn_reorder_linear_weight");
263 auto out_features = self.size(0);
264 auto in_features = self.size(1);
265 auto self_ = self.contiguous();
266 auto w = itensor_from_tensor(self_);
267 ideep::dims input_size;
268 auto dtype = w.get_data_type();
269 if (batch_size_opt.has_value()) {
270 input_size = {batch_size_opt.value(), in_features};
271 }
272 auto packed_desc = ideep::inner_product_forward::expected_weights_desc(
273 {out_features, in_features},
274 input_size,
275 /* weight dtype */ dtype,
276 /* src dtype */ dtype);
277 ideep::tensor result;
278 result.init(packed_desc);
279 result.feed_from(w);
280 return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
281 }
282
get_conv_transpose_expected_weights_desc(const ideep::tensor::dims & weights_dims,ideep::tensor::data_type w_dtype,const ideep::tensor::dims & strides,const ideep::tensor::dims & padding_l,const ideep::tensor::dims & padding_r,const ideep::tensor::dims & dilates,int groups,bool channels_last,ideep::algorithm aalgorithm,ideep::data_type x_dtype,const ideep::dims & src_dims)283 static ideep::tensor::desc get_conv_transpose_expected_weights_desc(
284 const ideep::tensor::dims& weights_dims,
285 ideep::tensor::data_type w_dtype,
286 const ideep::tensor::dims& strides,
287 const ideep::tensor::dims& padding_l,
288 const ideep::tensor::dims& padding_r,
289 const ideep::tensor::dims& dilates,
290 int groups,
291 bool channels_last,
292 ideep::algorithm aalgorithm,
293 ideep::data_type x_dtype,
294 const ideep::dims& src_dims) {
295 if (channels_last) {
296 return ideep::convolution_transpose_forward::expected_weights_desc<true>(
297 weights_dims,
298 w_dtype,
299 strides,
300 padding_l,
301 padding_r,
302 dilates,
303 groups,
304 aalgorithm,
305 ideep::prop_kind::forward,
306 src_dims);
307 } else {
308 return ideep::convolution_transpose_forward::expected_weights_desc<false>(
309 weights_dims,
310 w_dtype,
311 strides,
312 padding_l,
313 padding_r,
314 dilates,
315 groups,
316 aalgorithm,
317 ideep::prop_kind::forward,
318 src_dims);
319 }
320 }
321
mkldnn_reorder_conv_transpose_weight(const Tensor & self,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,int64_t groups,c10::OptionalArrayRef<int64_t> input_size)322 static Tensor mkldnn_reorder_conv_transpose_weight(
323 const Tensor& self,
324 IntArrayRef padding,
325 IntArrayRef output_padding,
326 IntArrayRef stride,
327 IntArrayRef dilation,
328 int64_t groups,
329 c10::OptionalArrayRef<int64_t> input_size) {
330 TORCH_CHECK(
331 (self.dim() == 4 || self.dim() == 5),
332 "mkldnn_reorder_conv_transpose_weight only supports conv_transpose2d and conv_transpose3d");
333 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
334 mkldnn_check_low_precision(
335 self.scalar_type(), "mkldnn_reorder_conv_transpose_weight");
336 int64_t pdim = self.dim() - 2;
337 const auto padding_expanded =
338 expand_param_if_needed(padding, "padding", pdim);
339 const auto stride_expanded = expand_param_if_needed(stride, "stride", pdim);
340 const auto dilation_expanded =
341 expand_param_if_needed(dilation, "dilation", pdim);
342 const auto output_padding_expanded =
343 expand_param_if_needed(output_padding, "output_padding", pdim);
344
345 ideep::dims src_dims = ideep::dims();
346 bool is_channels_last = false;
347 auto memory_format = at::MemoryFormat::Contiguous;
348 if (input_size.has_value()) {
349 src_dims = input_size.value().vec();
350 // if has input size, we always use channels last.
351 is_channels_last = true;
352 memory_format = self.dim() == 4 ? at::MemoryFormat::ChannelsLast
353 : at::MemoryFormat::ChannelsLast3d;
354 }
355
356 auto self_ = self.contiguous(memory_format);
357 ideep::tensor w = itensor_from_tensor(self_);
358
359 auto expected_desc = get_conv_transpose_expected_weights_desc(
360 w.get_dims(),
361 w.get_data_type(),
362 stride_expanded,
363 padding_expanded,
364 padding_r(padding_expanded, output_padding_expanded),
365 dilation_expanded,
366 groups,
367 is_channels_last,
368 ideep::algorithm::deconvolution_direct,
369 w.get_data_type(),
370 src_dims);
371
372 if (groups > 1) {
373 expected_desc = expected_desc.transpose(1, 2);
374 } else {
375 expected_desc = expected_desc.transpose(0, 1);
376 }
377
378 ideep::tensor result;
379 result.init(expected_desc);
380 w.transpose_(0, 1);
381 result.feed_from(w, /*is_deconv_weights*/true);
382
383 return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()),
384 self.options().device_opt());
385 }
386
get_lstm_packed_weights(const at::Tensor & weight_ih,const at::Tensor & weight_hh,const at::Tensor & weight2,const at::Tensor & weight3,int64_t layer_feature_size,int64_t hidden_size,bool has_biases,int64_t num_layers,bool bidirectional,int64_t time_step,int64_t batch_size,bool reverse)387 static std::tuple<ideep::tensor, ideep::tensor> get_lstm_packed_weights(
388 const at::Tensor& weight_ih,
389 const at::Tensor& weight_hh,
390 const at::Tensor& weight2,
391 const at::Tensor& weight3,
392 int64_t layer_feature_size,
393 int64_t hidden_size,
394 bool has_biases,
395 int64_t num_layers,
396 bool bidirectional,
397 int64_t time_step,
398 int64_t batch_size,
399 bool reverse) {
400
401 ideep::tensor cached_weight_ih, cached_weight_hh;
402
403 int64_t num_gates = 4;
404 int64_t num_bias_gates = 4;
405 std::vector<int64_t> output_sizes = {time_step, batch_size, hidden_size};
406
407 auto dtype = get_mkldnn_dtype(weight_ih.scalar_type());
408 ideep::tensor::desc src_layer_desc({time_step, batch_size, layer_feature_size}, dtype, ideep::format_tag::tnc);
409 ideep::tensor::desc src_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
410 ideep::tensor::desc src_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
411 ideep::tensor::desc bias_desc({1, 1, num_bias_gates, hidden_size}, dtype, ideep::format_tag::ldgo);
412
413 ideep::tensor::desc dst_layer_desc({time_step, batch_size, hidden_size}, dtype, ideep::format_tag::tnc);
414 ideep::tensor::desc dst_iter_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
415 ideep::tensor::desc dst_iter_c_desc({1, 1, batch_size, hidden_size}, dtype, ideep::format_tag::ldnc);
416
417 ideep::tensor src_layer(src_layer_desc);
418 ideep::tensor src_iter(src_iter_desc);
419 ideep::tensor src_iter_c(src_iter_c_desc);
420 ideep::tensor bias(bias_desc);
421
422 auto w1 = itensor_view_from_dense(
423 weight_ih,
424 {{1, 1, layer_feature_size, num_gates, hidden_size},
425 get_mkldnn_dtype(weight_ih.scalar_type()),
426 ideep::format_tag::ldgoi});
427
428 auto w2 = itensor_view_from_dense(
429 weight_hh,
430 {{1, 1, hidden_size, num_gates, hidden_size},
431 get_mkldnn_dtype(weight_hh.scalar_type()),
432 ideep::format_tag::ldgoi});
433
434 auto [packed_desc_ih, packed_desc_hh] =
435 ideep::lstm_forward_inference::expected_weights_desc(
436 output_sizes,
437 src_layer,
438 src_iter,
439 src_iter_c,
440 w1,
441 w2,
442 bias,
443 reverse);
444
445 cached_weight_ih.init(packed_desc_ih);
446 cached_weight_hh.init(packed_desc_hh);
447
448 cached_weight_ih.feed_from(w1);
449 cached_weight_hh.feed_from(w2);
450
451 return std::make_tuple(cached_weight_ih, cached_weight_hh);
452 }
453
should_use_plain_format(ideep::tensor w)454 static bool should_use_plain_format(ideep::tensor w) {
455 #if defined(IDEEP_VERSION_MAJOR) && IDEEP_VERSION_MAJOR>=3
456 return w.get_desc().is_opaque() || w.get_desc().is_plain();
457 # else
458 return w.get_desc().is_rnn_packed() || w.get_desc().is_plain();
459 #endif
460 }
461
mkldnn_reorder_mkldnn_rnn_layer_weight(Tensor weight0,Tensor weight1,int64_t hidden_size,bool reverse,bool has_biases,bool batch_first,c10::OptionalArrayRef<int64_t> input_size)462 static std::vector<Tensor> mkldnn_reorder_mkldnn_rnn_layer_weight(
463 Tensor weight0,
464 Tensor weight1,
465 int64_t hidden_size,
466 bool reverse,
467 bool has_biases,
468 bool batch_first,
469 c10::OptionalArrayRef<int64_t> input_size) {
470
471 std::vector<int64_t> input_size_value;
472 int64_t time_step, batch_size;
473 if (input_size.has_value()) {
474 input_size_value = input_size.value().vec();
475 int64_t time_index = batch_first ? 1: 0;
476 int64_t batch_size_index = batch_first ? 0: 1;
477
478 time_step = input_size_value[time_index];
479 batch_size = input_size_value[batch_size_index];
480 } else {
481 // no value fed, provide one here
482 time_step = 5;
483 batch_size = 10;
484 }
485
486 at::Tensor packed_w1, packed_w2;
487
488 int64_t feature_size = weight0.size(-1);
489
490 auto [w1_, w2_] = get_lstm_packed_weights(
491 weight0,
492 weight1,
493 at::zeros(
494 weight0.sizes(),
495 weight0.options()),
496 at::zeros(
497 weight1.sizes(),
498 weight1.options()),
499 feature_size,
500 hidden_size,
501 has_biases, // has_biases
502 1, // num_layers
503 false, // bidirectional
504 time_step,
505 batch_size,
506 reverse);
507
508 if (should_use_plain_format(w1_)) {
509 packed_w1 = weight0;
510 } else {
511 packed_w1 = new_with_itensor_mkldnn(std::move(w1_), optTypeMetaToScalarType(weight0.options().dtype_opt()), weight0.options().device_opt());
512 }
513
514 if (should_use_plain_format(w2_)) {
515 packed_w2 = weight1;
516 } else {
517 packed_w2 = new_with_itensor_mkldnn(std::move(w2_), optTypeMetaToScalarType(weight1.options().dtype_opt()), weight1.options().device_opt());
518 }
519 return {packed_w1, packed_w2};
520 }
521
get_mkldnn_serialized_md(const Tensor & self)522 static Tensor get_mkldnn_serialized_md(const Tensor& self) {
523 const ideep::tensor packed_w = itensor_from_tensor(self);
524 auto packed_w_desc = packed_w.get_desc();
525 std::vector<uint8_t> serialized_wei_desc;
526
527 #if IDEEP_PREREQ(3, 4, 1, 2)
528 serialized_wei_desc = packed_w_desc.get_blob();
529 #else
530 TORCH_CHECK(false, "Unexpected IDeep version to do weight serialization.");
531 #endif
532 Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {(int64_t)serialized_wei_desc.size()}, at::TensorOptions(at::kByte));
533 auto res = at::empty_like(serialized_md);
534 // serialized_md shares the buffer with serialized_wei_desc,
535 // which will be released outside of this function thus invalidating the buffer of serialized_md.
536 // A copy is needed here so that res has its own buffer, which remains valid even after serialized_wei_desc is released.
537 res.copy_(serialized_md);
538 return res;
539 }
540
TORCH_LIBRARY_IMPL(mkldnn,CPU,m)541 TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
542 m.impl(
543 TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_transpose_weight"),
544 TORCH_FN(mkldnn_reorder_conv_transpose_weight));
545 m.impl(
546 TORCH_SELECTIVE_NAME("mkldnn::_reorder_linear_weight"),
547 TORCH_FN(mkldnn_reorder_linear_weight));
548 m.impl(
549 TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_weight"),
550 TORCH_FN(mkldnn_reorder_conv_weight));
551 m.impl(
552 TORCH_SELECTIVE_NAME("mkldnn::_reorder_mkldnn_rnn_layer_weight"),
553 TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
554 }
555
TORCH_LIBRARY_IMPL(mkldnn,MkldnnCPU,m)556 TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
557 m.impl(
558 TORCH_SELECTIVE_NAME("mkldnn::_get_mkldnn_serialized_md"),
559 TORCH_FN(get_mkldnn_serialized_md ));
560 }
561
562 #else
563
564 Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, std::optional<ScalarType> dtype, std::optional<bool> masked_grad) {
565 TORCH_CHECK(false, "MKL-DNN build is disabled");
566 }
567
568 Tensor dense_to_mkldnn(const Tensor& cpu_tensor, std::optional<ScalarType> dtype) {
569 TORCH_CHECK(false, "MKL-DNN build is disabled");
570 }
571
572 Tensor mkldnn_reorder_conv2d_weight(
573 const Tensor& self,
574 IntArrayRef padding,
575 IntArrayRef stride,
576 IntArrayRef dilation,
577 int64_t groups,
578 c10::OptionalArrayRef<int64_t> input_size) {
579 TORCH_CHECK(false, "mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
580 }
581
582 Tensor mkldnn_reorder_conv3d_weight(
583 const Tensor& self,
584 IntArrayRef padding,
585 IntArrayRef stride,
586 IntArrayRef dilation,
587 int64_t groups,
588 c10::OptionalArrayRef<int64_t> input_size) {
589 TORCH_CHECK(false, "mkldnn_reorder_conv3d_weight: MKL-DNN build is disabled");
590 }
591
592 #endif // AT_MKLDNN_ENABLED()
593
594 #if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED()
595 #include <mkl.h>
596
mkl_reorder_linear_weight(const Tensor & weight,const int64_t batch_size)597 static Tensor mkl_reorder_linear_weight(
598 const Tensor& weight,
599 const int64_t batch_size) {
600 TORCH_CHECK(
601 weight.scalar_type() == ScalarType::Float,
602 "reorder_linear_weight: weight's dtype should be float");
603 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
604 auto M = batch_size;
605 auto N = weight.size(0);
606 auto K = weight.size(1);
607 int64_t pack_size =
608 (int64_t)(cblas_sgemm_pack_get_size(CblasBMatrix, M, N, K) / sizeof(float) + 1);
609 auto packed_weight = empty_mkldnn(
610 {pack_size, 1},
611 weight.scalar_type(),
612 weight.options().layout_opt(),
613 weight.options().device_opt(),
614 weight.options().pinned_memory_opt());
615 ideep::tensor& mkl_weight = itensor_from_mkldnn(packed_weight);
616 auto weight_ = weight.contiguous();
617 const ideep::tensor orig_w = itensor_view_from_dense(weight_);
618 cblas_sgemm_pack(
619 CblasRowMajor,
620 CblasBMatrix,
621 CblasTrans,
622 M,
623 N,
624 K,
625 1.0f,
626 (float*)(orig_w.get_data_handle()),
627 K,
628 (float*)(mkl_weight.get_data_handle()));
629 return packed_weight;
630 }
631
TORCH_LIBRARY_IMPL(mkl,CPU,m)632 TORCH_LIBRARY_IMPL(mkl, CPU, m) {
633 m.impl(
634 TORCH_SELECTIVE_NAME("mkl::_mkl_reorder_linear_weight"),
635 TORCH_FN(mkl_reorder_linear_weight));
636 }
637
638 #endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED
639
640 }}
641