1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
17 #define TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
18
19 #ifdef INTEL_MKL
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "dnnl.hpp"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/util/mkl_util.h"
29 #include "tensorflow/core/util/onednn_env_vars.h"
30 #ifdef DNNL_AARCH64_USE_ACL
31 #include "tensorflow/core/platform/hash.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #endif
34
35 using dnnl::inner_product_forward;
36 using dnnl::primitive_attr;
37 using dnnl::prop_kind;
38 using dnnl::stream;
39
40 namespace tensorflow {
41 static Eigen::internal::CacheSizes cache_sizes = Eigen::internal::CacheSizes();
42
43 typedef Eigen::ThreadPoolDevice CPUDevice;
ExecuteSingleThreadedGemm(int m,int n,int k,int bytes)44 inline bool ExecuteSingleThreadedGemm(int m, int n, int k, int bytes) {
45 // Ideally we would like to determine blocking and then come up with
46 // a heuristic but what we are targeting are very small models whose
47 // total size is < x*L2. So we will do this simple calculation
48 // to determine if the matrix multiplication should be run on a single thread.
49 // TODO(Intel-tf): this needs to be vastly improved, perhaps at a lower level
50 // than the integration.
51 ptrdiff_t l2_size = cache_sizes.m_l2;
52 constexpr float kHeuristicMultiplier = 1.01;
53 const float mul_size = bytes * (m * n + k * (m + n));
54 const float l2_heur = l2_size * kHeuristicMultiplier;
55 return mul_size < l2_heur;
56 }
57
58 // This structure aggregates multiple inputs to MklDnnMatMul* methods.
59 struct MklDnnMatMulFwdParams {
60 memory::dims src_dims;
61 memory::dims weight_dims;
62 memory::dims bias_dims;
63 memory::dims dst_dims;
64 memory::format_tag src_format;
65 memory::format_tag weight_format;
66 memory::format_tag dst_format;
67 string dtypes = string("");
68 bool const_weight;
69 #ifdef DNNL_AARCH64_USE_ACL
70 uint64 weight_hash;
71 #endif
72 struct PostOpParam {
73 string name;
74 std::vector<float> param;
75 };
76 std::vector<PostOpParam> post_op_params;
77
78 MklDnnMatMulFwdParams(
79 memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims,
80 memory::dims dst_dims,
81 memory::format_tag src_format = memory::format_tag::any,
82 memory::format_tag weight_format = memory::format_tag::any,
83 memory::format_tag dst_format = memory::format_tag::any,
84 bool const_weight = false)
src_dimsMklDnnMatMulFwdParams85 : src_dims(src_dims),
86 weight_dims(weight_dims),
87 bias_dims(bias_dims),
88 dst_dims(dst_dims),
89 src_format(src_format),
90 weight_format(weight_format),
91 dst_format(dst_format),
92 const_weight(const_weight) {}
93 };
94
95 // With quantization, input, weight, bias, and output can have different types.
96 // So we use different template parameters for each type.
97 // TODO(intel-tf): The template type "T" is currently used to match the
98 // templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h).
99 // In the future, with the removal of "T" from MklPrimitiveFactory, this class
100 // needs to drop "T".
101 template <typename T, typename Tinput, typename Tweight, typename Tbias,
102 typename Toutput>
103 class MklDnnMatMulFwdPrimitive : public MklPrimitive {
104 public:
MklDnnMatMulFwdPrimitive(const MklDnnMatMulFwdParams & matmulFwdParams)105 explicit MklDnnMatMulFwdPrimitive(
106 const MklDnnMatMulFwdParams& matmulFwdParams)
107 : MklPrimitive(engine(engine::kind::cpu, 0)) {
108 // Create matmul primitive
109 if (context_.matmul_fwd == nullptr) {
110 Setup(matmulFwdParams);
111 }
112 }
113
~MklDnnMatMulFwdPrimitive()114 ~MklDnnMatMulFwdPrimitive() {}
115
GetScratchPadDesc()116 dnnl::memory::desc GetScratchPadDesc() {
117 return context_.fwd_pd->scratchpad_desc();
118 }
119
120 // Inner-product forward execute with bias:
121 // - src_data: input data buffer of src
122 // - weight_data: input data buffer of weight
123 // - bias_data: input data buffer of bias
124 // - dst_data: output data buffer of dst
125 // - sp_data: scratchpad data
Execute(const Tinput * src_data,const Tweight * weight_data,const Tbias * bias_data,Toutput * dst_data,void * sp_data,std::shared_ptr<stream> fwd_stream)126 void Execute(const Tinput* src_data, const Tweight* weight_data,
127 const Tbias* bias_data, Toutput* dst_data, void* sp_data,
128 std::shared_ptr<stream> fwd_stream) {
129 #ifdef DNNL_AARCH64_USE_ACL
130 mutex_lock lock(primitive_execution_mu_);
131 #endif
132 #ifndef ENABLE_ONEDNN_OPENMP
133 context_.src_mem->set_data_handle(
134 static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
135 context_.weight_mem->set_data_handle(
136 static_cast<void*>(const_cast<Tweight*>(weight_data)), *fwd_stream);
137 context_.bias_mem->set_data_handle(
138 static_cast<void*>(const_cast<Tbias*>(bias_data)));
139 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
140 *fwd_stream);
141 context_.sp_mem->set_data_handle(sp_data, *fwd_stream);
142 #else
143 context_.src_mem->set_data_handle(
144 static_cast<void*>(const_cast<Tinput*>(src_data)));
145 context_.weight_mem->set_data_handle(
146 static_cast<void*>(const_cast<Tweight*>(weight_data)));
147 context_.bias_mem->set_data_handle(
148 static_cast<void*>(const_cast<Tbias*>(bias_data)));
149 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
150 context_.sp_mem->set_data_handle(sp_data);
151 #endif // !ENABLE_ONEDNN_OPENMP
152
153 execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
154
155 // After execution, set data handle back
156 context_.src_mem->set_data_handle(DummyData);
157 context_.weight_mem->set_data_handle(DummyData);
158 context_.bias_mem->set_data_handle(DummyData);
159 context_.dst_mem->set_data_handle(DummyData);
160 }
161
162 std::shared_ptr<dnnl::inner_product_forward::primitive_desc>
GetPrimitiveDesc()163 GetPrimitiveDesc() const {
164 return context_.fwd_pd;
165 }
166
167 private:
168 // Primitive reuse context for inner-product Fwd op
169 struct MklDnnMatMulFwdContext {
170 // oneDNN memory.
171 std::shared_ptr<dnnl::memory> src_mem;
172 std::shared_ptr<dnnl::memory> weight_mem;
173 std::shared_ptr<dnnl::memory> bias_mem;
174 std::shared_ptr<dnnl::memory> dst_mem;
175 std::shared_ptr<dnnl::memory> sp_mem;
176
177 // Descriptor and primitive-descriptor for forward inner-product.
178 std::shared_ptr<dnnl::inner_product_forward::desc> fwd_desc;
179 std::shared_ptr<dnnl::inner_product_forward::primitive_desc> fwd_pd;
180
181 // Memory descriptors.
182 std::shared_ptr<dnnl::memory::desc> src_md;
183 std::shared_ptr<dnnl::memory::desc> weight_md;
184 std::shared_ptr<dnnl::memory::desc> bias_md;
185 std::shared_ptr<dnnl::memory::desc> dst_md;
186
187 // Inner-product primitive.
188 std::shared_ptr<dnnl::primitive> matmul_fwd;
189 std::vector<dnnl::primitive> fwd_primitives;
190
191 std::vector<std::unordered_map<int, memory>> net_args;
192
MklDnnMatMulFwdContextMklDnnMatMulFwdContext193 MklDnnMatMulFwdContext()
194 : src_mem(nullptr),
195 weight_mem(nullptr),
196 bias_mem(nullptr),
197 dst_mem(nullptr),
198 sp_mem(nullptr),
199 fwd_desc(nullptr),
200 fwd_pd(nullptr),
201 src_md(nullptr),
202 weight_md(nullptr),
203 bias_md(nullptr),
204 dst_md(nullptr),
205 matmul_fwd(nullptr) {}
206 };
207
Setup(const MklDnnMatMulFwdParams & matmul_fwd_params)208 void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) {
209 // Create memory descriptors for inner-product data without specified
210 // format.
211 context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims},
212 MklDnnType<Tinput>(),
213 matmul_fwd_params.src_format));
214
215 context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims},
216 MklDnnType<Tweight>(),
217 matmul_fwd_params.weight_format));
218
219 context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims},
220 MklDnnType<Toutput>(),
221 matmul_fwd_params.dst_format));
222
223 context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims},
224 MklDnnType<Tbias>(),
225 memory::format_tag::any));
226 // Create an inner-product.
227 context_.fwd_desc.reset(new inner_product_forward::desc(
228 matmul_fwd_params.const_weight ? prop_kind::forward_inference
229 : prop_kind::forward_training,
230 *context_.src_md, *context_.weight_md, *context_.bias_md,
231 *context_.dst_md));
232 context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
233 *context_.fwd_desc, cpu_engine_));
234
235 // Check if there is any fusion as post-ops
236 auto const& post_op_params = matmul_fwd_params.post_op_params;
237 dnnl::primitive_attr post_ops_attr;
238 post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
239 dnnl::post_ops post_ops;
240 if (!post_op_params.empty()) {
241 for (auto const& post_op_param : post_op_params) {
242 if (post_op_param.name == "relu" || post_op_param.name == "leakyrelu") {
243 DCHECK_EQ(post_op_param.param.size(), 3);
244 float op_scale = post_op_param.param[0];
245 float op_alpha = post_op_param.param[1];
246 float op_beta = post_op_param.param[2];
247 post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_relu,
248 op_alpha, op_beta);
249 } else if (post_op_param.name == "relu6") {
250 DCHECK_EQ(post_op_param.param.size(), 3);
251 float op_scale = post_op_param.param[0];
252 float op_alpha = post_op_param.param[1];
253 float op_beta = post_op_param.param[2];
254 post_ops.append_eltwise(op_scale,
255 dnnl::algorithm::eltwise_bounded_relu,
256 op_alpha, op_beta);
257 } else if (post_op_param.name == "elu") {
258 DCHECK_EQ(post_op_param.param.size(), 3);
259 float op_scale = post_op_param.param[0];
260 float op_alpha = post_op_param.param[1];
261 float op_beta = post_op_param.param[2];
262 post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_elu,
263 op_alpha, op_beta);
264 } else if (post_op_param.name == "gelu_approximate") {
265 DCHECK_EQ(post_op_param.param.size(), 3);
266 float op_scale = post_op_param.param[0];
267 float op_alpha = post_op_param.param[1];
268 float op_beta = post_op_param.param[2];
269 post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_gelu_tanh,
270 op_alpha, op_beta);
271 } else if (post_op_param.name == "gelu_exact") {
272 DCHECK_EQ(post_op_param.param.size(), 3);
273 float op_scale = post_op_param.param[0];
274 float op_alpha = post_op_param.param[1];
275 float op_beta = post_op_param.param[2];
276 post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_gelu_erf,
277 op_alpha, op_beta);
278 } else if (post_op_param.name == "tanh") {
279 DCHECK_EQ(post_op_param.param.size(), 3);
280 float op_scale = post_op_param.param[0];
281 float op_alpha = post_op_param.param[1];
282 float op_beta = post_op_param.param[2];
283 post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_tanh,
284 op_alpha, op_beta);
285 } else if (post_op_param.name == "logistic") {
286 DCHECK_EQ(post_op_param.param.size(), 3);
287 float op_scale = post_op_param.param[0];
288 float op_alpha = post_op_param.param[1];
289 float op_beta = post_op_param.param[2];
290 post_ops.append_eltwise(op_scale, dnnl::algorithm::eltwise_logistic,
291 op_alpha, op_beta);
292 } else if (post_op_param.name == "output_scale") {
293 DCHECK_EQ(post_op_param.param.size(), 1);
294 std::vector<float> scales;
295 scales.push_back(post_op_param.param[0]);
296 post_ops_attr.set_output_scales(0, scales);
297 } else if (post_op_param.name == "sum") {
298 DCHECK_EQ(post_op_param.param.size(), 1);
299 float op_scale = post_op_param.param[0];
300 post_ops.append_sum(op_scale);
301
302 } else {
303 DCHECK((post_op_param.name == "relu") ||
304 (post_op_param.name == "relu6") ||
305 (post_op_param.name == "elu") ||
306 (post_op_param.name == "tanh") ||
307 (post_op_param.name == "logistic") ||
308 (post_op_param.name == "sum") ||
309 (post_op_param.name == "leakyrelu") ||
310 (post_op_param.name == "output_scale"));
311 }
312 }
313 post_ops_attr.set_post_ops(post_ops);
314 context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
315 *context_.fwd_desc, post_ops_attr, cpu_engine_));
316 } else {
317 context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
318 *context_.fwd_desc, post_ops_attr, cpu_engine_));
319 }
320
321 // Create memory primitive based on dummy data
322 context_.src_mem.reset(
323 new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
324 context_.weight_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
325 cpu_engine_, DummyData));
326 context_.dst_mem.reset(
327 new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
328 context_.bias_mem.reset(new memory({{matmul_fwd_params.bias_dims},
329 MklDnnType<Tbias>(),
330 memory::format_tag::x},
331 cpu_engine_, DummyData));
332 auto scratchpad_md = context_.fwd_pd->scratchpad_desc();
333 context_.sp_mem.reset(
334 new dnnl::memory(scratchpad_md, cpu_engine_, DummyData));
335
336 // Create inner-product primitive.
337 context_.matmul_fwd.reset(new inner_product_forward(*context_.fwd_pd));
338 context_.net_args.push_back({{DNNL_ARG_SRC, *context_.src_mem},
339 {DNNL_ARG_WEIGHTS, *context_.weight_mem},
340 {DNNL_ARG_BIAS, *context_.bias_mem},
341 {DNNL_ARG_SCRATCHPAD, *context_.sp_mem},
342 {DNNL_ARG_DST, *context_.dst_mem}});
343
344 context_.fwd_primitives.push_back(*context_.matmul_fwd);
345 return;
346 }
347
348 struct MklDnnMatMulFwdContext context_;
349
350 #ifdef DNNL_AARCH64_USE_ACL
351 // Guards Execution()
352 mutex primitive_execution_mu_;
353 #endif
354 };
355
356 template <typename T, typename Tinput, typename Tweight, typename Tbias,
357 typename Toutput>
358 class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
359 public:
Get(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,bool do_not_cache)360 static MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* Get(
361 const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) {
362 MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>* matmul_fwd =
363 nullptr;
364
365 if (do_not_cache) {
366 // Always create new primitive
367 matmul_fwd =
368 new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
369 mkldnn_matmul_fwd_dims);
370 } else {
371 // Try to find a suitable one in pool
372 matmul_fwd = dynamic_cast<
373 MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>*>(
374 MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
375 Toutput>::GetInstance()
376 .GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims));
377 if (matmul_fwd == nullptr) {
378 matmul_fwd =
379 new MklDnnMatMulFwdPrimitive<T, Tinput, Tweight, Tbias, Toutput>(
380 mkldnn_matmul_fwd_dims);
381 MklDnnMatMulFwdPrimitiveFactory<T, Tinput, Tweight, Tbias,
382 Toutput>::GetInstance()
383 .SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd);
384 }
385 }
386 return matmul_fwd;
387 }
388
389 private:
MklDnnMatMulFwdPrimitiveFactory()390 MklDnnMatMulFwdPrimitiveFactory() {}
~MklDnnMatMulFwdPrimitiveFactory()391 ~MklDnnMatMulFwdPrimitiveFactory() {}
392
GetInstance()393 static MklDnnMatMulFwdPrimitiveFactory& GetInstance() {
394 static MklDnnMatMulFwdPrimitiveFactory instance_;
395 return instance_;
396 }
397
CreateKey(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)398 static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
399 string prefix = "matmul_fwd_";
400 FactoryKeyCreator key_creator;
401 key_creator.AddAsKey(prefix);
402 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims);
403 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims);
404 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
405 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
406 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
407 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format);
408 #ifdef DNNL_AARCH64_USE_ACL
409 key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_hash);
410 #endif
411
412 // Generate keys for post-ops
413 for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
414 if (post_op_param.name == "relu" || post_op_param.name == "relu6" ||
415 post_op_param.name == "elu" || post_op_param.name == "tanh" ||
416 post_op_param.name == "logistic" ||
417 post_op_param.name == "leakyrelu" ||
418 post_op_param.name == "gelu_approximate" ||
419 post_op_param.name == "gelu_exact") {
420 DCHECK_EQ(post_op_param.param.size(), 3);
421 key_creator.AddAsKey(post_op_param.name);
422 key_creator.AddAsKey(post_op_param.param[0]);
423 key_creator.AddAsKey(post_op_param.param[1]);
424 key_creator.AddAsKey(post_op_param.param[2]);
425 } else if (post_op_param.name == "sum") {
426 DCHECK_EQ(post_op_param.param.size(), 1);
427 key_creator.AddAsKey(post_op_param.name);
428 key_creator.AddAsKey(post_op_param.param[0]);
429 } else if (post_op_param.name == "output_scale") {
430 DCHECK_EQ(post_op_param.param.size(), 1);
431 key_creator.AddAsKey(post_op_param.name);
432 key_creator.AddAsKey(post_op_param.param[0]);
433 } else {
434 return string("not_a_key");
435 }
436 }
437 return key_creator.GetKey();
438 }
439
GetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims)440 MklPrimitive* GetMklDnnMatMulFwd(
441 const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) {
442 string key = CreateKey(mkldnn_matmul_fwd_dims);
443 return this->GetOp(key);
444 }
445
SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams & mkldnn_matmul_fwd_dims,MklPrimitive * op)446 void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims,
447 MklPrimitive* op) {
448 string key = CreateKey(mkldnn_matmul_fwd_dims);
449 this->SetOp(key, op);
450 }
451 };
452
453 template <class Tweight, class Toutput>
454 class MklDnnMatMulOpBase : public OpKernel {
455 public:
MklDnnMatMulOpBase(OpKernelConstruction * context)456 explicit MklDnnMatMulOpBase(OpKernelConstruction* context)
457 : OpKernel(context) {}
458 void Compute(OpKernelContext* context) override = 0;
459
460 // Allocate output tensor.
461 virtual void AllocateOutputTensor(
462 OpKernelContext* context,
463 const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc,
464 const memory::dims& output_dims_mkl_order,
465 MklTensorFormat output_tf_format, Tensor** output_tensor,
466 bool native_format = false) {
467 DCHECK(output_tensor);
468 auto dst_pd = mkldnn_matmul_prim_desc.dst_desc();
469
470 MklDnnShape output_mkl_shape;
471 output_mkl_shape.SetMklTensor(true);
472 output_mkl_shape.SetMklLayout(&dst_pd);
473 output_mkl_shape.SetElemType(MklDnnType<Toutput>());
474 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
475 output_dims_mkl_order, output_tf_format);
476
477 TensorShape output_tf_shape;
478 output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
479
480 if (native_format) {
481 output_tf_shape = output_mkl_shape.GetTfShape();
482 }
483 // Allocate Output Tensor
484 AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor,
485 output_tf_shape, output_mkl_shape, native_format);
486 }
487
488 // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
489 // be acquired before entering the function, since it is acquired
490 // inside the function.
IsWeightCacheEmpty(OpKernelContext * context)491 inline bool IsWeightCacheEmpty(OpKernelContext* context)
492 TF_LOCKS_EXCLUDED(mu_) {
493 tf_shared_lock lock(mu_);
494 return (weight_oi_.NumElements() == 0);
495 }
496
497 // Cache the converted weight in a tensor.
498 // Only one thread can execute this method at any given time.
CacheWeight(OpKernelContext * context,const std::shared_ptr<dnnl::inner_product_forward::primitive_desc> & matmul_fwd_pd,Tweight * weight_data,const Tensor & weight_tensor,MklDnnData<Tweight> & weight,const memory::desc & weight_md)499 void CacheWeight(
500 OpKernelContext* context,
501 const std::shared_ptr<dnnl::inner_product_forward::primitive_desc>&
502 matmul_fwd_pd,
503 Tweight* weight_data, const Tensor& weight_tensor,
504 MklDnnData<Tweight>& weight, const memory::desc& weight_md)
505 TF_LOCKS_EXCLUDED(mu_) {
506 mutex_lock lock(mu_);
507 const Tensor& weight_t = weight_oi_;
508
509 // If the weights are already cached, there's nothing to do
510 if (weight_t.NumElements() > 0) {
511 return;
512 }
513
514 // reorder and cache the weight
515 weight.SetUsrMem(weight_md, &weight_tensor);
516 weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_desc(), cpu_engine_,
517 context);
518 weight_data = static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
519
520 size_t weight_size = matmul_fwd_pd.get()->weights_desc().get_size();
521 TensorShape weight_tf_shape;
522 weight_tf_shape.AddDim(weight_size / sizeof(Tweight));
523
524 OP_REQUIRES_OK(context,
525 context->allocate_temp(DataTypeToEnum<Tweight>::value,
526 weight_tf_shape, &weight_oi_));
527
528 void* weight_oi_t_data = weight.GetTensorBuffer(&weight_oi_);
529 memcpy(weight_oi_t_data, weight_data, weight_size);
530
531 // cache the memory descriptor
532 auto expected_md = matmul_fwd_pd->weights_desc();
533 TensorShape weight_mkl_format;
534 weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight));
535
536 OP_REQUIRES_OK(context,
537 context->allocate_temp(DataTypeToEnum<Tweight>::value,
538 weight_mkl_format, &weight_oi_md_));
539 *reinterpret_cast<memory::desc*>(weight_oi_md_.flat<Tweight>().data()) =
540 expected_md;
541 }
542
GetCachedWeight(OpKernelContext * context,const memory::desc & expected_md)543 Tweight* GetCachedWeight(OpKernelContext* context,
544 const memory::desc& expected_md)
545 TF_LOCKS_EXCLUDED(mu_) {
546 tf_shared_lock lock(mu_);
547 const Tensor& weight_t = weight_oi_;
548 const Tensor& weight_md_t = weight_oi_md_;
549
550 // Check if the memory descriptor of the cached weight is same as
551 // expected_md. if so use the cached memory, else return NULL
552 if (weight_md_t.flat<Tweight>().size()) {
553 const memory::desc& stored_md =
554 *(static_cast<memory::desc*>(weight_md_t.data()));
555 if (stored_md == expected_md) {
556 return static_cast<Tweight*>(
557 const_cast<Tweight*>(weight_t.flat<Tweight>().data()));
558 }
559 }
560 return nullptr;
561 }
562
563 engine cpu_engine_ = engine(engine::kind::cpu, 0);
564
565 protected:
566 // Tensor to save reordered weight
567 mutex mu_;
568 Tensor weight_oi_ TF_GUARDED_BY(mu_);
569 Tensor weight_oi_md_ TF_GUARDED_BY(mu_);
570
571 bool is_weight_const_;
572
573 const int kInputIndexSrc = 0;
574 const int kInputIndexWeight = 1;
575 const int kInputIndexBias = 2;
576 const int kOutputIndexDst = 0;
577 };
578
579 using dnnl::matmul;
580
581 namespace {
582
583 struct MklMatMulParams {
584 string prefix;
585 memory::dims a_dims;
586 memory::dims b_dims;
587 memory::dims c_dims;
588 memory::dims a_strides;
589 memory::dims b_strides;
590 memory::dims c_strides;
591 #ifdef DNNL_AARCH64_USE_ACL
592 int aarch64_counter;
593 #endif
594 struct PostOpParam {
595 string name;
596 std::vector<float> param;
597 memory::dims dims;
598 memory::data_type data_type;
599 memory::format_tag format_tag;
600 };
601 std::vector<PostOpParam> post_op_params;
602
MklMatMulParamsMklMatMulParams603 MklMatMulParams(string prefix, memory::dims a_dims, memory::dims b_dims,
604 memory::dims c_dims, memory::dims a_strides,
605 memory::dims b_strides, memory::dims c_strides)
606 : prefix(prefix),
607 a_dims(a_dims),
608 b_dims(b_dims),
609 c_dims(c_dims),
610 a_strides(a_strides),
611 b_strides(b_strides),
612 c_strides(c_strides) {}
613 };
614
615 template <typename Tlhs, typename Trhs, typename Toutput>
616 class MklMatMulPrimitive : public MklPrimitive {
617 public:
MklMatMulPrimitive(const MklMatMulParams & params)618 explicit MklMatMulPrimitive(const MklMatMulParams& params)
619 : MklPrimitive(engine(engine::kind::cpu, 0)) {
620 // Create matmul primitive
621 Setup(params);
622 }
623
~MklMatMulPrimitive()624 ~MklMatMulPrimitive() {}
625
GetScratchPadDesc()626 dnnl::memory::desc GetScratchPadDesc() {
627 return context_.prim_desc->scratchpad_desc();
628 }
629 void Execute(const std::shared_ptr<stream>& stream, const Tlhs* a_data,
630 const Trhs* b_data, const Toutput* c_data, void* sp_data,
631 void* mul_data = nullptr, void* add_data = nullptr) {
632 #ifdef DNNL_AARCH64_USE_ACL
633 mutex_lock lock(primitive_execution_mu_);
634 #endif
635 #ifndef ENABLE_ONEDNN_OPENMP
636 context_.a_mem->set_data_handle(
637 static_cast<void*>(const_cast<Tlhs*>(a_data)), *stream);
638 context_.b_mem->set_data_handle(
639 static_cast<void*>(const_cast<Trhs*>(b_data)), *stream);
640 context_.c_mem->set_data_handle(
641 static_cast<void*>(const_cast<Toutput*>(c_data)), *stream);
642 context_.sp_mem->set_data_handle(sp_data, *stream);
643
644 if (mul_data != nullptr)
645 context_.mul_mem->set_data_handle(mul_data, *stream);
646 if (add_data != nullptr)
647 context_.add_mem->set_data_handle(add_data, *stream);
648 #else
649 context_.a_mem->set_data_handle(
650 static_cast<void*>(const_cast<Tlhs*>(a_data)));
651 context_.b_mem->set_data_handle(
652 static_cast<void*>(const_cast<Trhs*>(b_data)));
653 context_.c_mem->set_data_handle(
654 static_cast<void*>(const_cast<Toutput*>(c_data)));
655 context_.sp_mem->set_data_handle(sp_data);
656 if (mul_data != nullptr) context_.mul_mem->set_data_handle(mul_data);
657 if (add_data != nullptr) context_.add_mem->set_data_handle(add_data);
658 #endif // !ENABLE_ONEDNN_OPENMP
659 execute_primitives(context_.matmul_primitives, stream, context_.net_args);
660
661 // After execution, set data handle back
662 context_.a_mem->set_data_handle(DummyData);
663 context_.b_mem->set_data_handle(DummyData);
664 context_.c_mem->set_data_handle(DummyData);
665 context_.sp_mem->set_data_handle(DummyData);
666 if (mul_data != nullptr) context_.mul_mem->set_data_handle(DummyData);
667 if (add_data != nullptr) context_.add_mem->set_data_handle(DummyData);
668 }
669
670 private:
671 // Primitive reuse context for MatMul op
672 struct MklMatMulContext {
673 // oneDNN memory.
674 std::shared_ptr<dnnl::memory> a_mem;
675 std::shared_ptr<dnnl::memory> b_mem;
676 std::shared_ptr<dnnl::memory> c_mem;
677 std::shared_ptr<dnnl::memory> mul_mem;
678 std::shared_ptr<dnnl::memory> add_mem;
679 std::shared_ptr<dnnl::memory> sp_mem;
680
681 // Descriptor and primitive-descriptor for MatMul.
682 std::shared_ptr<matmul::desc> desc;
683 std::shared_ptr<matmul::primitive_desc> prim_desc;
684
685 // Memory descriptors.
686 std::shared_ptr<dnnl::memory::desc> a_md;
687 std::shared_ptr<dnnl::memory::desc> b_md;
688 std::shared_ptr<dnnl::memory::desc> c_md;
689 std::shared_ptr<dnnl::memory::desc> mul_md;
690 std::shared_ptr<dnnl::memory::desc> add_md;
691
692 // MatMul primitive.
693 std::vector<dnnl::primitive> matmul_primitives;
694 std::vector<std::unordered_map<int, memory>> net_args;
695
MklMatMulContextMklMatMulContext696 MklMatMulContext()
697 : a_mem(nullptr),
698 b_mem(nullptr),
699 c_mem(nullptr),
700 mul_mem(nullptr),
701 add_mem(nullptr),
702 sp_mem(nullptr),
703 desc(nullptr),
704 prim_desc(nullptr),
705 a_md(nullptr),
706 b_md(nullptr),
707 c_md(nullptr),
708 mul_md(nullptr),
709 add_md(nullptr) {}
710 };
711
Setup(const MklMatMulParams & params)712 void Setup(const MklMatMulParams& params) {
713 std::shared_ptr<dnnl::primitive> matmul_primitive = nullptr;
714
715 // Create MatMul descriptor and primitive descriptor.
716 context_.a_md.reset(new memory::desc({params.a_dims}, MklDnnType<Tlhs>(),
717 params.a_strides));
718
719 context_.b_md.reset(new memory::desc({params.b_dims}, MklDnnType<Trhs>(),
720 params.b_strides));
721
722 context_.c_md.reset(new memory::desc({params.c_dims}, MklDnnType<Toutput>(),
723 params.c_strides));
724
725 // Create matmul.
726 context_.desc.reset(
727 new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md));
728
729 // Check if there is any fusion as post-ops
730 auto const& post_op_params = params.post_op_params;
731 dnnl::primitive_attr post_ops_attr;
732 dnnl::post_ops post_ops;
733 if (!post_op_params.empty()) {
734 for (auto const& post_op_param : post_op_params) {
735 if (post_op_param.name == "output_scale") {
736 DCHECK_EQ(post_op_param.param.size(), 1);
737 std::vector<float> scales;
738 scales.push_back(post_op_param.param[0]);
739 post_ops_attr.set_output_scales(0, scales);
740 } else if (post_op_param.name == "mul") {
741 context_.mul_md.reset(new memory::desc({post_op_param.dims},
742 post_op_param.data_type,
743 post_op_param.format_tag));
744 post_ops.append_binary(dnnl::algorithm::binary_mul, *context_.mul_md);
745 } else if (post_op_param.name == "add") {
746 context_.add_md.reset(new memory::desc({post_op_param.dims},
747 post_op_param.data_type,
748 post_op_param.format_tag));
749 post_ops.append_binary(dnnl::algorithm::binary_add, *context_.add_md);
750 } else {
751 DCHECK((post_op_param.name == "output_scale"));
752 }
753 }
754 post_ops_attr.set_post_ops(post_ops);
755 }
756 post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
757 context_.prim_desc.reset(
758 new matmul::primitive_desc(*context_.desc, post_ops_attr, cpu_engine_));
759
760 // Create memory primitive based on dummy data.
761 context_.a_mem.reset(
762 new dnnl::memory(*context_.a_md, cpu_engine_, DummyData));
763 context_.b_mem.reset(
764 new dnnl::memory(*context_.b_md, cpu_engine_, DummyData));
765 context_.c_mem.reset(
766 new dnnl::memory(*context_.b_md, cpu_engine_, DummyData));
767 auto scratchpad_md = context_.prim_desc->scratchpad_desc();
768 context_.sp_mem.reset(
769 new dnnl::memory(scratchpad_md, cpu_engine_, DummyData));
770
771 // Create matmul primitive.
772 matmul_primitive.reset(new dnnl::matmul(*context_.prim_desc));
773 context_.net_args.push_back({{DNNL_ARG_SRC, *context_.a_mem},
774 {DNNL_ARG_WEIGHTS, *context_.b_mem},
775 {DNNL_ARG_SCRATCHPAD, *context_.sp_mem},
776 {DNNL_ARG_DST, *context_.c_mem}});
777 if (!post_op_params.empty()) {
778 int count = 0;
779 for (auto const& post_op_param : post_op_params) {
780 if (post_op_param.name == "mul") {
781 context_.mul_mem.reset(
782 new dnnl::memory(*context_.mul_md, cpu_engine_, DummyData));
783 context_.net_args[0].insert(
784 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(count) | DNNL_ARG_SRC_1,
785 *context_.mul_mem});
786 count++;
787 } else if (post_op_param.name == "add") {
788 context_.add_mem.reset(
789 new dnnl::memory(*context_.add_md, cpu_engine_, DummyData));
790 context_.net_args[0].insert(
791 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(count) | DNNL_ARG_SRC_1,
792 *context_.add_mem});
793 count++;
794 }
795 }
796 }
797
798 context_.matmul_primitives.push_back(*matmul_primitive);
799 return;
800 }
801
802 struct MklMatMulContext context_;
803 #ifdef DNNL_AARCH64_USE_ACL
804 mutex primitive_execution_mu_;
805 #endif
806 };
807
808 template <typename T, typename Tlhs, typename Trhs, typename Toutput>
809 class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> {
810 public:
Get(const MklMatMulParams & params,bool do_not_cache)811 static MklMatMulPrimitive<Tlhs, Trhs, Toutput>* Get(
812 const MklMatMulParams& params, bool do_not_cache) {
813 MklMatMulPrimitive<Tlhs, Trhs, Toutput>* matmul_prim = nullptr;
814
815 if (do_not_cache) {
816 // Always create new primitive
817 matmul_prim = new MklMatMulPrimitive<Tlhs, Trhs, Toutput>(params);
818 } else {
819 // Try to find a suitable one in pool
820 matmul_prim = dynamic_cast<MklMatMulPrimitive<Tlhs, Trhs, Toutput>*>(
821 MklMatMulPrimitiveFactory<T, Tlhs, Trhs, Toutput>::GetInstance()
822 .GetMklMatMul(params));
823 if (matmul_prim == nullptr) {
824 matmul_prim = new MklMatMulPrimitive<Tlhs, Trhs, Toutput>(params);
825 MklMatMulPrimitiveFactory<T, Tlhs, Trhs, Toutput>::GetInstance()
826 .SetMklMatMul(params, matmul_prim);
827 }
828 }
829
830 return matmul_prim;
831 }
832
833 private:
MklMatMulPrimitiveFactory()834 MklMatMulPrimitiveFactory() {}
~MklMatMulPrimitiveFactory()835 ~MklMatMulPrimitiveFactory() {}
836
GetInstance()837 static MklMatMulPrimitiveFactory& GetInstance() {
838 static MklMatMulPrimitiveFactory instance_;
839 return instance_;
840 }
841
CreateKey(const MklMatMulParams & params)842 static string CreateKey(const MklMatMulParams& params) {
843 FactoryKeyCreator key_creator;
844 key_creator.AddAsKey(params.prefix);
845 key_creator.AddAsKey(params.a_dims);
846 key_creator.AddAsKey(params.b_dims);
847 key_creator.AddAsKey(params.c_dims);
848 key_creator.AddAsKey(params.a_strides);
849 key_creator.AddAsKey(params.b_strides);
850 key_creator.AddAsKey(params.c_strides);
851 key_creator.AddAsKey(typeid(T).name());
852 #ifdef DNNL_AARCH64_USE_ACL
853 key_creator.AddAsKey(params.aarch64_counter);
854 #endif
855 key_creator.AddAsKey(typeid(Tlhs).name());
856 key_creator.AddAsKey(typeid(Trhs).name());
857 key_creator.AddAsKey(typeid(Toutput).name());
858
859 // Generate keys for post-ops
860 for (auto const& post_op_param : params.post_op_params) {
861 if (post_op_param.name == "output_scale") {
862 DCHECK_EQ(post_op_param.param.size(), 1);
863 key_creator.AddAsKey(post_op_param.name);
864 key_creator.AddAsKey(post_op_param.param[0]);
865 } else if (post_op_param.name == "mul" || post_op_param.name == "add") {
866 key_creator.AddAsKey(post_op_param.name);
867 key_creator.AddAsKey(post_op_param.dims);
868 } else {
869 return string("not_a_key");
870 }
871 }
872 return key_creator.GetKey();
873 }
874
GetMklMatMul(const MklMatMulParams & params)875 MklPrimitive* GetMklMatMul(const MklMatMulParams& params) {
876 string key = CreateKey(params);
877 return this->GetOp(key);
878 }
879
SetMklMatMul(const MklMatMulParams & params,MklPrimitive * op)880 void SetMklMatMul(const MklMatMulParams& params, MklPrimitive* op) {
881 string key = CreateKey(params);
882 this->SetOp(key, op);
883 }
884 };
885
886 template <typename T>
887 void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
888 float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
889 float beta, T* c, int64_t ldc, OpKernelContext* ctx = nullptr) {
890 using dims = dnnl::memory::dims;
891
892 // Prepare strides based on the transa and transb flags: transposed
893 // matrices have strides swapped
894 dims a_dims = dims{m, k};
895 dims b_dims = dims{k, n};
896 dims c_dims = dims{m, n};
897 dims a_strides = tolower(transa) == 'n' ? dims{lda, 1} : dims{1, lda};
898 dims b_strides = tolower(transb) == 'n' ? dims{ldb, 1} : dims{1, ldb};
899 dims c_strides = dims{ldc, 1};
900
901 // MklMatMul uses const alpha and beta, make guarantee here to ensure
902 // they are never changed.
903 DCHECK_EQ(alpha, 1.0f);
904 DCHECK_EQ(beta, 0.f);
905
906 MklMatMulParams params("dnnl_gemm", a_dims, b_dims, c_dims, a_strides,
907 b_strides, c_strides);
908 MklMatMulPrimitive<T, T, T>* matmul_prim =
909 MklMatMulPrimitiveFactory<T, T, T, T>::Get(params, 0);
910
911 UserScratchPad<unsigned char> scratch_pad;
912 scratch_pad.AllocateSPTensor(matmul_prim, ctx);
913 // Execute matmul primitive.
914 auto st = ExecuteSingleThreadedGemm(m, n, k, sizeof(T));
915 std::shared_ptr<stream> cpu_stream;
916 MklDnnThreadPool eigen_tp(ctx, st ? 1 : -1);
917 cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine()));
918 matmul_prim->Execute(cpu_stream, a, b, c, scratch_pad.Get());
919 }
920
921 } // anonymous namespace
922
923 } // namespace tensorflow
924
925 #endif // INTEL_MKL
926 #endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_
927