xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/Matmul.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Context.h>
5 #include <ATen/native/mkldnn/Matmul.h>
6 
7 #if !AT_MKLDNN_ENABLED()
8 
9 namespace at {
10 namespace native {
11 
mkldnn_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result,float beta,float alpha)12 void mkldnn_matmul(
13     const Tensor &mat1,
14     const Tensor &mat2,
15     const Tensor &result,
16     float beta,
17     float alpha) {
18   TORCH_CHECK(false, "mkldnn_matmul: ATen not compiled with MKLDNN support");
19 }
20 
use_mkldnn_bf16_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result_opt)21 bool use_mkldnn_bf16_matmul(
22     const Tensor& mat1,
23     const Tensor& mat2,
24     const Tensor& result_opt){
25   return false;
26 }
27 
use_mkldnn_fp16_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result_opt)28 bool use_mkldnn_fp16_matmul(
29     const Tensor& mat1,
30     const Tensor& mat2,
31     const Tensor& result_opt){
32   return false;
33 }
34 
mkldnn_bf16_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const c10::BFloat16 * a,int64_t lda,const c10::BFloat16 * b,int64_t ldb,float beta,c10::BFloat16 * c,int64_t ldc)35 bool mkldnn_bf16_gemm(
36     TransposeType transa, TransposeType transb,
37     int64_t m, int64_t n, int64_t k,
38     float alpha,
39     const c10::BFloat16 *a, int64_t lda,
40     const c10::BFloat16 *b, int64_t ldb,
41     float beta,
42     c10::BFloat16 *c, int64_t ldc) {
43   return false;
44 }
45 
mkldnn_fp16_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const c10::Half * a,int64_t lda,const c10::Half * b,int64_t ldb,float beta,c10::Half * c,int64_t ldc)46 bool mkldnn_fp16_gemm(
47     TransposeType transa, TransposeType transb,
48     int64_t m, int64_t n, int64_t k,
49     float alpha,
50     const c10::Half *a, int64_t lda,
51     const c10::Half *b, int64_t ldb,
52     float beta,
53     c10::Half *c, int64_t ldc) {
54   return false;
55 }
mkldnn_bf32_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const float * a,int64_t lda,const float * b,int64_t ldb,float beta,float * c,int64_t ldc)56 bool mkldnn_bf32_gemm(
57     TransposeType transa, TransposeType transb,
58     int64_t m, int64_t n, int64_t k,
59     float alpha,
60     const float *a, int64_t lda,
61     const float *b, int64_t ldb,
62     float beta,
63     float *c, int64_t ldc){
64       return false;
65     }
66 
use_mkldnn_bf32_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)67 bool use_mkldnn_bf32_matmul(
68     const Tensor& mat1,
69     const Tensor& mat2,
70     const Tensor& result) {
71     return false;
72 }
73 
use_mkldnn_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)74 bool use_mkldnn_matmul(
75     const Tensor& mat1,
76     const Tensor& mat2,
77     const Tensor& result) {
78     return false;
79 }
80 
mkldnn_matmul_i8i8i32(const Tensor & mat1,const Tensor & mat2,const Tensor & result)81 void mkldnn_matmul_i8i8i32(
82     const Tensor &mat1,
83     const Tensor &mat2,
84     const Tensor &result) {
85   TORCH_INTERNAL_ASSERT(false, __func__, ": ATen not compiled with MKLDNN support");
86 }
87 
88 } // namespace native
89 } // namespace at
90 
91 #else // AT_MKLDNN_ENABLED
92 
93 #include <ATen/native/mkldnn/MKLDNNCommon.h>
94 #include <ATen/native/mkldnn/Utils.h>
95 
96 namespace at {
97 namespace native {
98 
use_mkldnn_bf16_matmul()99 static bool use_mkldnn_bf16_matmul() {
100   return at::globalContext().userEnabledMkldnn() && mkldnn_bf16_device_check();
101 }
102 
use_mkldnn_fp16_matmul()103 static bool use_mkldnn_fp16_matmul() {
104   return at::globalContext().userEnabledMkldnn() && mkldnn_fp16_device_check();
105 }
106 
use_mkldnn_bf32_matmul()107 static bool use_mkldnn_bf32_matmul() {
108   return use_mkldnn_bf16_matmul() && at::globalContext().float32MatmulPrecision() == at::Float32MatmulPrecision::MEDIUM;
109 }
110 
111 
112 template<typename scalar_t>
113 inline typename std::enable_if_t<
114     std::is_same_v<scalar_t, float> ||
115     std::is_same_v<scalar_t, c10::Half> ||
116     std::is_same_v<scalar_t, c10::BFloat16>,
117     bool>
mkldnn_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const scalar_t * a_data,int64_t lda,const scalar_t * b_data,int64_t ldb,float beta,scalar_t * c_data,int64_t ldc)118 mkldnn_gemm(
119     TransposeType transa, TransposeType transb,
120     int64_t m, int64_t n, int64_t k,
121     float alpha,
122     const scalar_t *a_data, int64_t lda,
123     const scalar_t *b_data, int64_t ldb,
124     float beta,
125     scalar_t *c_data, int64_t ldc) {
126   bool bf16_usable = std::is_same_v<scalar_t, c10::BFloat16> && use_mkldnn_bf16_matmul();
127   bool fp16_usable = std::is_same_v<scalar_t, c10::Half> && use_mkldnn_fp16_matmul();
128   bool bf32_usable = std::is_same_v<scalar_t, float> && use_mkldnn_bf32_matmul();
129   if ( !(bf16_usable || fp16_usable || bf32_usable) ||
130       (m * n * k <= 16 * 16 * 16) || (alpha == 0.0f)) {
131     return false;
132   }
133 
134   ideep::attr_t op_attr;
135   // Use mkldnn post ops to perform the add.
136   if (beta != 0.0f) {
137     op_attr = ideep::attr_t::fuse_sum();
138   }
139   if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path
140 
141   // NOTE: View as c-contiguous to avoid extra reordering in mkldnn
142   // Use identity: C = AB <=> C^T = B^T A^T
143   ideep::tensor::dims a_strides{{lda, 1}}, b_strides{{ldb, 1}}, c_strides{{ldc, 1}};
144   if (transa != TransposeType::NoTranspose) {
145     std::swap(a_strides[0], a_strides[1]);
146   }
147   if (transb != TransposeType::NoTranspose) {
148     std::swap(b_strides[0], b_strides[1]);
149   }
150 
151   auto idtype = ideep::tensor::data_type::bf16;
152   if constexpr (std::is_same_v<scalar_t, c10::Half>) {
153     idtype = ideep::tensor::data_type::f16;
154   }
155   if constexpr (std::is_same_v<scalar_t, float>) {
156     idtype = ideep::tensor::data_type::f32;
157   }
158 
159   ideep::tensor a({
160       /*sizes=*/{k, m},
161       idtype,
162       /*strides=*/a_strides},
163     const_cast<scalar_t*>(a_data));
164   ideep::tensor b({
165       /*sizes=*/{n, k},
166       idtype,
167       /*strides=*/b_strides},
168     const_cast<scalar_t*>(b_data));
169   ideep::tensor c({
170       /*sizes=*/{n, m},
171       idtype,
172       /*strides=*/c_strides},
173     c_data);
174 
175   ideep::matmul_forward::compute(
176       b, a, c, alpha, beta,
177       ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
178 
179   if (c.get_data_handle() != c_data){
180     // ideep will query onednn expect format of output
181     // if given output format is not expected, ideep will re-init an output buffer
182     // under this case, we need copy the re-inited buffer back to given buffer
183     ideep::tensor real_output({
184         /*sizes=*/{n, m},
185         idtype,
186         /*strides=*/c_strides},
187       c_data);
188     c.reorder_to(real_output);
189   }
190 
191   return true;
192 }
193 
mkldnn_bf16_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const c10::BFloat16 * a,int64_t lda,const c10::BFloat16 * b,int64_t ldb,float beta,c10::BFloat16 * c,int64_t ldc)194 bool mkldnn_bf16_gemm(
195     TransposeType transa, TransposeType transb,
196     int64_t m, int64_t n, int64_t k,
197     float alpha,
198     const c10::BFloat16 *a, int64_t lda,
199     const c10::BFloat16 *b, int64_t ldb,
200     float beta,
201     c10::BFloat16 *c, int64_t ldc) {
202   return mkldnn_gemm<c10::BFloat16>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
203 }
204 
mkldnn_fp16_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const c10::Half * a,int64_t lda,const c10::Half * b,int64_t ldb,float beta,c10::Half * c,int64_t ldc)205 bool mkldnn_fp16_gemm(
206     TransposeType transa, TransposeType transb,
207     int64_t m, int64_t n, int64_t k,
208     float alpha,
209     const c10::Half *a, int64_t lda,
210     const c10::Half *b, int64_t ldb,
211     float beta,
212     c10::Half *c, int64_t ldc) {
213   return mkldnn_gemm<c10::Half>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
214 }
215 
mkldnn_bf32_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,float alpha,const float * a,int64_t lda,const float * b,int64_t ldb,float beta,float * c,int64_t ldc)216 bool mkldnn_bf32_gemm(
217     TransposeType transa, TransposeType transb,
218     int64_t m, int64_t n, int64_t k,
219     float alpha,
220     const float *a, int64_t lda,
221     const float *b, int64_t ldb,
222     float beta,
223     float *c, int64_t ldc){
224       return mkldnn_gemm<float>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
225     }
226 
mkldnn_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result,float beta,float alpha)227 void mkldnn_matmul(
228     const Tensor &mat1,
229     const Tensor &mat2,
230     const Tensor &result,
231     float beta,
232     float alpha) {
233   TORCH_CHECK((mat1.dim() == 2 && mat2.dim() == 2) || // aten::addmm
234               (mat1.dim() == 3 && mat2.dim() == 3) || // aten::bmm, aten::baddbmm
235               (mat1.dim() == 2 && mat2.dim() == 1) || // aten::mv
236               (mat1.dim() == 1 && mat2.dim() == 1),  // aten::dot
237               "mkldnn_matmul:  unsupported dims for mat and mat2");
238 
239 #if defined(__aarch64__)
240   // oneDNN fast-maths mode (enabled by setting the environment variable ONEDNN_DEFAULT_FPMATH_MODE=BF16) will dispatch
241   // fp32 inputs to bf16 kernels where HW permits. So, both fp32 and bf16 inputs are permitted.
242   TORCH_CHECK((mat1.scalar_type() == mat2.scalar_type()) && (mat1.scalar_type() == result.scalar_type()) &&
243               ((mat1.scalar_type() == at::kFloat) || (mat1.scalar_type() == at::kBFloat16)),
244               "mkldnn_matmul:  only enabled for fp32 and bf16 path");
245   // device needs to support bf16 if the inputs are of bf16 type
246   if (mat1.scalar_type() == at::kBFloat16) {
247     TORCH_CHECK(mkldnn_bf16_device_check_arm(),
248                 "mkldnn_matmul: mkldnn_matmul bf16 path needs a cpu with bf16 support");
249   }
250 #else
251   TORCH_CHECK(
252       (mat1.scalar_type() == at::kBFloat16 ||
253        mat1.scalar_type() == at::kHalf ||
254        mat1.scalar_type() == at::kFloat) &&
255           mat2.scalar_type() == mat1.scalar_type() &&
256           result.scalar_type() == mat1.scalar_type(),
257       "mkldnn_matmul:  only enabled for bf16 and fp16 path");
258   if (mat1.scalar_type() == at::kBFloat16 || mat1.scalar_type() == at::kFloat) {
259     TORCH_CHECK(
260         mkldnn_bf16_device_check(),
261         "mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq, or AWS Graviton3");
262   } else {
263     TORCH_INTERNAL_ASSERT(mat1.scalar_type() == at::kHalf);
264     TORCH_CHECK(
265         mkldnn_fp16_device_check(),
266         "mkldnn_matmul: mkldnn_matmul fp16 path needs the cpu support avx_ne_convert or avx512_fp16");
267   }
268 #endif
269 
270   auto mat1_unsqueezed = mat1.dim() == 1 ? mat1.unsqueeze(0) : mat1;
271   auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2;
272   auto result_unsqueezed = result.dim() == 1 ? result.unsqueeze(1) : result;
273   bool bf32_usable = mat1.scalar_type() == at::kFloat && use_mkldnn_bf32_matmul();
274 
275   ideep::attr_t op_attr;
276   // "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor
277   // but mkldnn matmul primitive only support bias be 1-D tensors
278   // to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over
279   if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum();
280   if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path
281   // If alpha = 0, dose not need actually do gemm computation
282   if (alpha == 0)
283     return;
284 
285   auto is_mkldnn_optimized_format = [&](const Tensor& t) {
286     if (t.is_contiguous()) return true;
287     const auto sizes = t.sizes();
288     const auto strides = t.strides();
289     if (t.dim() == 2){
290       return strides[0] == 1 && strides[1] == sizes[0];
291     } else {
292       // dim = 3
293       return strides[0] == sizes[1] * sizes[2] && strides[1] == 1 && strides[2] == sizes[1];
294     }
295   };
296 
297   // Mkldnn only optimized for contiguous or transposed (transpose last 2 dim if 3-D tensor) format now
298   // Will remove this "contiguous" after mkldnn have fully supported
299   Tensor mat1_ = is_mkldnn_optimized_format(mat1_unsqueezed) ? mat1_unsqueezed : mat1_unsqueezed.contiguous();
300   Tensor mat2_ = is_mkldnn_optimized_format(mat2_unsqueezed) ? mat2_unsqueezed : mat2_unsqueezed.contiguous();
301   // Make sure mat1 and mat2 have default contiguous strides if they are contiguous tensors for better performance.
302   mat1_ = may_convert_to_default_contiguous_strides(mat1_);
303   mat2_ = may_convert_to_default_contiguous_strides(mat2_);
304 
305   // mkldnn_matmul only proceed CPU tensor
306   const ideep::tensor x = itensor_view_from_dense(mat1_);
307   const ideep::tensor w = itensor_view_from_dense(mat2_);
308   ideep::tensor y = itensor_view_from_dense(result_unsqueezed);
309   ideep::matmul_forward::compute(x, w, y, alpha, beta,
310       ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
311   if (y.get_data_handle() != result.data_ptr()){
312     // ideep will query onednn expect format of output
313     // if given output format is not expected, ideep will re-init an output buffer
314     // under this case, we need copy the re-inited buffer back to given buffer
315     ideep::tensor public_y = itensor_view_from_dense(result);
316     y.reorder_to(public_y);
317   }
318 
319   if (mat1.dim() == 1 && mat2.dim() == 1){
320     // aten::dot
321     result.squeeze_();
322   }
323 
324 }
325 
checksize(const Tensor & mat1,const Tensor & mat2)326 inline bool checksize(const Tensor& mat1, const Tensor& mat2){
327   // if dim = 2, mat1's size = (m * n), mat2's size = (n * k)
328   // else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
329   // else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
330   // only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
331   static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
332   if (mat1.dim() == 1 && mat2.dim() == 1) {
333     // aten::dot
334     return mat1.size(0) > mkldnn_gemm_min_size;
335   } else if (mat1.dim() == 2 && mat2.dim() == 1) {
336     // aten::mv
337     return mat1.size(0) * mat1.size(1) > mkldnn_gemm_min_size;
338   } else if (mat2.dim() == 2 && mat2.dim() == 2) {
339     // aten::addmm
340     return mat1.size(0) * mat1.size(1) * mat2.size(1) > mkldnn_gemm_min_size;
341   } else {
342     // aten::bmm, aten::baddbmm
343     return mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) > mkldnn_gemm_min_size;
344   }
345 }
346 
use_mkldnn_bf16_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)347 bool use_mkldnn_bf16_matmul(
348     const Tensor& mat1,
349     const Tensor& mat2,
350     const Tensor& result) {
351 #if defined(__aarch64__)
352   if (mkldnn_bf16_device_check_arm()) {
353      //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
354      //so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well
355      return (
356         use_mkldnn_bf16_matmul() &&
357         (mat1.scalar_type() == mat2.scalar_type()) && (!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
358         ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
359         mat1.numel() != 0 &&
360         mat2.numel() != 0 &&
361         checksize(mat1, mat2));
362   } else
363 #endif
364   {
365      return (
366         use_mkldnn_bf16_matmul() &&
367         mat1.scalar_type() == kBFloat16 &&
368         mat2.scalar_type() == kBFloat16 &&
369         (!result.defined() || result.scalar_type() == kBFloat16) &&
370         mat1.numel() != 0 &&
371         mat2.numel() != 0 &&
372         checksize(mat1, mat2));
373   }
374 }
375 
use_mkldnn_fp16_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)376 bool use_mkldnn_fp16_matmul(
377     const Tensor& mat1,
378     const Tensor& mat2,
379     const Tensor& result) {
380 
381     return (
382       use_mkldnn_fp16_matmul() &&
383       mat1.scalar_type() == kHalf &&
384       mat2.scalar_type() == kHalf &&
385       (!result.defined() || result.scalar_type() == kHalf) &&
386       mat1.numel() != 0 &&
387       mat2.numel() != 0 &&
388       checksize(mat1, mat2));
389 }
390 
use_mkldnn_bf32_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)391 bool use_mkldnn_bf32_matmul(
392     const Tensor& mat1,
393     const Tensor& mat2,
394     const Tensor& result) {
395 
396     return (
397       use_mkldnn_bf32_matmul() &&
398       mat1.scalar_type() == kFloat &&
399       mat2.scalar_type() == kFloat &&
400       (!result.defined() || result.scalar_type() == kFloat) &&
401       mat1.numel() != 0 &&
402       mat2.numel() != 0 &&
403       checksize(mat1, mat2));
404 }
405 
use_mkldnn_matmul(const Tensor & mat1,const Tensor & mat2,const Tensor & result)406 bool use_mkldnn_matmul(
407     const Tensor& mat1,
408     const Tensor& mat2,
409     const Tensor& result) {
410   return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result));
411 }
412 
_mkldnn_matmul_i8i8i32_with_primitive(const Tensor & mat1,const Tensor & mat2,const Tensor & result)413 static void _mkldnn_matmul_i8i8i32_with_primitive(
414     const Tensor &mat1,
415     const Tensor &mat2,
416     const Tensor &result) {
417   // Create ideep tensors for oneDNN computation
418   auto src = ideep::tensor(
419       {mat1.sizes().vec(),
420        ideep::tensor::data_type::s8,
421        mat1.strides().vec()},
422       mat1.data_ptr());
423   auto wei = ideep::tensor(
424       {mat2.sizes().vec(),
425        ideep::tensor::data_type::s8,
426        mat2.strides().vec()},
427       mat2.data_ptr());
428   auto dst = ideep::tensor(
429       {result.sizes().vec(),
430        ideep::tensor::data_type::s32,
431        result.strides().vec()},
432       result.data_ptr());
433   // Create primitive desc
434   auto engine = ideep::engine::cpu_engine();
435   ideep::attr_t op_attr;
436   op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
437   auto src_desc = src.get_desc();
438   auto wei_desc = wei.get_desc();
439   auto dst_desc = dst.get_desc();
440   auto prim_desc = dnnl::matmul::primitive_desc(
441       engine, src_desc, wei_desc, dst_desc, op_attr);
442   // Reorder mat2 if needed
443   auto expected_weight = wei.reorder_if_differ_in(prim_desc.weights_desc());
444   // Prepare args for primitive
445   ideep::tensor scratchpad(prim_desc.scratchpad_desc());
446   ideep::exec_args args;
447   args.insert({DNNL_ARG_SRC, src});
448   args.insert({DNNL_ARG_WEIGHTS, expected_weight});
449   args.insert({DNNL_ARG_DST, dst});
450   args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
451   // Create primitve and execute
452   auto primitive = dnnl::matmul(prim_desc);
453   primitive.execute(ideep::stream::default_stream(), args);
454 }
455 
_mkldnn_gemm_i8i8i32_with_blas(const Tensor & self,const Tensor & mat2,const Tensor & result)456 static void _mkldnn_gemm_i8i8i32_with_blas(
457   const Tensor& self,
458   const Tensor& mat2,
459   const Tensor& result) {
460     const int m = result.size(0);
461     const int n = result.size(1);
462     const int k = self.size(1);
463 
464     const char transa = self.strides()[1] == 1 ? 'N' : 'T';
465     const char transb = mat2.strides()[1] == 1 ? 'N' : 'T';
466     const char offsetc = 'F';
467 
468     const int lda = transa == 'T' ? self.stride(1) : self.stride(0);
469     const int ldb = transb == 'T' ? mat2.stride(1) : mat2.stride(0);
470     const int ldc = n;
471 
472     const float alpha = 1;
473     const float beta = 0;
474 
475     int8_t ao = 0;
476     int8_t bo = 0;
477     int32_t co = 0;
478 
479     dnnl::gemm_s8s8s32(
480         transa,
481         transb,
482         offsetc,
483         m,
484         n,
485         k,
486         alpha,
487         (int8_t*)self.data_ptr(),
488         lda,
489         ao,
490         (int8_t*)mat2.data_ptr(),
491         ldb,
492         bo,
493         beta,
494         (int32_t*)result.data_ptr(),
495         ldc,
496         &co);
497   }
498 
mkldnn_matmul_i8i8i32(const Tensor & mat1,const Tensor & mat2,const Tensor & result)499 void mkldnn_matmul_i8i8i32(
500     const Tensor &mat1,
501     const Tensor &mat2,
502     const Tensor &result) {
503   // x:s8 * w:s8 -> y:s32
504   // both inputs should be 2d
505   // In most cases, using DNNL blas API is faster but it requires a/b contiguous along one dimentsion
506   bool a_is_contigous = (mat1.stride(0) == 1 || mat1.stride(1) == 1);
507   bool b_is_contigous = (mat2.stride(0) == 1 || mat2.stride(1) == 1);
508 
509   if (a_is_contigous && b_is_contigous) {
510     _mkldnn_gemm_i8i8i32_with_blas(mat1, mat2, result);
511   } else {
512     _mkldnn_matmul_i8i8i32_with_primitive(mat1, mat2, result);
513   }
514 }
515 
516 } // namespace native
517 } // namespace at
518 
519 #endif // AT_MKLDNN_ENABLED
520