1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Context.h>
5 #include <ATen/native/mkldnn/Matmul.h>
6
7 #if !AT_MKLDNN_ENABLED()
8
9 namespace at {
10 namespace native {
11
mkldnn_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result,float beta,float alpha)12 void mkldnn_matmul(
13 const Tensor &mat1,
14 const Tensor &mat2,
15 const Tensor &result,
16 float beta,
17 float alpha) {
18 TORCH_CHECK(false, "mkldnn_matmul: ATen not compiled with MKLDNN support");
19 }
20
use_mkldnn_bf16_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result_opt)21 bool use_mkldnn_bf16_matmul(
22 const Tensor& mat1,
23 const Tensor& mat2,
24 const Tensor& result_opt){
25 return false;
26 }
27
use_mkldnn_fp16_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result_opt)28 bool use_mkldnn_fp16_matmul(
29 const Tensor& mat1,
30 const Tensor& mat2,
31 const Tensor& result_opt){
32 return false;
33 }
34
mkldnn_bf16_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const c10::BFloat16 * a,int64_t lda,const c10::BFloat16 * b,int64_t ldb,float beta,c10::BFloat16 * c,int64_t ldc)35 bool mkldnn_bf16_gemm(
36 TransposeType transa, TransposeType transb,
37 int64_t m, int64_t n, int64_t k,
38 float alpha,
39 const c10::BFloat16 *a, int64_t lda,
40 const c10::BFloat16 *b, int64_t ldb,
41 float beta,
42 c10::BFloat16 *c, int64_t ldc) {
43 return false;
44 }
45
mkldnn_fp16_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const c10::Half * a,int64_t lda,const c10::Half * b,int64_t ldb,float beta,c10::Half * c,int64_t ldc)46 bool mkldnn_fp16_gemm(
47 TransposeType transa, TransposeType transb,
48 int64_t m, int64_t n, int64_t k,
49 float alpha,
50 const c10::Half *a, int64_t lda,
51 const c10::Half *b, int64_t ldb,
52 float beta,
53 c10::Half *c, int64_t ldc) {
54 return false;
55 }
mkldnn_bf32_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const float * a,int64_t lda,const float * b,int64_t ldb,float beta,float * c,int64_t ldc)56 bool mkldnn_bf32_gemm(
57 TransposeType transa, TransposeType transb,
58 int64_t m, int64_t n, int64_t k,
59 float alpha,
60 const float *a, int64_t lda,
61 const float *b, int64_t ldb,
62 float beta,
63 float *c, int64_t ldc){
64 return false;
65 }
66
use_mkldnn_bf32_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)67 bool use_mkldnn_bf32_matmul(
68 const Tensor& mat1,
69 const Tensor& mat2,
70 const Tensor& result) {
71 return false;
72 }
73
use_mkldnn_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)74 bool use_mkldnn_matmul(
75 const Tensor& mat1,
76 const Tensor& mat2,
77 const Tensor& result) {
78 return false;
79 }
80
mkldnn_matmul_i8i8i32(const Tensor & mat1,const Tensor & mat2,const Tensor & result)81 void mkldnn_matmul_i8i8i32(
82 const Tensor &mat1,
83 const Tensor &mat2,
84 const Tensor &result) {
85 TORCH_INTERNAL_ASSERT(false, __func__, ": ATen not compiled with MKLDNN support");
86 }
87
88 } // namespace native
89 } // namespace at
90
91 #else // AT_MKLDNN_ENABLED
92
93 #include <ATen/native/mkldnn/MKLDNNCommon.h>
94 #include <ATen/native/mkldnn/Utils.h>
95
96 namespace at {
97 namespace native {
98
use_mkldnn_bf16_matmul()99 static bool use_mkldnn_bf16_matmul() {
100 return at::globalContext().userEnabledMkldnn() && mkldnn_bf16_device_check();
101 }
102
use_mkldnn_fp16_matmul()103 static bool use_mkldnn_fp16_matmul() {
104 return at::globalContext().userEnabledMkldnn() && mkldnn_fp16_device_check();
105 }
106
use_mkldnn_bf32_matmul()107 static bool use_mkldnn_bf32_matmul() {
108 return use_mkldnn_bf16_matmul() && at::globalContext().float32MatmulPrecision() == at::Float32MatmulPrecision::MEDIUM;
109 }
110
111
112 template<typename scalar_t>
113 inline typename std::enable_if_t<
114 std::is_same_v<scalar_t, float> ||
115 std::is_same_v<scalar_t, c10::Half> ||
116 std::is_same_v<scalar_t, c10::BFloat16>,
117 bool>
mkldnn_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const scalar_t * a_data,int64_t lda,const scalar_t * b_data,int64_t ldb,float beta,scalar_t * c_data,int64_t ldc)118 mkldnn_gemm(
119 TransposeType transa, TransposeType transb,
120 int64_t m, int64_t n, int64_t k,
121 float alpha,
122 const scalar_t *a_data, int64_t lda,
123 const scalar_t *b_data, int64_t ldb,
124 float beta,
125 scalar_t *c_data, int64_t ldc) {
126 bool bf16_usable = std::is_same_v<scalar_t, c10::BFloat16> && use_mkldnn_bf16_matmul();
127 bool fp16_usable = std::is_same_v<scalar_t, c10::Half> && use_mkldnn_fp16_matmul();
128 bool bf32_usable = std::is_same_v<scalar_t, float> && use_mkldnn_bf32_matmul();
129 if ( !(bf16_usable || fp16_usable || bf32_usable) ||
130 (m * n * k <= 16 * 16 * 16) || (alpha == 0.0f)) {
131 return false;
132 }
133
134 ideep::attr_t op_attr;
135 // Use mkldnn post ops to perform the add.
136 if (beta != 0.0f) {
137 op_attr = ideep::attr_t::fuse_sum();
138 }
139 if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path
140
141 // NOTE: View as c-contiguous to avoid extra reordering in mkldnn
142 // Use identity: C = AB <=> C^T = B^T A^T
143 ideep::tensor::dims a_strides{{lda, 1}}, b_strides{{ldb, 1}}, c_strides{{ldc, 1}};
144 if (transa != TransposeType::NoTranspose) {
145 std::swap(a_strides[0], a_strides[1]);
146 }
147 if (transb != TransposeType::NoTranspose) {
148 std::swap(b_strides[0], b_strides[1]);
149 }
150
151 auto idtype = ideep::tensor::data_type::bf16;
152 if constexpr (std::is_same_v<scalar_t, c10::Half>) {
153 idtype = ideep::tensor::data_type::f16;
154 }
155 if constexpr (std::is_same_v<scalar_t, float>) {
156 idtype = ideep::tensor::data_type::f32;
157 }
158
159 ideep::tensor a({
160 /*sizes=*/{k, m},
161 idtype,
162 /*strides=*/a_strides},
163 const_cast<scalar_t*>(a_data));
164 ideep::tensor b({
165 /*sizes=*/{n, k},
166 idtype,
167 /*strides=*/b_strides},
168 const_cast<scalar_t*>(b_data));
169 ideep::tensor c({
170 /*sizes=*/{n, m},
171 idtype,
172 /*strides=*/c_strides},
173 c_data);
174
175 ideep::matmul_forward::compute(
176 b, a, c, alpha, beta,
177 ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
178
179 if (c.get_data_handle() != c_data){
180 // ideep will query onednn expect format of output
181 // if given output format is not expected, ideep will re-init an output buffer
182 // under this case, we need copy the re-inited buffer back to given buffer
183 ideep::tensor real_output({
184 /*sizes=*/{n, m},
185 idtype,
186 /*strides=*/c_strides},
187 c_data);
188 c.reorder_to(real_output);
189 }
190
191 return true;
192 }
193
mkldnn_bf16_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const c10::BFloat16 * a,int64_t lda,const c10::BFloat16 * b,int64_t ldb,float beta,c10::BFloat16 * c,int64_t ldc)194 bool mkldnn_bf16_gemm(
195 TransposeType transa, TransposeType transb,
196 int64_t m, int64_t n, int64_t k,
197 float alpha,
198 const c10::BFloat16 *a, int64_t lda,
199 const c10::BFloat16 *b, int64_t ldb,
200 float beta,
201 c10::BFloat16 *c, int64_t ldc) {
202 return mkldnn_gemm<c10::BFloat16>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
203 }
204
mkldnn_fp16_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const c10::Half * a,int64_t lda,const c10::Half * b,int64_t ldb,float beta,c10::Half * c,int64_t ldc)205 bool mkldnn_fp16_gemm(
206 TransposeType transa, TransposeType transb,
207 int64_t m, int64_t n, int64_t k,
208 float alpha,
209 const c10::Half *a, int64_t lda,
210 const c10::Half *b, int64_t ldb,
211 float beta,
212 c10::Half *c, int64_t ldc) {
213 return mkldnn_gemm<c10::Half>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
214 }
215
mkldnn_bf32_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const float * a,int64_t lda,const float * b,int64_t ldb,float beta,float * c,int64_t ldc)216 bool mkldnn_bf32_gemm(
217 TransposeType transa, TransposeType transb,
218 int64_t m, int64_t n, int64_t k,
219 float alpha,
220 const float *a, int64_t lda,
221 const float *b, int64_t ldb,
222 float beta,
223 float *c, int64_t ldc){
224 return mkldnn_gemm<float>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
225 }
226
mkldnn_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result,float beta,float alpha)227 void mkldnn_matmul(
228 const Tensor &mat1,
229 const Tensor &mat2,
230 const Tensor &result,
231 float beta,
232 float alpha) {
233 TORCH_CHECK((mat1.dim() == 2 && mat2.dim() == 2) || // aten::addmm
234 (mat1.dim() == 3 && mat2.dim() == 3) || // aten::bmm, aten::baddbmm
235 (mat1.dim() == 2 && mat2.dim() == 1) || // aten::mv
236 (mat1.dim() == 1 && mat2.dim() == 1), // aten::dot
237 "mkldnn_matmul: unsupported dims for mat and mat2");
238
239 #if defined(__aarch64__)
240 // oneDNN fast-maths mode (enabled by setting the environment variable ONEDNN_DEFAULT_FPMATH_MODE=BF16) will dispatch
241 // fp32 inputs to bf16 kernels where HW permits. So, both fp32 and bf16 inputs are permitted.
242 TORCH_CHECK((mat1.scalar_type() == mat2.scalar_type()) && (mat1.scalar_type() == result.scalar_type()) &&
243 ((mat1.scalar_type() == at::kFloat) || (mat1.scalar_type() == at::kBFloat16)),
244 "mkldnn_matmul: only enabled for fp32 and bf16 path");
245 // device needs to support bf16 if the inputs are of bf16 type
246 if (mat1.scalar_type() == at::kBFloat16) {
247 TORCH_CHECK(mkldnn_bf16_device_check_arm(),
248 "mkldnn_matmul: mkldnn_matmul bf16 path needs a cpu with bf16 support");
249 }
250 #else
251 TORCH_CHECK(
252 (mat1.scalar_type() == at::kBFloat16 ||
253 mat1.scalar_type() == at::kHalf ||
254 mat1.scalar_type() == at::kFloat) &&
255 mat2.scalar_type() == mat1.scalar_type() &&
256 result.scalar_type() == mat1.scalar_type(),
257 "mkldnn_matmul: only enabled for bf16 and fp16 path");
258 if (mat1.scalar_type() == at::kBFloat16 || mat1.scalar_type() == at::kFloat) {
259 TORCH_CHECK(
260 mkldnn_bf16_device_check(),
261 "mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq, or AWS Graviton3");
262 } else {
263 TORCH_INTERNAL_ASSERT(mat1.scalar_type() == at::kHalf);
264 TORCH_CHECK(
265 mkldnn_fp16_device_check(),
266 "mkldnn_matmul: mkldnn_matmul fp16 path needs the cpu support avx_ne_convert or avx512_fp16");
267 }
268 #endif
269
270 auto mat1_unsqueezed = mat1.dim() == 1 ? mat1.unsqueeze(0) : mat1;
271 auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2;
272 auto result_unsqueezed = result.dim() == 1 ? result.unsqueeze(1) : result;
273 bool bf32_usable = mat1.scalar_type() == at::kFloat && use_mkldnn_bf32_matmul();
274
275 ideep::attr_t op_attr;
276 // "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor
277 // but mkldnn matmul primitive only support bias be 1-D tensors
278 // to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over
279 if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum();
280 if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path
281 // If alpha = 0, dose not need actually do gemm computation
282 if (alpha == 0)
283 return;
284
285 auto is_mkldnn_optimized_format = [&](const Tensor& t) {
286 if (t.is_contiguous()) return true;
287 const auto sizes = t.sizes();
288 const auto strides = t.strides();
289 if (t.dim() == 2){
290 return strides[0] == 1 && strides[1] == sizes[0];
291 } else {
292 // dim = 3
293 return strides[0] == sizes[1] * sizes[2] && strides[1] == 1 && strides[2] == sizes[1];
294 }
295 };
296
297 // Mkldnn only optimized for contiguous or transposed (transpose last 2 dim if 3-D tensor) format now
298 // Will remove this "contiguous" after mkldnn have fully supported
299 Tensor mat1_ = is_mkldnn_optimized_format(mat1_unsqueezed) ? mat1_unsqueezed : mat1_unsqueezed.contiguous();
300 Tensor mat2_ = is_mkldnn_optimized_format(mat2_unsqueezed) ? mat2_unsqueezed : mat2_unsqueezed.contiguous();
301 // Make sure mat1 and mat2 have default contiguous strides if they are contiguous tensors for better performance.
302 mat1_ = may_convert_to_default_contiguous_strides(mat1_);
303 mat2_ = may_convert_to_default_contiguous_strides(mat2_);
304
305 // mkldnn_matmul only proceed CPU tensor
306 const ideep::tensor x = itensor_view_from_dense(mat1_);
307 const ideep::tensor w = itensor_view_from_dense(mat2_);
308 ideep::tensor y = itensor_view_from_dense(result_unsqueezed);
309 ideep::matmul_forward::compute(x, w, y, alpha, beta,
310 ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
311 if (y.get_data_handle() != result.data_ptr()){
312 // ideep will query onednn expect format of output
313 // if given output format is not expected, ideep will re-init an output buffer
314 // under this case, we need copy the re-inited buffer back to given buffer
315 ideep::tensor public_y = itensor_view_from_dense(result);
316 y.reorder_to(public_y);
317 }
318
319 if (mat1.dim() == 1 && mat2.dim() == 1){
320 // aten::dot
321 result.squeeze_();
322 }
323
324 }
325
checksize(const Tensor & mat1,const Tensor & mat2)326 inline bool checksize(const Tensor& mat1, const Tensor& mat2){
327 // if dim = 2, mat1's size = (m * n), mat2's size = (n * k)
328 // else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
329 // else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
330 // only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
331 static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
332 if (mat1.dim() == 1 && mat2.dim() == 1) {
333 // aten::dot
334 return mat1.size(0) > mkldnn_gemm_min_size;
335 } else if (mat1.dim() == 2 && mat2.dim() == 1) {
336 // aten::mv
337 return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size;
338 } else if (mat2.dim() == 2 && mat2.dim() == 2) {
339 // aten::addmm
340 return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size;
341 } else {
342 // aten::bmm, aten::baddbmm
343 return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size;
344 }
345 }
346
use_mkldnn_bf16_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)347 bool use_mkldnn_bf16_matmul(
348 const Tensor& mat1,
349 const Tensor& mat2,
350 const Tensor& result) {
351 #if defined(__aarch64__)
352 if (mkldnn_bf16_device_check_arm()) {
353 //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
354 //so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well
355 return (
356 use_mkldnn_bf16_matmul() &&
357 (mat1.scalar_type() == mat2.scalar_type()) && (!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
358 ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
359 mat1.numel() != 0 &&
360 mat2.numel() != 0 &&
361 checksize(mat1, mat2));
362 } else
363 #endif
364 {
365 return (
366 use_mkldnn_bf16_matmul() &&
367 mat1.scalar_type() == kBFloat16 &&
368 mat2.scalar_type() == kBFloat16 &&
369 (!result.defined() || result.scalar_type() == kBFloat16) &&
370 mat1.numel() != 0 &&
371 mat2.numel() != 0 &&
372 checksize(mat1, mat2));
373 }
374 }
375
use_mkldnn_fp16_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)376 bool use_mkldnn_fp16_matmul(
377 const Tensor& mat1,
378 const Tensor& mat2,
379 const Tensor& result) {
380
381 return (
382 use_mkldnn_fp16_matmul() &&
383 mat1.scalar_type() == kHalf &&
384 mat2.scalar_type() == kHalf &&
385 (!result.defined() || result.scalar_type() == kHalf) &&
386 mat1.numel() != 0 &&
387 mat2.numel() != 0 &&
388 checksize(mat1, mat2));
389 }
390
use_mkldnn_bf32_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)391 bool use_mkldnn_bf32_matmul(
392 const Tensor& mat1,
393 const Tensor& mat2,
394 const Tensor& result) {
395
396 return (
397 use_mkldnn_bf32_matmul() &&
398 mat1.scalar_type() == kFloat &&
399 mat2.scalar_type() == kFloat &&
400 (!result.defined() || result.scalar_type() == kFloat) &&
401 mat1.numel() != 0 &&
402 mat2.numel() != 0 &&
403 checksize(mat1, mat2));
404 }
405
use_mkldnn_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)406 bool use_mkldnn_matmul(
407 const Tensor& mat1,
408 const Tensor& mat2,
409 const Tensor& result) {
410 return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result));
411 }
412
_mkldnn_matmul_i8i8i32_with_primitive(const Tensor & mat1,const Tensor & mat2,const Tensor & result)413 static void _mkldnn_matmul_i8i8i32_with_primitive(
414 const Tensor &mat1,
415 const Tensor &mat2,
416 const Tensor &result) {
417 // Create ideep tensors for oneDNN computation
418 auto src = ideep::tensor(
419 {mat1.sizes().vec(),
420 ideep::tensor::data_type::s8,
421 mat1.strides().vec()},
422 mat1.data_ptr());
423 auto wei = ideep::tensor(
424 {mat2.sizes().vec(),
425 ideep::tensor::data_type::s8,
426 mat2.strides().vec()},
427 mat2.data_ptr());
428 auto dst = ideep::tensor(
429 {result.sizes().vec(),
430 ideep::tensor::data_type::s32,
431 result.strides().vec()},
432 result.data_ptr());
433 // Create primitive desc
434 auto engine = ideep::engine::cpu_engine();
435 ideep::attr_t op_attr;
436 op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
437 auto src_desc = src.get_desc();
438 auto wei_desc = wei.get_desc();
439 auto dst_desc = dst.get_desc();
440 auto prim_desc = dnnl::matmul::primitive_desc(
441 engine, src_desc, wei_desc, dst_desc, op_attr);
442 // Reorder mat2 if needed
443 auto expected_weight = wei.reorder_if_differ_in(prim_desc.weights_desc());
444 // Prepare args for primitive
445 ideep::tensor scratchpad(prim_desc.scratchpad_desc());
446 ideep::exec_args args;
447 args.insert({DNNL_ARG_SRC, src});
448 args.insert({DNNL_ARG_WEIGHTS, expected_weight});
449 args.insert({DNNL_ARG_DST, dst});
450 args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
451 // Create primitve and execute
452 auto primitive = dnnl::matmul(prim_desc);
453 primitive.execute(ideep::stream::default_stream(), args);
454 }
455
_mkldnn_gemm_i8i8i32_with_blas(const Tensor & self,const Tensor & mat2,const Tensor & result)456 static void _mkldnn_gemm_i8i8i32_with_blas(
457 const Tensor& self,
458 const Tensor& mat2,
459 const Tensor& result) {
460 const int m = result.size(0);
461 const int n = result.size(1);
462 const int k = self.size(1);
463
464 const char transa = self.strides()[1] == 1 ? 'N' : 'T';
465 const char transb = mat2.strides()[1] == 1 ? 'N' : 'T';
466 const char offsetc = 'F';
467
468 const int lda = transa == 'T' ? self.stride(1) : self.stride(0);
469 const int ldb = transb == 'T' ? mat2.stride(1) : mat2.stride(0);
470 const int ldc = n;
471
472 const float alpha = 1;
473 const float beta = 0;
474
475 int8_t ao = 0;
476 int8_t bo = 0;
477 int32_t co = 0;
478
479 dnnl::gemm_s8s8s32(
480 transa,
481 transb,
482 offsetc,
483 m,
484 n,
485 k,
486 alpha,
487 (int8_t*)self.data_ptr(),
488 lda,
489 ao,
490 (int8_t*)mat2.data_ptr(),
491 ldb,
492 bo,
493 beta,
494 (int32_t*)result.data_ptr(),
495 ldc,
496 &co);
497 }
498
mkldnn_matmul_i8i8i32(const Tensor & mat1,const Tensor & mat2,const Tensor & result)499 void mkldnn_matmul_i8i8i32(
500 const Tensor &mat1,
501 const Tensor &mat2,
502 const Tensor &result) {
503 // x:s8 * w:s8 -> y:s32
504 // both inputs should be 2d
505 // In most cases, using DNNL blas API is faster but it requires a/b contiguous along one dimentsion
506 bool a_is_contigous = (mat1.stride(0) == 1 || mat1.stride(1) == 1);
507 bool b_is_contigous = (mat2.stride(0) == 1 || mat2.stride(1) == 1);
508
509 if (a_is_contigous && b_is_contigous) {
510 _mkldnn_gemm_i8i8i32_with_blas(mat1, mat2, result);
511 } else {
512 _mkldnn_matmul_i8i8i32_with_primitive(mat1, mat2, result);
513 }
514 }
515
516 } // namespace native
517 } // namespace at
518
519 #endif // AT_MKLDNN_ENABLED
520