xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
18 
19 #ifdef INTEL_MKL
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "dnnl.hpp"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/util/mkl_util.h"
29 #include "tensorflow/core/util/onednn_env_vars.h"
30 #ifdef DNNL_AARCH64_USE_ACL
31 #include "tensorflow/core/platform/hash.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #endif
34 
35 using dnnl::inner_product_forward;
36 using dnnl::primitive_attr;
37 using dnnl::prop_kind;
38 using dnnl::stream;
39 
40 namespace tensorflow {
41 static Eigen::internal::CacheSizes cache_sizes = Eigen::internal::CacheSizes();
42 
43 typedef Eigen::ThreadPoolDevice CPUDevice;
ExecuteSingleThreadedGemm(int m,int n,int k,int bytes)44 inline bool ExecuteSingleThreadedGemm(int m, int n, int k, int bytes) {
45   // Ideally we would like to determine blocking and then come up with
46   // a heuristic but what we are targeting are very small models whose
47   // total size is < x*L2. So we will do this simple calculation
48   // to determine if the matrix multiplication should be run on a single thread.
49   // TODO(Intel-tf): this needs to be vastly improved, perhaps at a lower level
50   // than the integration.
51   ptrdiff_t l2_size = cache_sizes.m_l2;
52   constexpr float kHeuristicMultiplier = 1.01;
53   const float mul_size = bytes * (m * n + k * (m + n));
54   const float l2_heur = l2_size * kHeuristicMultiplier;
55   return mul_size < l2_heur;
56 }
57 
58 // This structure aggregates multiple inputs to MklDnnMatMul* methods.
59 struct MklDnnMatMulFwdParams {
60   memory::dims src_dims;
61   memory::dims weight_dims;
62   memory::dims bias_dims;
63   memory::dims dst_dims;
64   memory::format_tag src_format;
65   memory::format_tag weight_format;
66   memory::format_tag dst_format;
67   string dtypes = string("");
68   bool const_weight;
69 #ifdef DNNL_AARCH64_USE_ACL
70   uint64 weight_hash;
71 #endif
72   struct PostOpParam {
73     string name;
74     std::vector<float> param;
75   };
76   std::vector<PostOpParam> post_op_params;
77 
78   MklDnnMatMulFwdParams(
79       memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims,
80       memory::dims dst_dims,
81       memory::format_tag src_format = memory::format_tag::any,
82       memory::format_tag weight_format = memory::format_tag::any,
83       memory::format_tag dst_format = memory::format_tag::any,
84       bool const_weight = false)
src_dimsMklDnnMatMulFwdParams85       : src_dims(src_dims),
86         weight_dims(weight_dims),
87         bias_dims(bias_dims),
88         dst_dims(dst_dims),
89         src_format(src_format),
90         weight_format(weight_format),
91         dst_format(dst_format),
92         const_weight(const_weight) {}
93 };
94 
95 // With quantization, input, weight, bias, and output can have different types.
96 // So we use different template parameters for each type.
97 // TODO(intel-tf): The template type "T" is currently used to match the
98 // templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h).
99 // In the future, with the removal of "T" from MklPrimitiveFactory, this class
100 // needs to drop "T".
101 template <typename T, typename Tinput, typename Tweight, typename Tbias,
102           typename Toutput>
103 class MklDnnMatMulFwdPrimitive : public MklPrimitive {
104  public:
MklDnnMatMulFwdPrimitive(const MklDnnMatMulFwdParams & matmulFwdParams)105   explicit MklDnnMatMulFwdPrimitive(
106       const MklDnnMatMulFwdParams& matmulFwdParams)
107       : MklPrimitive(engine(engine::kind::cpu, 0)) {
108     // Create matmul primitive
109     if (context_.matmul_fwd == nullptr) {
110       Setup(matmulFwdParams);
111     }
112   }
113 
~MklDnnMatMulFwdPrimitive()114   ~MklDnnMatMulFwdPrimitive() {}
115 
GetScratchPadDesc()116   dnnl::memory::desc GetScratchPadDesc() {
117     return context_.fwd_pd->scratchpad_desc();
118   }
119 
120   // Inner-product forward execute with bias:
121   //  - src_data: input data buffer of src
122   //  - weight_data: input data buffer of weight
123   //  - bias_data: input data buffer of bias
124   //  - dst_data: output data buffer of dst
125   //  - sp_data: scratchpad data
Execute(const Tinput * src_data,const Tweight * weight_data,const Tbias * bias_data,Toutput * dst_data,void * sp_data,std::shared_ptr<stream> fwd_stream)126   void Execute(const Tinput* src_data, const Tweight* weight_data,
127                const Tbias* bias_data, Toutput* dst_data, void* sp_data,
128                std::shared_ptr<stream> fwd_stream) {
129 #ifdef DNNL_AARCH64_USE_ACL
130     mutex_lock lock(primitive_execution_mu_);
131 #endif
132 #ifndef ENABLE_ONEDNN_OPENMP
133     context_.src_mem->set_data_handle(
134         static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
135     context_.weight_mem->set_data_handle(
136         static_cast<void*>(const_cast<Tweight*>(weight_data)), *fwd_stream);
137     context_.bias_mem->set_data_handle(
138         static_cast<void*>(const_cast<Tbias*>(bias_data)));
139     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
140                                       *fwd_stream);
141     context_.sp_mem->set_data_handle(sp_data, *fwd_stream);
142 #else
143     context_.src_mem->set_data_handle(
144         static_cast<void*>(const_cast<Tinput*>(src_data)));
145     context_.weight_mem->set_data_handle(
146         static_cast<void*>(const_cast<Tweight*>(weight_data)));
147     context_.bias_mem->set_data_handle(
148         static_cast<void*>(const_cast<Tbias*>(bias_data)));
149     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
150     context_.sp_mem->set_data_handle(sp_data);
151 #endif  // !ENABLE_ONEDNN_OPENMP
152 
153     execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
154 
155     // After execution, set data handle back
156     context_.src_mem->set_data_handle(DummyData);
157     context_.weight_mem->set_data_handle(DummyData);
158     context_.bias_mem->set_data_handle(DummyData);
159     context_.dst_mem->set_data_handle(DummyData);
160   }
161 
162   std::shared_ptr<dnnl::inner_product_forward::primitive_desc>
GetPrimitiveDesc()163   GetPrimitiveDesc() const {
164     return context_.fwd_pd;
165   }
166 
167  private:
168   // Primitive reuse context for inner-product Fwd op
169   struct MklDnnMatMulFwdContext {
170     // oneDNN memory.
171     std::shared_ptr<dnnl::memory> src_mem;
172     std::shared_ptr<dnnl::memory> weight_mem;
173     std::shared_ptr<dnnl::memory> bias_mem;
174     std::shared_ptr<dnnl::memory> dst_mem;
175     std::shared_ptr<dnnl::memory> sp_mem;
176 
177     // Descriptor and primitive-descriptor for forward inner-product.
178     std::shared_ptr<dnnl::inner_product_forward::desc> fwd_desc;
179     std::shared_ptr<dnnl::inner_product_forward::primitive_desc> fwd_pd;
180 
181     // Memory descriptors.
182     std::shared_ptr<dnnl::memory::desc> src_md;
183     std::shared_ptr<dnnl::memory::desc> weight_md;
184     std::shared_ptr<dnnl::memory::desc> bias_md;
185     std::shared_ptr<dnnl::memory::desc> dst_md;
186 
187     // Inner-product primitive.
188     std::shared_ptr<dnnl::primitive> matmul_fwd;
189     std::vector<dnnl::primitive> fwd_primitives;
190 
191     std::vector<std::unordered_map<int, memory>> net_args;
192 
MklDnnMatMulFwdContextMklDnnMatMulFwdContext193     MklDnnMatMulFwdContext()
194         : src_mem(nullptr),
195           weight_mem(nullptr),
196           bias_mem(nullptr),
197           dst_mem(nullptr),
198           sp_mem(nullptr),
199           fwd_desc(nullptr),
200           fwd_pd(nullptr),
201           src_md(nullptr),
202           weight_md(nullptr),
203           bias_md(nullptr),
204           dst_md(nullptr),
205           matmul_fwd(nullptr) {}
206   };
207 
Setup(const MklDnnMatMulFwdParams & matmul_fwd_params)208   void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) {
209     // Create memory descriptors for inner-product data without specified
210     // format.
211     context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
212                                            MklDnnType<Tinput>(),
213                                            matmul_fwd_params.src_format));
214 
215     context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
216                                               MklDnnType<Tweight>(),
217                                               matmul_fwd_params.weight_format));
218 
219     context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
220                                            MklDnnType<Toutput>(),
221                                            matmul_fwd_params.dst_format));
222 
223     context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
224                                             MklDnnType<Tbias>(),
225                                             memory::format_tag::any));
226     // Create an inner-product.
227     context_.fwd_desc.reset(new inner_product_forward::desc(
228         matmul_fwd_params.const_weight ? prop_kind::forward_inference
229                                        : prop_kind::forward_training,
230         *context_.src_md, *context_.weight_md, *context_.bias_md,
231         *context_.dst_md));
232     context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
233         *context_.fwd_desc, cpu_engine_));
234 
235     // Check if there is any fusion as post-ops
236     auto const& post_op_params = matmul_fwd_params.post_op_params;
237     dnnl::primitive_attr post_ops_attr;
238     post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
239     dnnl::post_ops post_ops;
240     if (!post_op_params.empty()) {
241       for (auto const& post_op_param : post_op_params) {
242         if (post_op_param.name == "relu" || post_op_param.name == "leakyrelu") {
243           DCHECK_EQ(post_op_param.param.size(), 3);
244           float op_scale = post_op_param.param[0];
245           float op_alpha = post_op_param.param[1];
246           float op_beta = post_op_param.param[2];
247           post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_relu,
248                                   op_alpha, op_beta);
249         } else if (post_op_param.name == "relu6") {
250           DCHECK_EQ(post_op_param.param.size(), 3);
251           float op_scale = post_op_param.param[0];
252           float op_alpha = post_op_param.param[1];
253           float op_beta = post_op_param.param[2];
254           post_ops.append_eltwise(op_scale,
255                                   dnnl::algorithm::eltwise_bounded_relu,
256                                   op_alpha, op_beta);
257         } else if (post_op_param.name == "elu") {
258           DCHECK_EQ(post_op_param.param.size(), 3);
259           float op_scale = post_op_param.param[0];
260           float op_alpha = post_op_param.param[1];
261           float op_beta = post_op_param.param[2];
262           post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_elu,
263                                   op_alpha, op_beta);
264         } else if (post_op_param.name == "gelu_approximate") {
265           DCHECK_EQ(post_op_param.param.size(), 3);
266           float op_scale = post_op_param.param[0];
267           float op_alpha = post_op_param.param[1];
268           float op_beta = post_op_param.param[2];
269           post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_gelu_tanh,
270                                   op_alpha, op_beta);
271         } else if (post_op_param.name == "gelu_exact") {
272           DCHECK_EQ(post_op_param.param.size(), 3);
273           float op_scale = post_op_param.param[0];
274           float op_alpha = post_op_param.param[1];
275           float op_beta = post_op_param.param[2];
276           post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_gelu_erf,
277                                   op_alpha, op_beta);
278         } else if (post_op_param.name == "tanh") {
279           DCHECK_EQ(post_op_param.param.size(), 3);
280           float op_scale = post_op_param.param[0];
281           float op_alpha = post_op_param.param[1];
282           float op_beta = post_op_param.param[2];
283           post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_tanh,
284                                   op_alpha, op_beta);
285         } else if (post_op_param.name == "logistic") {
286           DCHECK_EQ(post_op_param.param.size(), 3);
287           float op_scale = post_op_param.param[0];
288           float op_alpha = post_op_param.param[1];
289           float op_beta = post_op_param.param[2];
290           post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_logistic,
291                                   op_alpha, op_beta);
292         } else if (post_op_param.name == "output_scale") {
293           DCHECK_EQ(post_op_param.param.size(), 1);
294           std::vector<float> scales;
295           scales.push_back(post_op_param.param[0]);
296           post_ops_attr.set_output_scales(0, scales);
297         } else if (post_op_param.name == "sum") {
298           DCHECK_EQ(post_op_param.param.size(), 1);
299           float op_scale = post_op_param.param[0];
300           post_ops.append_sum(op_scale);
301 
302         } else {
303           DCHECK((post_op_param.name == "relu") ||
304                  (post_op_param.name == "relu6") ||
305                  (post_op_param.name == "elu") ||
306                  (post_op_param.name == "tanh") ||
307                  (post_op_param.name == "logistic") ||
308                  (post_op_param.name == "sum") ||
309                  (post_op_param.name == "leakyrelu") ||
310                  (post_op_param.name == "output_scale"));
311         }
312       }
313       post_ops_attr.set_post_ops(post_ops);
314       context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
315           *context_.fwd_desc, post_ops_attr, cpu_engine_));
316     } else {
317       context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
318           *context_.fwd_desc, post_ops_attr, cpu_engine_));
319     }
320 
321     // Create memory primitive based on dummy data
322     context_.src_mem.reset(
323         new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
324     context_.weight_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
325                                          cpu_engine_, DummyData));
326     context_.dst_mem.reset(
327         new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
328     context_.bias_mem.reset(new memory({{matmul_fwd_params.bias_dims},
329                                         MklDnnType<Tbias>(),
330                                         memory::format_tag::x},
331                                        cpu_engine_, DummyData));
332     auto scratchpad_md = context_.fwd_pd->scratchpad_desc();
333     context_.sp_mem.reset(
334         new dnnl::memory(scratchpad_md, cpu_engine_, DummyData));
335 
336     // Create inner-product primitive.
337     context_.matmul_fwd.reset(new inner_product_forward(*context_.fwd_pd));
338     context_.net_args.push_back({{DNNL_ARG_SRC, *context_.src_mem},
339                                  {DNNL_ARG_WEIGHTS, *context_.weight_mem},
340                                  {DNNL_ARG_BIAS, *context_.bias_mem},
341                                  {DNNL_ARG_SCRATCHPAD, *context_.sp_mem},
342                                  {DNNL_ARG_DST, *context_.dst_mem}});
343 
344     context_.fwd_primitives.push_back(*context_.matmul_fwd);
345     return;
346   }
347 
348   struct MklDnnMatMulFwdContext context_;
349 
350 #ifdef DNNL_AARCH64_USE_ACL
351   // Guards Execution()
352   mutex primitive_execution_mu_;
353 #endif
354 };
355 
356 template <typename T, typename Tinput, typename Tweight, typename Tbias,
357           typename Toutput>
358 class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
359  public:
Get(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,bool do_not_cache)360   static MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* Get(
361       const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) {
362     MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* matmul_fwd =
363         nullptr;
364 
365     if (do_not_cache) {
366       // Always create new primitive
367       matmul_fwd =
368           new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
369               mkldnn_matmul_fwd_dims);
370     } else {
371       // Try to find a suitable one in pool
372       matmul_fwd = dynamic_cast<
373           MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>*>(
374           MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
375                                           Toutput>::GetInstance()
376               .GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims));
377       if (matmul_fwd == nullptr) {
378         matmul_fwd =
379             new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
380                 mkldnn_matmul_fwd_dims);
381         MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
382                                         Toutput>::GetInstance()
383             .SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd);
384       }
385     }
386     return matmul_fwd;
387   }
388 
389  private:
MklDnnMatMulFwdPrimitiveFactory()390   MklDnnMatMulFwdPrimitiveFactory() {}
~MklDnnMatMulFwdPrimitiveFactory()391   ~MklDnnMatMulFwdPrimitiveFactory() {}
392 
GetInstance()393   static MklDnnMatMulFwdPrimitiveFactory& GetInstance() {
394     static MklDnnMatMulFwdPrimitiveFactory instance_;
395     return instance_;
396   }
397 
CreateKey(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)398   static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
399     string prefix = "matmul_fwd_";
400     FactoryKeyCreator key_creator;
401     key_creator.AddAsKey(prefix);
402     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims);
403     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims);
404     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
405     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
406     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
407     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format);
408 #ifdef DNNL_AARCH64_USE_ACL
409     key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_hash);
410 #endif
411 
412     // Generate keys for post-ops
413     for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
414       if (post_op_param.name == "relu" || post_op_param.name == "relu6" ||
415           post_op_param.name == "elu" || post_op_param.name == "tanh" ||
416           post_op_param.name == "logistic" ||
417           post_op_param.name == "leakyrelu" ||
418           post_op_param.name == "gelu_approximate" ||
419           post_op_param.name == "gelu_exact") {
420         DCHECK_EQ(post_op_param.param.size(), 3);
421         key_creator.AddAsKey(post_op_param.name);
422         key_creator.AddAsKey(post_op_param.param[0]);
423         key_creator.AddAsKey(post_op_param.param[1]);
424         key_creator.AddAsKey(post_op_param.param[2]);
425       } else if (post_op_param.name == "sum") {
426         DCHECK_EQ(post_op_param.param.size(), 1);
427         key_creator.AddAsKey(post_op_param.name);
428         key_creator.AddAsKey(post_op_param.param[0]);
429       } else if (post_op_param.name == "output_scale") {
430         DCHECK_EQ(post_op_param.param.size(), 1);
431         key_creator.AddAsKey(post_op_param.name);
432         key_creator.AddAsKey(post_op_param.param[0]);
433       } else {
434         return string("not_a_key");
435       }
436     }
437     return key_creator.GetKey();
438   }
439 
GetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)440   MklPrimitive* GetMklDnnMatMulFwd(
441       const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
442     string key = CreateKey(mkldnn_matmul_fwd_dims);
443     return this->GetOp(key);
444   }
445 
SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,MklPrimitive * op)446   void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims,
447                           MklPrimitive* op) {
448     string key = CreateKey(mkldnn_matmul_fwd_dims);
449     this->SetOp(key, op);
450   }
451 };
452 
453 template <class Tweight, class Toutput>
454 class MklDnnMatMulOpBase : public OpKernel {
455  public:
MklDnnMatMulOpBase(OpKernelConstruction * context)456   explicit MklDnnMatMulOpBase(OpKernelConstruction* context)
457       : OpKernel(context) {}
458   void Compute(OpKernelContext* context) override = 0;
459 
460   // Allocate output tensor.
461   virtual void AllocateOutputTensor(
462       OpKernelContext* context,
463       const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc,
464       const memory::dims& output_dims_mkl_order,
465       MklTensorFormat output_tf_format, Tensor** output_tensor,
466       bool native_format = false) {
467     DCHECK(output_tensor);
468     auto dst_pd = mkldnn_matmul_prim_desc.dst_desc();
469 
470     MklDnnShape output_mkl_shape;
471     output_mkl_shape.SetMklTensor(true);
472     output_mkl_shape.SetMklLayout(&dst_pd);
473     output_mkl_shape.SetElemType(MklDnnType<Toutput>());
474     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
475                                  output_dims_mkl_order, output_tf_format);
476 
477     TensorShape output_tf_shape;
478     output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
479 
480     if (native_format) {
481       output_tf_shape = output_mkl_shape.GetTfShape();
482     }
483     // Allocate Output Tensor
484     AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor,
485                               output_tf_shape, output_mkl_shape, native_format);
486   }
487 
488   // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
489   // be acquired before entering the function, since it is acquired
490   // inside the function.
IsWeightCacheEmpty(OpKernelContext * context)491   inline bool IsWeightCacheEmpty(OpKernelContext* context)
492       TF_LOCKS_EXCLUDED(mu_) {
493     tf_shared_lock lock(mu_);
494     return (weight_oi_.NumElements() == 0);
495   }
496 
497   // Cache the converted weight in a tensor.
498   // Only one thread can execute this method at any given time.
CacheWeight(OpKernelContext * context,const std::shared_ptr<dnnl::inner_product_forward::primitive_desc> & matmul_fwd_pd,Tweight * weight_data,const Tensor & weight_tensor,MklDnnData<Tweight> & weight,const memory::desc & weight_md)499   void CacheWeight(
500       OpKernelContext* context,
501       const std::shared_ptr<dnnl::inner_product_forward::primitive_desc>&
502           matmul_fwd_pd,
503       Tweight* weight_data, const Tensor& weight_tensor,
504       MklDnnData<Tweight>& weight, const memory::desc& weight_md)
505       TF_LOCKS_EXCLUDED(mu_) {
506     mutex_lock lock(mu_);
507     const Tensor& weight_t = weight_oi_;
508 
509     // If the weights are already cached, there's nothing to do
510     if (weight_t.NumElements() > 0) {
511       return;
512     }
513 
514     // reorder and cache the weight
515     weight.SetUsrMem(weight_md, &weight_tensor);
516     weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_desc(), cpu_engine_,
517                                context);
518     weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
519 
520     size_t weight_size = matmul_fwd_pd.get()->weights_desc().get_size();
521     TensorShape weight_tf_shape;
522     weight_tf_shape.AddDim(weight_size / sizeof(Tweight));
523 
524     OP_REQUIRES_OK(context,
525                    context->allocate_temp(DataTypeToEnum<Tweight>::value,
526                                           weight_tf_shape, &weight_oi_));
527 
528     void* weight_oi_t_data = weight.GetTensorBuffer(&weight_oi_);
529     memcpy(weight_oi_t_data, weight_data, weight_size);
530 
531     // cache the memory descriptor
532     auto expected_md = matmul_fwd_pd->weights_desc();
533     TensorShape weight_mkl_format;
534     weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight));
535 
536     OP_REQUIRES_OK(context,
537                    context->allocate_temp(DataTypeToEnum<Tweight>::value,
538                                           weight_mkl_format, &weight_oi_md_));
539     *reinterpret_cast<memory::desc*>(weight_oi_md_.flat<Tweight>().data()) =
540         expected_md;
541   }
542 
GetCachedWeight(OpKernelContext * context,const memory::desc & expected_md)543   Tweight* GetCachedWeight(OpKernelContext* context,
544                            const memory::desc& expected_md)
545       TF_LOCKS_EXCLUDED(mu_) {
546     tf_shared_lock lock(mu_);
547     const Tensor& weight_t = weight_oi_;
548     const Tensor& weight_md_t = weight_oi_md_;
549 
550     // Check if the memory descriptor of the cached weight is same as
551     // expected_md. if so use the cached memory, else return NULL
552     if (weight_md_t.flat<Tweight>().size()) {
553       const memory::desc& stored_md =
554           *(static_cast<memory::desc*>(weight_md_t.data()));
555       if (stored_md == expected_md) {
556         return static_cast<Tweight*>(
557             const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
558       }
559     }
560     return nullptr;
561   }
562 
563   engine cpu_engine_ = engine(engine::kind::cpu, 0);
564 
565  protected:
566   // Tensor to save reordered weight
567   mutex mu_;
568   Tensor weight_oi_ TF_GUARDED_BY(mu_);
569   Tensor weight_oi_md_ TF_GUARDED_BY(mu_);
570 
571   bool is_weight_const_;
572 
573   const int kInputIndexSrc = 0;
574   const int kInputIndexWeight = 1;
575   const int kInputIndexBias = 2;
576   const int kOutputIndexDst = 0;
577 };
578 
579 using dnnl::matmul;
580 
581 namespace {
582 
583 struct MklMatMulParams {
584   string prefix;
585   memory::dims a_dims;
586   memory::dims b_dims;
587   memory::dims c_dims;
588   memory::dims a_strides;
589   memory::dims b_strides;
590   memory::dims c_strides;
591 #ifdef DNNL_AARCH64_USE_ACL
592   int aarch64_counter;
593 #endif
594   struct PostOpParam {
595     string name;
596     std::vector<float> param;
597     memory::dims dims;
598     memory::data_type data_type;
599     memory::format_tag format_tag;
600   };
601   std::vector<PostOpParam> post_op_params;
602 
MklMatMulParamsMklMatMulParams603   MklMatMulParams(string prefix, memory::dims a_dims, memory::dims b_dims,
604                   memory::dims c_dims, memory::dims a_strides,
605                   memory::dims b_strides, memory::dims c_strides)
606       : prefix(prefix),
607         a_dims(a_dims),
608         b_dims(b_dims),
609         c_dims(c_dims),
610         a_strides(a_strides),
611         b_strides(b_strides),
612         c_strides(c_strides) {}
613 };
614 
615 template <typename Tlhs, typename Trhs, typename Toutput>
616 class MklMatMulPrimitive : public MklPrimitive {
617  public:
MklMatMulPrimitive(const MklMatMulParams & params)618   explicit MklMatMulPrimitive(const MklMatMulParams& params)
619       : MklPrimitive(engine(engine::kind::cpu, 0)) {
620     // Create matmul primitive
621     Setup(params);
622   }
623 
~MklMatMulPrimitive()624   ~MklMatMulPrimitive() {}
625 
GetScratchPadDesc()626   dnnl::memory::desc GetScratchPadDesc() {
627     return context_.prim_desc->scratchpad_desc();
628   }
629   void Execute(const std::shared_ptr<stream>& stream, const Tlhs* a_data,
630                const Trhs* b_data, const Toutput* c_data, void* sp_data,
631                void* mul_data = nullptr, void* add_data = nullptr) {
632 #ifdef DNNL_AARCH64_USE_ACL
633     mutex_lock lock(primitive_execution_mu_);
634 #endif
635 #ifndef ENABLE_ONEDNN_OPENMP
636     context_.a_mem->set_data_handle(
637         static_cast<void*>(const_cast<Tlhs*>(a_data)), *stream);
638     context_.b_mem->set_data_handle(
639         static_cast<void*>(const_cast<Trhs*>(b_data)), *stream);
640     context_.c_mem->set_data_handle(
641         static_cast<void*>(const_cast<Toutput*>(c_data)), *stream);
642     context_.sp_mem->set_data_handle(sp_data, *stream);
643 
644     if (mul_data != nullptr)
645       context_.mul_mem->set_data_handle(mul_data, *stream);
646     if (add_data != nullptr)
647       context_.add_mem->set_data_handle(add_data, *stream);
648 #else
649     context_.a_mem->set_data_handle(
650         static_cast<void*>(const_cast<Tlhs*>(a_data)));
651     context_.b_mem->set_data_handle(
652         static_cast<void*>(const_cast<Trhs*>(b_data)));
653     context_.c_mem->set_data_handle(
654         static_cast<void*>(const_cast<Toutput*>(c_data)));
655     context_.sp_mem->set_data_handle(sp_data);
656     if (mul_data != nullptr) context_.mul_mem->set_data_handle(mul_data);
657     if (add_data != nullptr) context_.add_mem->set_data_handle(add_data);
658 #endif  // !ENABLE_ONEDNN_OPENMP
659     execute_primitives(context_.matmul_primitives, stream, context_.net_args);
660 
661     // After execution, set data handle back
662     context_.a_mem->set_data_handle(DummyData);
663     context_.b_mem->set_data_handle(DummyData);
664     context_.c_mem->set_data_handle(DummyData);
665     context_.sp_mem->set_data_handle(DummyData);
666     if (mul_data != nullptr) context_.mul_mem->set_data_handle(DummyData);
667     if (add_data != nullptr) context_.add_mem->set_data_handle(DummyData);
668   }
669 
670  private:
671   // Primitive reuse context for MatMul op
672   struct MklMatMulContext {
673     // oneDNN memory.
674     std::shared_ptr<dnnl::memory> a_mem;
675     std::shared_ptr<dnnl::memory> b_mem;
676     std::shared_ptr<dnnl::memory> c_mem;
677     std::shared_ptr<dnnl::memory> mul_mem;
678     std::shared_ptr<dnnl::memory> add_mem;
679     std::shared_ptr<dnnl::memory> sp_mem;
680 
681     // Descriptor and primitive-descriptor for MatMul.
682     std::shared_ptr<matmul::desc> desc;
683     std::shared_ptr<matmul::primitive_desc> prim_desc;
684 
685     // Memory descriptors.
686     std::shared_ptr<dnnl::memory::desc> a_md;
687     std::shared_ptr<dnnl::memory::desc> b_md;
688     std::shared_ptr<dnnl::memory::desc> c_md;
689     std::shared_ptr<dnnl::memory::desc> mul_md;
690     std::shared_ptr<dnnl::memory::desc> add_md;
691 
692     // MatMul primitive.
693     std::vector<dnnl::primitive> matmul_primitives;
694     std::vector<std::unordered_map<int, memory>> net_args;
695 
MklMatMulContextMklMatMulContext696     MklMatMulContext()
697         : a_mem(nullptr),
698           b_mem(nullptr),
699           c_mem(nullptr),
700           mul_mem(nullptr),
701           add_mem(nullptr),
702           sp_mem(nullptr),
703           desc(nullptr),
704           prim_desc(nullptr),
705           a_md(nullptr),
706           b_md(nullptr),
707           c_md(nullptr),
708           mul_md(nullptr),
709           add_md(nullptr) {}
710   };
711 
Setup(const MklMatMulParams & params)712   void Setup(const MklMatMulParams& params) {
713     std::shared_ptr<dnnl::primitive> matmul_primitive = nullptr;
714 
715     // Create MatMul descriptor and primitive descriptor.
716     context_.a_md.reset(new memory::desc({params.a_dims}, MklDnnType<Tlhs>(),
717                                          params.a_strides));
718 
719     context_.b_md.reset(new memory::desc({params.b_dims}, MklDnnType<Trhs>(),
720                                          params.b_strides));
721 
722     context_.c_md.reset(new memory::desc({params.c_dims}, MklDnnType<Toutput>(),
723                                          params.c_strides));
724 
725     // Create matmul.
726     context_.desc.reset(
727         new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md));
728 
729     // Check if there is any fusion as post-ops
730     auto const& post_op_params = params.post_op_params;
731     dnnl::primitive_attr post_ops_attr;
732     dnnl::post_ops post_ops;
733     if (!post_op_params.empty()) {
734       for (auto const& post_op_param : post_op_params) {
735         if (post_op_param.name == "output_scale") {
736           DCHECK_EQ(post_op_param.param.size(), 1);
737           std::vector<float> scales;
738           scales.push_back(post_op_param.param[0]);
739           post_ops_attr.set_output_scales(0, scales);
740         } else if (post_op_param.name == "mul") {
741           context_.mul_md.reset(new memory::desc({post_op_param.dims},
742                                                  post_op_param.data_type,
743                                                  post_op_param.format_tag));
744           post_ops.append_binary(dnnl::algorithm::binary_mul, *context_.mul_md);
745         } else if (post_op_param.name == "add") {
746           context_.add_md.reset(new memory::desc({post_op_param.dims},
747                                                  post_op_param.data_type,
748                                                  post_op_param.format_tag));
749           post_ops.append_binary(dnnl::algorithm::binary_add, *context_.add_md);
750         } else {
751           DCHECK((post_op_param.name == "output_scale"));
752         }
753       }
754       post_ops_attr.set_post_ops(post_ops);
755     }
756     post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
757     context_.prim_desc.reset(
758         new matmul::primitive_desc(*context_.desc, post_ops_attr, cpu_engine_));
759 
760     // Create memory primitive based on dummy data.
761     context_.a_mem.reset(
762         new dnnl::memory(*context_.a_md, cpu_engine_, DummyData));
763     context_.b_mem.reset(
764         new dnnl::memory(*context_.b_md, cpu_engine_, DummyData));
765     context_.c_mem.reset(
766         new dnnl::memory(*context_.b_md, cpu_engine_, DummyData));
767     auto scratchpad_md = context_.prim_desc->scratchpad_desc();
768     context_.sp_mem.reset(
769         new dnnl::memory(scratchpad_md, cpu_engine_, DummyData));
770 
771     // Create matmul primitive.
772     matmul_primitive.reset(new dnnl::matmul(*context_.prim_desc));
773     context_.net_args.push_back({{DNNL_ARG_SRC, *context_.a_mem},
774                                  {DNNL_ARG_WEIGHTS, *context_.b_mem},
775                                  {DNNL_ARG_SCRATCHPAD, *context_.sp_mem},
776                                  {DNNL_ARG_DST, *context_.c_mem}});
777     if (!post_op_params.empty()) {
778       int count = 0;
779       for (auto const& post_op_param : post_op_params) {
780         if (post_op_param.name == "mul") {
781           context_.mul_mem.reset(
782               new dnnl::memory(*context_.mul_md, cpu_engine_, DummyData));
783           context_.net_args[0].insert(
784               {DNNL_ARG_ATTR_MULTIPLE_POST_OP(count) | DNNL_ARG_SRC_1,
785                *context_.mul_mem});
786           count++;
787         } else if (post_op_param.name == "add") {
788           context_.add_mem.reset(
789               new dnnl::memory(*context_.add_md, cpu_engine_, DummyData));
790           context_.net_args[0].insert(
791               {DNNL_ARG_ATTR_MULTIPLE_POST_OP(count) | DNNL_ARG_SRC_1,
792                *context_.add_mem});
793           count++;
794         }
795       }
796     }
797 
798     context_.matmul_primitives.push_back(*matmul_primitive);
799     return;
800   }
801 
802   struct MklMatMulContext context_;
803 #ifdef DNNL_AARCH64_USE_ACL
804   mutex primitive_execution_mu_;
805 #endif
806 };
807 
808 template <typename T, typename Tlhs, typename Trhs, typename Toutput>
809 class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> {
810  public:
Get(const MklMatMulParams & params,bool do_not_cache)811   static MklMatMulPrimitive<Tlhs, Trhs, Toutput>* Get(
812       const MklMatMulParams& params, bool do_not_cache) {
813     MklMatMulPrimitive<Tlhs, Trhs, Toutput>* matmul_prim = nullptr;
814 
815     if (do_not_cache) {
816       // Always create new primitive
817       matmul_prim = new MklMatMulPrimitive<Tlhs, Trhs, Toutput>(params);
818     } else {
819       // Try to find a suitable one in pool
820       matmul_prim = dynamic_cast<MklMatMulPrimitive<Tlhs, Trhs, Toutput>*>(
821           MklMatMulPrimitiveFactory<T, Tlhs, Trhs, Toutput>::GetInstance()
822               .GetMklMatMul(params));
823       if (matmul_prim == nullptr) {
824         matmul_prim = new MklMatMulPrimitive<Tlhs, Trhs, Toutput>(params);
825         MklMatMulPrimitiveFactory<T, Tlhs, Trhs, Toutput>::GetInstance()
826             .SetMklMatMul(params, matmul_prim);
827       }
828     }
829 
830     return matmul_prim;
831   }
832 
833  private:
MklMatMulPrimitiveFactory()834   MklMatMulPrimitiveFactory() {}
~MklMatMulPrimitiveFactory()835   ~MklMatMulPrimitiveFactory() {}
836 
GetInstance()837   static MklMatMulPrimitiveFactory& GetInstance() {
838     static MklMatMulPrimitiveFactory instance_;
839     return instance_;
840   }
841 
CreateKey(const MklMatMulParams & params)842   static string CreateKey(const MklMatMulParams& params) {
843     FactoryKeyCreator key_creator;
844     key_creator.AddAsKey(params.prefix);
845     key_creator.AddAsKey(params.a_dims);
846     key_creator.AddAsKey(params.b_dims);
847     key_creator.AddAsKey(params.c_dims);
848     key_creator.AddAsKey(params.a_strides);
849     key_creator.AddAsKey(params.b_strides);
850     key_creator.AddAsKey(params.c_strides);
851     key_creator.AddAsKey(typeid(T).name());
852 #ifdef DNNL_AARCH64_USE_ACL
853     key_creator.AddAsKey(params.aarch64_counter);
854 #endif
855     key_creator.AddAsKey(typeid(Tlhs).name());
856     key_creator.AddAsKey(typeid(Trhs).name());
857     key_creator.AddAsKey(typeid(Toutput).name());
858 
859     // Generate keys for post-ops
860     for (auto const& post_op_param : params.post_op_params) {
861       if (post_op_param.name == "output_scale") {
862         DCHECK_EQ(post_op_param.param.size(), 1);
863         key_creator.AddAsKey(post_op_param.name);
864         key_creator.AddAsKey(post_op_param.param[0]);
865       } else if (post_op_param.name == "mul" || post_op_param.name == "add") {
866         key_creator.AddAsKey(post_op_param.name);
867         key_creator.AddAsKey(post_op_param.dims);
868       } else {
869         return string("not_a_key");
870       }
871     }
872     return key_creator.GetKey();
873   }
874 
GetMklMatMul(const MklMatMulParams & params)875   MklPrimitive* GetMklMatMul(const MklMatMulParams& params) {
876     string key = CreateKey(params);
877     return this->GetOp(key);
878   }
879 
SetMklMatMul(const MklMatMulParams & params,MklPrimitive * op)880   void SetMklMatMul(const MklMatMulParams& params, MklPrimitive* op) {
881     string key = CreateKey(params);
882     this->SetOp(key, op);
883   }
884 };
885 
886 template <typename T>
887 void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
888                float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
889                float beta, T* c, int64_t ldc, OpKernelContext* ctx = nullptr) {
890   using dims = dnnl::memory::dims;
891 
892   // Prepare strides based on the transa and transb flags: transposed
893   // matrices have strides swapped
894   dims a_dims = dims{m, k};
895   dims b_dims = dims{k, n};
896   dims c_dims = dims{m, n};
897   dims a_strides = tolower(transa) == 'n' ? dims{lda, 1} : dims{1, lda};
898   dims b_strides = tolower(transb) == 'n' ? dims{ldb, 1} : dims{1, ldb};
899   dims c_strides = dims{ldc, 1};
900 
901   // MklMatMul uses const alpha and beta, make guarantee here to ensure
902   // they are never changed.
903   DCHECK_EQ(alpha, 1.0f);
904   DCHECK_EQ(beta, 0.f);
905 
906   MklMatMulParams params("dnnl_gemm", a_dims, b_dims, c_dims, a_strides,
907                          b_strides, c_strides);
908   MklMatMulPrimitive<T, T, T>* matmul_prim =
909       MklMatMulPrimitiveFactory<T, T, T, T>::Get(params, 0);
910 
911   UserScratchPad<unsigned char> scratch_pad;
912   scratch_pad.AllocateSPTensor(matmul_prim, ctx);
913   // Execute matmul primitive.
914   auto st = ExecuteSingleThreadedGemm(m, n, k, sizeof(T));
915   std::shared_ptr<stream> cpu_stream;
916   MklDnnThreadPool eigen_tp(ctx, st ? 1 : -1);
917   cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine()));
918   matmul_prim->Execute(cpu_stream, a, b, c, scratch_pad.Get());
919 }
920 
921 }  // anonymous namespace
922 
923 }  // namespace tensorflow
924 
925 #endif  // INTEL_MKL
926 #endif  // TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
927