xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_conv_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 #ifdef INTEL_MKL
18 
19 #include "tensorflow/core/kernels/mkl/mkl_conv_ops.h"
20 
21 #include <algorithm>
22 #include <map>
23 #include <string>
24 #include <unordered_map>
25 
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h"
28 #include "tensorflow/core/kernels/no_op.h"
29 #ifdef DNNL_AARCH64_USE_ACL
30 #include "tensorflow/core/platform/hash.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #endif
33 
34 using dnnl::convolution_forward;
35 using dnnl::prop_kind;
36 using dnnl::stream;
37 using ConvFwdPd = dnnl::convolution_forward::primitive_desc;
38 using ReorderPd = dnnl::reorder::primitive_desc;
39 
40 namespace tensorflow {
41 // This structure aggregates multiple inputs to Conv2DFwd* methods.
42 struct MklConvFwdParams {
43   memory::dims src_dims;
44   memory::dims filter_dims;
45   memory::dims bias_dims;
46   memory::dims dst_dims;
47   memory::dims strides;
48   memory::dims dilations;
49   memory::dims padding_left;
50   memory::dims padding_right;
51   memory::dims fuse_bn_dims;
52   MklTensorFormat tf_fmt;
53   bool native_format;
54   string dtypes = string("");
55 #ifdef DNNL_AARCH64_USE_ACL
56   uint64 filter_hash;
57 #endif
58   struct PostOpParam {
59     string name;
60     dnnl::algorithm alg;
61     std::vector<float> param;
62     std::string partial_key;
63   };
64   std::vector<PostOpParam> post_op_params;
65 
MklConvFwdParamstensorflow::MklConvFwdParams66   MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims,
67                    memory::dims bias_dims, memory::dims dst_dims,
68                    memory::dims strides, memory::dims dilations,
69                    memory::dims padding_left, memory::dims padding_right,
70                    memory::dims fuse_bn_dims, MklTensorFormat tf_fmt,
71                    bool native_format)
72       : src_dims(src_dims),
73         filter_dims(filter_dims),
74         bias_dims(bias_dims),
75         dst_dims(dst_dims),
76         strides(strides),
77         dilations(dilations),
78         padding_left(padding_left),
79         padding_right(padding_right),
80         fuse_bn_dims(fuse_bn_dims),
81         tf_fmt(tf_fmt),
82         native_format(native_format) {}
83 };
84 
85 // With quantization, input, filter, and output can have different types
86 // so we use different template parameter for each type
87 template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
88 class MklConvFwdPrimitive : public MklPrimitive {
89  public:
MklConvFwdPrimitive(const MklConvFwdParams & convFwdDims)90   explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
91       : MklPrimitive(engine(engine::kind::cpu, 0)) {
92     // Create convolution primitive
93     if (context_.conv_fwd == nullptr) {
94       Setup(convFwdDims);
95     }
96   }
~MklConvFwdPrimitive()97   ~MklConvFwdPrimitive() {}
98 
GetScratchPadDesc()99   dnnl::memory::desc GetScratchPadDesc() {
100     return context_.fwd_pd->scratchpad_desc();
101   }
102 
103   // Convolution forward execute with bias
104   //   src_data:    input data buffer of src
105   //   filter_data: input data buffer of filter (weights)
106   //   bias_data:   input data buffer of bias
107   //   dst_data:    output data buffer of dst
Execute(const Tinput * src_data,const Tfilter * filter_data,const Tbias * bias_data,const Toutput * dst_data,std::shared_ptr<stream> fwd_stream,void * sp_data=nullptr)108   void Execute(const Tinput* src_data, const Tfilter* filter_data,
109                const Tbias* bias_data, const Toutput* dst_data,
110                std::shared_ptr<stream> fwd_stream, void* sp_data = nullptr) {
111     Execute(src_data, filter_data, bias_data, dst_data, nullptr, nullptr,
112             nullptr, nullptr, fwd_stream, sp_data);
113   }
114 
Execute(const Tinput * src_data,const Tfilter * filter_data,const Tbias * bias_data,const Toutput * dst_data,const Tinput * bn_scale_data,const Tinput * bn_mean_data,const Tinput * bn_offset_data,const Tinput * bn_rsqrt_data,std::shared_ptr<stream> fwd_stream,void * sp_data)115   void Execute(const Tinput* src_data, const Tfilter* filter_data,
116                const Tbias* bias_data, const Toutput* dst_data,
117                const Tinput* bn_scale_data, const Tinput* bn_mean_data,
118                const Tinput* bn_offset_data, const Tinput* bn_rsqrt_data,
119                std::shared_ptr<stream> fwd_stream, void* sp_data) {
120 #ifdef DNNL_AARCH64_USE_ACL
121     // When we are using single global cache then in this case we can have
122     // multiple threads running the same primitive that we created so this
123     // should happen under the lock.
124     mutex_lock lock(primitive_execution_mu_);
125 #endif
126 #ifndef ENABLE_ONEDNN_OPENMP
127     // TODO(intel-tf): Create a common function and avoid the duplicate code
128     context_.src_mem->set_data_handle(
129         static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
130     context_.filter_mem->set_data_handle(
131         static_cast<void*>(const_cast<Tfilter*>(filter_data)), *fwd_stream);
132     if (bias_data != nullptr) {
133       context_.bias_mem->set_data_handle(
134           static_cast<void*>(const_cast<Tbias*>(bias_data)), *fwd_stream);
135     }
136     if (bn_scale_data != nullptr) {
137       context_.bn_scale_mem->set_data_handle(
138           static_cast<void*>(const_cast<Tinput*>(bn_scale_data)), *fwd_stream);
139       context_.bn_mean_mem->set_data_handle(
140           static_cast<void*>(const_cast<Tinput*>(bn_mean_data)), *fwd_stream);
141       context_.bn_rsqrt_mem->set_data_handle(
142           static_cast<void*>(const_cast<Tinput*>(bn_rsqrt_data)), *fwd_stream);
143       context_.bn_offset_mem->set_data_handle(
144           static_cast<void*>(const_cast<Tinput*>(bn_offset_data)), *fwd_stream);
145     }
146     context_.dst_mem->set_data_handle(
147         static_cast<void*>(const_cast<Toutput*>(dst_data)), *fwd_stream);
148 #else
149     context_.src_mem->set_data_handle(
150         static_cast<void*>(const_cast<Tinput*>(src_data)));
151     context_.filter_mem->set_data_handle(
152         static_cast<void*>(const_cast<Tfilter*>(filter_data)));
153     if (bias_data != nullptr) {
154       context_.bias_mem->set_data_handle(
155           static_cast<void*>(const_cast<Tbias*>(bias_data)));
156     }
157     if (bn_scale_data != nullptr) {
158       context_.bn_scale_mem->set_data_handle(
159           static_cast<void*>(const_cast<Tinput*>(bn_scale_data)));
160       context_.bn_mean_mem->set_data_handle(
161           static_cast<void*>(const_cast<Tinput*>(bn_mean_data)));
162       context_.bn_rsqrt_mem->set_data_handle(
163           static_cast<void*>(const_cast<Tinput*>(bn_rsqrt_data)));
164       context_.bn_offset_mem->set_data_handle(
165           static_cast<void*>(const_cast<Tinput*>(bn_offset_data)));
166     }
167     context_.dst_mem->set_data_handle(
168         static_cast<void*>(const_cast<Toutput*>(dst_data)));
169 #endif  // !ENABLE_ONEDNN_OPENMP
170     if (sp_data) {
171       context_.sp_mem->set_data_handle(static_cast<void*>(sp_data),
172                                        *fwd_stream);
173     }
174 
175     DCHECK_EQ(context_.fwd_primitives.size(),
176               context_.fwd_primitives_args.size());
177     for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
178       context_.fwd_primitives.at(i).execute(*fwd_stream,
179                                             context_.fwd_primitives_args.at(i));
180     }
181 
182     // After execution, set data handle back
183     context_.src_mem->set_data_handle(DummyData);
184     context_.filter_mem->set_data_handle(DummyData);
185     if (bias_data != nullptr) {
186       context_.bias_mem->set_data_handle(DummyData);
187     }
188     if (bn_scale_data != nullptr) {
189       context_.bn_scale_mem->set_data_handle(DummyData);
190       context_.bn_mean_mem->set_data_handle(DummyData);
191       context_.bn_rsqrt_mem->set_data_handle(DummyData);
192       context_.bn_offset_mem->set_data_handle(DummyData);
193     }
194     context_.dst_mem->set_data_handle(DummyData);
195     if (sp_data) {
196       context_.sp_mem->set_data_handle(DummyData);
197     }
198   }
199 
200   // Convolution forward execute without bias
201   //   src_data:    input data buffer of src
202   //   filter_data: input data buffer of filter (weights)
203   //   dst_data:    output data buffer of dst
Execute(const Tinput * src_data,const Tfilter * filter_data,const Toutput * dst_data,std::shared_ptr<stream> fwd_stream,void * sp_data)204   void Execute(const Tinput* src_data, const Tfilter* filter_data,
205                const Toutput* dst_data, std::shared_ptr<stream> fwd_stream,
206                void* sp_data) {
207     Execute(src_data, filter_data, nullptr, dst_data, nullptr, nullptr, nullptr,
208             nullptr, fwd_stream, sp_data);
209   }
210 
GetPrimitiveDesc() const211   std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
212     return context_.fwd_pd;
213   }
214 
215  private:
216   // Primitive reuse context for Conv2D Fwd op
217   struct ConvFwdContext {
218     // MKL-DNN memory
219     std::shared_ptr<dnnl::memory> src_mem;
220     std::shared_ptr<dnnl::memory> filter_mem;
221     std::shared_ptr<dnnl::memory> bias_mem;
222     std::shared_ptr<dnnl::memory> dst_mem;
223     std::shared_ptr<dnnl::memory> sp_mem;
224 
225     // FusedBatchNorm related memory
226     std::shared_ptr<dnnl::memory> bn_scale_mem;
227     std::shared_ptr<dnnl::memory> bn_mean_mem;
228     std::shared_ptr<dnnl::memory> bn_rsqrt_mem;
229     std::shared_ptr<dnnl::memory> bn_offset_mem;
230 
231     // Desc & primitive desc
232     std::shared_ptr<dnnl::convolution_forward::desc> fwd_desc;
233 
234     // Memory desc
235     std::shared_ptr<dnnl::memory::desc> src_md;
236     std::shared_ptr<dnnl::memory::desc> filter_md;
237     std::shared_ptr<dnnl::memory::desc> bias_md;
238     std::shared_ptr<dnnl::memory::desc> dst_md;
239 
240     // TODO(intel-tf): Only need one? FusedBatchNorm related.
241     std::shared_ptr<dnnl::memory::desc> bn_scale_md;
242     std::shared_ptr<dnnl::memory::desc> bn_mean_md;
243     std::shared_ptr<dnnl::memory::desc> bn_rsqrt_md;
244     std::shared_ptr<dnnl::memory::desc> bn_offset_md;
245 
246     // Convolution primitive
247     std::shared_ptr<ConvFwdPd> fwd_pd;
248     std::shared_ptr<dnnl::primitive> conv_fwd;
249 
250     std::vector<dnnl::primitive> fwd_primitives;
251     std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
252 
ConvFwdContexttensorflow::MklConvFwdPrimitive::ConvFwdContext253     ConvFwdContext()
254         : src_mem(nullptr),
255           filter_mem(nullptr),
256           bias_mem(nullptr),
257           dst_mem(nullptr),
258           sp_mem(nullptr),
259           bn_scale_mem(nullptr),
260           bn_mean_mem(nullptr),
261           bn_rsqrt_mem(nullptr),
262           bn_offset_mem(nullptr),
263           fwd_desc(nullptr),
264           src_md(nullptr),
265           filter_md(nullptr),
266           bias_md(nullptr),
267           dst_md(nullptr),
268           bn_scale_md(nullptr),
269           bn_mean_md(nullptr),
270           bn_rsqrt_md(nullptr),
271           bn_offset_md(nullptr),
272           fwd_pd(nullptr),
273           conv_fwd(nullptr) {}
274   };
275 
Setup(const MklConvFwdParams & convFwdDims)276   void Setup(const MklConvFwdParams& convFwdDims) {
277     memory::format_tag user_data_fmt;
278     if (convFwdDims.native_format) {
279       user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt);
280     } else {
281       // Create memory descriptors for convolution data w/ no specified format
282       user_data_fmt = memory::format_tag::any;
283     }
284     context_.src_md.reset(new memory::desc(
285         {convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt));
286 
287     context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims},
288                                               MklDnnType<Tfilter>(),
289                                               memory::format_tag::any));
290 
291     context_.dst_md.reset(new memory::desc(
292         {convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt));
293 
294     if (!convFwdDims.bias_dims.empty()) {
295       context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims},
296                                               MklDnnType<Tbias>(),
297                                               memory::format_tag::any));
298       // Create a convolution descriptor
299       context_.fwd_desc.reset(new convolution_forward::desc(
300           prop_kind::forward, dnnl::algorithm::convolution_direct,
301           *context_.src_md, *context_.filter_md, *context_.bias_md,
302           *context_.dst_md, convFwdDims.strides, convFwdDims.dilations,
303           convFwdDims.padding_left, convFwdDims.padding_right));
304     } else {
305       context_.fwd_desc.reset(new convolution_forward::desc(
306           prop_kind::forward, dnnl::algorithm::convolution_direct,
307           *context_.src_md, *context_.filter_md, *context_.dst_md,
308           convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
309           convFwdDims.padding_right));
310     }
311 
312     if (!convFwdDims.fuse_bn_dims.empty()) {
313       const memory::format_tag fused_bn_arg_fmt =
314           convFwdDims.native_format
315               ? user_data_fmt
316               : MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt);
317 
318       context_.bn_scale_md.reset(new memory::desc(
319           {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt));
320       context_.bn_mean_md.reset(new memory::desc(
321           {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt));
322       context_.bn_rsqrt_md.reset(new memory::desc(
323           {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt));
324       context_.bn_offset_md.reset(new memory::desc(
325           {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt));
326     }
327 
328     // Check if there is any fusions as post-ops
329     auto const& post_op_params = convFwdDims.post_op_params;
330     dnnl::primitive_attr post_ops_attr;
331     dnnl::post_ops post_ops;
332     post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
333     if (!post_op_params.empty()) {
334       for (auto const& post_op_param : post_op_params) {
335         if (post_op_param.name == "activation") {
336           DCHECK_EQ(post_op_param.param.size(), 3);
337           float op_scale = post_op_param.param[0];
338           float op_alpha = post_op_param.param[1];
339           float op_beta = post_op_param.param[2];
340           post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha,
341                                   op_beta);
342         } else if (post_op_param.name == "sum") {
343           DCHECK_EQ(post_op_param.param.size(), 1);
344           float op_scale = post_op_param.param[0];
345           post_ops.append_sum(op_scale);
346         } else if (post_op_param.name == "output_scale") {
347           if (post_op_param.param.size() == 1) {
348             post_ops_attr.set_output_scales(0, post_op_param.param);
349           } else {
350             post_ops_attr.set_output_scales(2, post_op_param.param);
351           }
352         } else if (post_op_param.name == "fuse_bn") {
353           post_ops.append_binary(dnnl::algorithm::binary_sub,
354                                  *context_.bn_mean_md);
355           post_ops.append_binary(dnnl::algorithm::binary_mul,
356                                  *context_.bn_rsqrt_md);
357           post_ops.append_binary(dnnl::algorithm::binary_mul,
358                                  *context_.bn_scale_md);
359           post_ops.append_binary(dnnl::algorithm::binary_add,
360                                  *context_.bn_offset_md);
361         } else {
362           DCHECK((post_op_param.name == "activation") ||
363                  (post_op_param.name == "sum") ||
364                  (post_op_param.name == "output_scale") ||
365                  (post_op_param.name == "fuse_bn"));
366         }
367       }
368       post_ops_attr.set_post_ops(post_ops);
369     }
370     context_.fwd_pd.reset(
371         new ConvFwdPd(*context_.fwd_desc, post_ops_attr, cpu_engine_));
372 
373     // Create memory primitive based on dummy data
374     context_.src_mem.reset(
375         new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
376     context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
377                                          cpu_engine_, DummyData));
378     context_.dst_mem.reset(
379         new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
380 
381     context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
382     auto scratchpad_md = context_.fwd_pd->scratchpad_desc();
383     context_.sp_mem.reset(
384         new dnnl::memory(scratchpad_md, cpu_engine_, DummyData));
385 
386     // Create convolution primitive and add it to net
387     if (!convFwdDims.bias_dims.empty()) {
388       context_.bias_mem.reset(new memory(
389           {{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x},
390           cpu_engine_, DummyData));
391       context_.fwd_primitives_args.push_back(
392           {{DNNL_ARG_SRC, *context_.src_mem},
393            {DNNL_ARG_WEIGHTS, *context_.filter_mem},
394            {DNNL_ARG_BIAS, *context_.bias_mem},
395            {DNNL_ARG_SCRATCHPAD, *context_.sp_mem},
396            {DNNL_ARG_DST, *context_.dst_mem}});
397     } else if (!convFwdDims.fuse_bn_dims.empty()) {
398       context_.bn_scale_mem.reset(
399           new memory(*context_.bn_scale_md, cpu_engine_, DummyData));
400       context_.bn_mean_mem.reset(
401           new memory(*context_.bn_mean_md, cpu_engine_, DummyData));
402       context_.bn_offset_mem.reset(
403           new memory(*context_.bn_offset_md, cpu_engine_, DummyData));
404       context_.bn_rsqrt_mem.reset(
405           new memory(*context_.bn_rsqrt_md, cpu_engine_, DummyData));
406 
407       context_.fwd_primitives_args.push_back(
408           {{DNNL_ARG_SRC, *context_.src_mem},
409            {DNNL_ARG_WEIGHTS, *context_.filter_mem},
410            {DNNL_ARG_DST, *context_.dst_mem},
411            {DNNL_ARG_SCRATCHPAD, *context_.sp_mem},
412            {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
413             *context_.bn_mean_mem},
414            {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1,
415             *context_.bn_rsqrt_mem},
416            {DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1,
417             *context_.bn_scale_mem},
418            {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1,
419             *context_.bn_offset_mem}});
420     } else {
421       context_.fwd_primitives_args.push_back(
422           {{DNNL_ARG_SRC, *context_.src_mem},
423            {DNNL_ARG_WEIGHTS, *context_.filter_mem},
424            {DNNL_ARG_SCRATCHPAD, *context_.sp_mem},
425            {DNNL_ARG_DST, *context_.dst_mem}});
426     }
427     context_.fwd_primitives.push_back(*context_.conv_fwd);
428   }
429 
430   struct ConvFwdContext context_;
431 
432 #ifdef DNNL_AARCH64_USE_ACL
433   // Guards Execution()
434   mutex primitive_execution_mu_;
435 #endif
436 };
437 
438 // TODO(intel-tf): We should not require passing a type to MklPrimitiveFactory.
439 // But removing the need for type in MklPrimitiveFactory is going to require
440 // change to every MKL op. So not doing it now. Instead passing float.
441 template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
442 class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<float> {
443  public:
Get(const MklConvFwdParams & convFwdDims,bool do_not_cache)444   static MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>* Get(
445       const MklConvFwdParams& convFwdDims, bool do_not_cache) {
446     MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>* conv_fwd = nullptr;
447 
448     if (do_not_cache) {
449       // Always create a new primitive
450       conv_fwd =
451           new MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>(convFwdDims);
452     } else {
453       // Try to find a suitable one in pool
454       conv_fwd =
455           dynamic_cast<MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>*>(
456               MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias,
457                                          Toutput>::GetInstance()
458                   .GetConvFwd(convFwdDims));
459       if (conv_fwd == nullptr) {
460         conv_fwd = new MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>(
461             convFwdDims);
462         MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias,
463                                    Toutput>::GetInstance()
464             .SetConvFwd(convFwdDims, conv_fwd);
465       }
466     }
467 
468     return conv_fwd;
469   }
470 
471  private:
MklConvFwdPrimitiveFactory()472   MklConvFwdPrimitiveFactory() {}
~MklConvFwdPrimitiveFactory()473   ~MklConvFwdPrimitiveFactory() {}
474 
475   static const int kDilationH = 0, kDilationW = 1;
476 
GetInstance()477   static MklConvFwdPrimitiveFactory& GetInstance() {
478     static MklConvFwdPrimitiveFactory instance_;
479     return instance_;
480   }
481 
CreateKey(const MklConvFwdParams & convFwdDims)482   static string CreateKey(const MklConvFwdParams& convFwdDims) {
483     string prefix = "conv_fwd_";
484     FactoryKeyCreator key_creator;
485     key_creator.AddAsKey(prefix);
486     key_creator.AddAsKey(convFwdDims.src_dims);
487     key_creator.AddAsKey(convFwdDims.filter_dims);
488 #ifdef DNNL_AARCH64_USE_ACL
489     key_creator.AddAsKey(convFwdDims.filter_hash);
490 #endif
491     key_creator.AddAsKey(convFwdDims.bias_dims);
492     key_creator.AddAsKey(convFwdDims.dst_dims);
493     key_creator.AddAsKey(convFwdDims.strides);
494     key_creator.AddAsKey(convFwdDims.dilations);
495     key_creator.AddAsKey(convFwdDims.padding_left);
496     key_creator.AddAsKey(convFwdDims.padding_right);
497     key_creator.AddAsKey(convFwdDims.dtypes);
498     if (convFwdDims.native_format) {
499       key_creator.AddAsKey(convFwdDims.tf_fmt);
500     }
501 
502     // Generate keys for post-ops
503     for (auto const& post_op_param : convFwdDims.post_op_params) {
504       key_creator.AddAsKey(post_op_param.name);
505       if (post_op_param.name == "activation") {
506         DCHECK_EQ(post_op_param.param.size(), 3);
507         for (auto& param : post_op_param.param) {
508           key_creator.AddAsKey(param);
509         }
510       } else if (post_op_param.name == "sum") {
511         DCHECK_EQ(post_op_param.param.size(), 1);
512         for (auto& param : post_op_param.param) {
513           key_creator.AddAsKey(param);
514         }
515       } else if (post_op_param.name == "output_scale") {
516         key_creator.AddAsKey(post_op_param.partial_key);
517       } else if (post_op_param.name == "fuse_bn") {
518         key_creator.AddAsKey(post_op_param.name);
519         key_creator.AddAsKey(convFwdDims.fuse_bn_dims);
520       } else {
521         return string("not_a_key");
522       }
523     }
524 
525     return key_creator.GetKey();
526   }
527 
GetConvFwd(const MklConvFwdParams & convFwdDims)528   MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) {
529     string key = CreateKey(convFwdDims);
530     return this->GetOp(key);
531   }
532 
SetConvFwd(const MklConvFwdParams & convFwdDims,MklPrimitive * op)533   void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
534     string key = CreateKey(convFwdDims);
535     this->SetOp(key, op);
536   }
537 };
538 
539 // Base class for convolution forward operations
540 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
541           typename Toutput, typename Ttemp_output, typename Tpadding,
542           bool bias_enabled, bool pad_enabled, bool is_depthwise,
543           bool native_format>
544 class MklConvOp : public OpKernel {
545  public:
~MklConvOp()546   ~MklConvOp() {}
547 
MklConvOp(OpKernelConstruction * context)548   explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
549     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
550 
551     // Conv and QuantizedConv ops have different padding attributes
552     // (`padding_list` versus `explicit_paddings`). But one and only one
553     // attribute is expected.
554     OP_REQUIRES(
555         context,
556         !(context->HasAttr("padding_list") &&
557           context->HasAttr("explicit_paddings")),
558         errors::InvalidArgument("Can only have 1 `padding` list at most"));
559     if (context->HasAttr("padding_list")) {
560       OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_));
561     }
562     if (context->HasAttr("explicit_paddings")) {
563       OP_REQUIRES_OK(context,
564                      context->GetAttr("explicit_paddings", &padding_list_));
565     }
566 
567     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
568     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str_));
569     OP_REQUIRES(context, FormatFromString(data_format_str_, &data_format_),
570                 errors::InvalidArgument("Invalid data format"));
571     OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5),
572                 errors::InvalidArgument("Sliding window strides field must "
573                                         "specify 4 or 5 dimensions"));
574 
575     const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
576     const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
577     OP_REQUIRES(
578         context, stride_n == 1 && stride_c == 1,
579         errors::Unimplemented("Current implementation does not yet support "
580                               "strides in the batch and depth dimensions."));
581 
582     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
583     is_filter_const_ = false;
584     if (AreWeightsFrozen()) {
585       is_filter_const_ = true;
586     } else if (context->HasAttr("is_filter_const")) {
587       OP_REQUIRES_OK(context,
588                      context->GetAttr("is_filter_const", &is_filter_const_));
589     }
590 
591     if (strides_.size() == 4) {
592       OP_REQUIRES(context, dilations_.size() == 4,
593                   errors::InvalidArgument("Sliding window dilations field must "
594                                           "specify 4 dimensions"));
595       const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
596       const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
597       const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
598       const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
599       OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
600                   errors::InvalidArgument(
601                       "Current implementation does not yet support "
602                       "dilations in the batch and depth dimensions."));
603       OP_REQUIRES(
604           context, dilation_h > 0 && dilation_w > 0,
605           errors::InvalidArgument("Dilated rates should be larger than 0."));
606     } else if (strides_.size() == 5) {
607       OP_REQUIRES(context, dilations_.size() == 5,
608                   errors::InvalidArgument("Dilation rates field must "
609                                           "specify 5 dimensions"));
610       OP_REQUIRES(context,
611                   (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
612                    GetTensorDim(dilations_, data_format_, 'C') == 1),
613                   errors::InvalidArgument(
614                       "Current implementation does not yet support "
615                       "dilations rates in the batch and depth dimensions."));
616       OP_REQUIRES(
617           context,
618           (GetTensorDim(dilations_, data_format_, '0') > 0 &&
619            GetTensorDim(dilations_, data_format_, '1') > 0 &&
620            GetTensorDim(dilations_, data_format_, '2') > 0),
621           errors::InvalidArgument("Dilated rates should be larger than 0."));
622     }
623   }
624 
Compute(OpKernelContext * context)625   void Compute(OpKernelContext* context) override {
626     try {
627       // Input tensors
628       const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src);
629       const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
630 
631       OP_REQUIRES(
632           context, filter_tensor.NumElements() > 0,
633           errors::InvalidArgument("filter must not have zero elements "
634                                   "(i.e. all dimensions must be non-zero)"));
635 
636       MklDnnShape src_mkl_shape, filter_mkl_shape;
637       GetMklShape(context, kInputIndex_Src, &src_mkl_shape, native_format);
638       GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape,
639                   native_format);
640 
641       OP_REQUIRES(context, !filter_mkl_shape.IsMklTensor(),
642                   errors::InvalidArgument("Filter should not be in "
643                                           "Mkl Layout"));
644 
645       MklDnnData<Tinput> src(&cpu_engine_);
646       MklDnnData<Tfilter> filter(&cpu_engine_);
647 
648       memory::dims src_dims, filter_dims, padding_left, padding_right,
649           dilations, strides;
650       memory::dims dst_dims_tf_order, dst_dims_mkl_order;
651 
652       // For any Conv with `EXPLICIT` padding, get padding from `padding_list`
653       // attribute. Otherwise, get it from one of the inputs.
654       bool pad_attr_enabled = false;
655       for (auto const& padding_val : padding_list_) {
656         if (padding_val) {
657           pad_attr_enabled = true;
658 
659           break;
660         }
661       }
662 
663       if (fuse_pad_ || pad_attr_enabled) {
664         PadWithConvFusion(context, padding_left, padding_right,
665                           pad_attr_enabled, data_format_str_);
666       }
667 
668       // Get shapes of input tensors in MKL-DNN order
669       MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
670                               dilations_);
671       auto src_tf_shape = GetTfShape(context, kInputIndex_Src, native_format);
672       auto filter_tf_shape =
673           GetTfShape(context, kInputIndex_Filter, native_format);
674       bool is_grouped_convolution = false;
675       conv_utl.GetConvFwdSizesInMklOrder(
676           src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
677           &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left,
678           &padding_right, &is_grouped_convolution,
679           (fuse_pad_ || pad_attr_enabled), is_depthwise);
680 
681       if (!context->status().ok()) return;
682 
683       // Check for corner case - if there is nothing to compute, return.
684       TensorShape dst_tf_shape = MklDnnDimsToTFShape(dst_dims_tf_order);
685 
686       // Corner cases: output with 0 elements and 0 batch size.
687       Tensor* dst_tensor = nullptr;
688       bool emit_filter_output = (typeid(Tinput) == typeid(Tfilter) &&
689                                  typeid(Tinput) == typeid(Toutput) &&
690                                  (typeid(Tinput) == typeid(float) ||
691                                   typeid(Tinput) == typeid(bfloat16))) &&
692                                 !native_format;
693       if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) {
694         MklDnnShape dst_mkl_shape;
695         dst_mkl_shape.SetMklTensor(false);
696         AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor,
697                                   src_tf_shape, dst_mkl_shape, native_format);
698 
699         // MklConv2D/3D also outputs converted filter as 2nd output.
700         filter_mkl_shape.SetMklTensor(false);
701         Tensor* output_filter_tensor = nullptr;
702         if (emit_filter_output) {
703           filter_mkl_shape.SetMklTensor(false);
704           AllocateOutputSetMklShape(context, kOutputIndex_Filter,
705                                     &output_filter_tensor, filter_tf_shape,
706                                     filter_mkl_shape);
707         }
708         return;
709       }
710 
711       bool is_conv2d = (strides_.size() == 4);
712       bool is_conv3d = (strides_.size() == 5);
713 
714       if (!is_conv2d && !is_conv3d) {
715         OP_REQUIRES(
716             context, !pad_enabled,
717             errors::InvalidArgument("Pad + Conv fusion only works for 2D/3D"));
718         OP_REQUIRES(
719             context, !fuse_pad_,
720             errors::InvalidArgument("Pad+Conv fusion only works for 2D/3D"));
721       }
722 
723       // TODO(intel-tf) 3-D support for Depthwise is not there
724       if (is_depthwise) {
725         OP_REQUIRES(context, is_conv2d,
726                     errors::InvalidArgument(
727                         "Only 2D convolution is supported for depthwise."));
728       }
729 
730       // Create memory for user data.
731       // Describe how the inputs and outputs of Convolution look like. Also
732       // specify buffers containing actual input and output data.
733       auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
734                               : TFDataFormatToMklDnn3DDataFormat(data_format_);
735 
736       auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
737       // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
738       OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
739                   errors::InvalidArgument("Invalid data format"));
740 
741       // If input is in MKL layout, then simply grab the layout; otherwise,
742       // construct TF layout for input.
743       // For constructing TF layout for input, although input shape (src_dims)
744       // is required to be in MKL-DNN order, the input layout is actually in
745       // TF layout depending on the data format:
746       //     Conv2D: NHWC or NCHW
747       //     Conv3D: NDHWC or NCDHW
748       auto src_md =
749           src_mkl_shape.IsMklTensor()
750               ? src_mkl_shape.GetMklLayout()
751               : memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
752       src.SetUsrMem(src_md, &src_tensor);
753 
754       // Although filter shape (filter_dims) required is in MKL-DNN order,
755       // the layout is Tensorflow's layout (HWIO) and (HWIGO) for
756       // depthwise/group convolutions.
757       auto filter_format = is_conv2d ? ((is_depthwise || is_grouped_convolution)
758                                             ? memory::format_tag::hwigo
759                                             : memory::format_tag::hwio)
760                                      : memory::format_tag::dhwio;
761 
762       DCHECK(!filter_mkl_shape.IsMklTensor());
763       auto filter_md =
764           filter_mkl_shape.IsMklTensor()
765               ? filter_mkl_shape.GetMklLayout()
766               : memory::desc(filter_dims, MklDnnType<Tfilter>(), filter_format);
767       filter.SetUsrMem(filter_md, &filter_tensor);
768 
769       // MKL-DNN dilations start from 0.
770       for (int i = 0; i < dilations.size(); ++i) --dilations[i];
771 
772       // In some cases, primitive descriptor could potentially contain
773       // large buffers. As a result, we don't cache these primitives if the
774       // environment variable `TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE` is set to True.
775       // MKL-DNN allocates buffers in the following cases:
776       //   1. Legacy CPU without AVX512/AVX2, or
777       //   2. 1x1 convolution with strides != 1
778       bool do_not_cache =
779           MklPrimitiveFactory<Tinput>::IsPrimitiveMemOptEnabled() &&
780           (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) &&
781           (MklPrimitiveFactory<Tinput>::IsLegacyPlatform() ||
782            IsConv1x1StrideNot1(filter_dims, strides));
783 
784       // Get a conv2d fwd from primitive pool
785       MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Ttemp_output>* conv_fwd =
786           nullptr;
787       memory::dims bias_dims = {};
788       if (fuse_biasadd_) {
789         conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
790       }
791       memory::dims fuse_bn_dims = {};
792       TensorShape fuse_bn_shape;
793       if (fuse_bn_) {
794         // Inputs to FusedBatchNorm have same 1D shape
795         fuse_bn_shape = MklGetInput(context, kInputIndex_BN_Mean).shape();
796         OP_REQUIRES(context, fuse_bn_shape.dims() == 1,
797                     errors::InvalidArgument("FusedBatchNorm must be 1D, not: ",
798                                             fuse_bn_shape.DebugString()));
799 
800         // Note - MKL-DNN expects {1, C, 1, 1} for binary post-op even for NHWC
801         fuse_bn_dims = {1, fuse_bn_shape.dim_size(0), 1, 1};
802       }
803 
804       MklConvFwdParams convFwdDims(
805           src_dims, filter_dims, fuse_biasadd_ ? bias_dims : NONE_DIMS,
806           dst_dims_mkl_order, strides, dilations, padding_left, padding_right,
807           fuse_bn_dims, tf_fmt, native_format);
808 
809       // TODO(intel-tf): Extend the basic parameters for data types and fusions
810       this->ExtendConvFwdParams(context, convFwdDims);
811 #ifdef DNNL_AARCH64_USE_ACL
812       // TODO(milpuz01): Remove once Arm Compute Library provides support for
813       // in-place updates
814       convFwdDims.filter_hash = Hash64(
815           filter_tensor.tensor_data().data(),
816           std::min(kFilterTensorHashLength,
817                    static_cast<int>(filter_tensor.tensor_data().size())));
818 #endif
819 
820       conv_fwd =
821           MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get(
822               convFwdDims, do_not_cache);
823       // Allocate output tensors `dst_tensor` and `filter_out_tensor`
824       MklDnnShape output_mkl_shape;
825       std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
826       AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt,
827                            &output_mkl_shape, &dst_tensor);
828 
829       Tensor* filter_out_tensor = nullptr;
830       if (emit_filter_output) {
831         AllocateFilterOutputTensor(context, *conv_fwd_pd,
832                                    TFShapeToMklDnnDims(filter_tf_shape),
833                                    &filter_out_tensor);
834       }
835 
836       Ttemp_output* dst_data =
837           reinterpret_cast<Ttemp_output*>(dst_tensor->flat<Toutput>().data());
838 
839       // Check whether src and filter need to be reordered.
840       Tinput* src_data = nullptr;
841       if (src_md != conv_fwd_pd->src_desc()) {
842         src.SetUsrMem(src_md, &src_tensor);
843         src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_, context);
844         src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
845       } else {
846         src_data = static_cast<Tinput*>(
847             const_cast<Tinput*>(src_tensor.flat<Tinput>().data()));
848       }
849 
850       Tfilter* filter_data = nullptr;
851       if (filter_md != conv_fwd_pd->weights_desc()) {
852         bool is_filter_cached = false;
853         // If filter is a constant, we can avoid the conversion of filter from
854         // Tensorflow format to MKL format by caching the filter when it is
855         // converted for the first time. This cached filter can then be reused
856         // in subsequent iterations.
857         if (is_filter_const_) {
858           if (IsFilterCacheEmpty(context)) {
859             // Cache filter if it is not already cached.
860             CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
861                         filter, filter_md, filter_mkl_shape);
862           }
863           filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc());
864           is_filter_cached = (filter_data != nullptr);
865         }
866         if (!is_filter_cached) {
867           filter.SetUsrMem(filter_md, &filter_tensor);
868           if (filter_out_tensor == nullptr) {
869             filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), cpu_engine_,
870                                        context);
871           } else {
872             filter.CheckReorderToOpMem(
873                 conv_fwd_pd->weights_desc(),
874                 filter.GetTensorBuffer(filter_out_tensor), cpu_engine_,
875                 context);
876           }
877           filter_data =
878               static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
879         }
880       } else {
881         filter_data = static_cast<Tfilter*>(
882             const_cast<Tfilter*>(filter_tensor.flat<Tfilter>().data()));
883       }
884 
885       UserScratchPad<unsigned char> scratch_pad;
886       scratch_pad.AllocateSPTensor(conv_fwd, context);
887 
888       // Execute convolution
889       std::shared_ptr<stream> fwd_cpu_stream;
890       MklDnnThreadPool eigen_tp(context);
891       fwd_cpu_stream.reset(CreateStream(&eigen_tp, conv_fwd->GetEngine()));
892       if (fuse_biasadd_) {
893         const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
894         Tbias* bias_data =
895             this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
896         conv_fwd->Execute(src_data, filter_data, bias_data, dst_data,
897                           fwd_cpu_stream, scratch_pad.Get());
898       } else if (fuse_bn_) {
899         const Tensor& bn_scale_tensor =
900             MklGetInput(context, kInputIndex_BN_Scale);
901         Tinput* bn_scale_data = static_cast<Tinput*>(
902             const_cast<Tinput*>(bn_scale_tensor.flat<Tinput>().data()));
903         const Tensor& bn_mean_tensor =
904             MklGetInput(context, kInputIndex_BN_Mean);
905         Tinput* bn_mean_data = static_cast<Tinput*>(
906             const_cast<Tinput*>(bn_mean_tensor.flat<Tinput>().data()));
907         const Tensor& bn_offset_tensor =
908             MklGetInput(context, kInputIndex_BN_Offset);
909         Tinput* bn_offset_data = static_cast<Tinput*>(
910             const_cast<Tinput*>(bn_offset_tensor.flat<Tinput>().data()));
911 
912         Tensor bn_rsqrt_tensor;
913         OP_REQUIRES_OK(context,
914                        context->allocate_temp(DataTypeToEnum<Tinput>::v(),
915                                               fuse_bn_shape, &bn_rsqrt_tensor));
916         Tinput* bn_rsqrt_data = static_cast<Tinput*>(
917             const_cast<Tinput*>(bn_rsqrt_tensor.flat<Tinput>().data()));
918         this->ComputeBNScale(context, epsilon_, kInputIndex_BN_Variance,
919                              bn_rsqrt_data);
920         conv_fwd->Execute(src_data, filter_data, nullptr, dst_data,
921                           bn_scale_data, bn_mean_data, bn_offset_data,
922                           bn_rsqrt_data, fwd_cpu_stream, scratch_pad.Get());
923       } else {
924         conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream,
925                           scratch_pad.Get());
926       }
927 
928       // Delete primitive since it is not cached.
929       if (do_not_cache) delete conv_fwd;
930 
931     } catch (dnnl::error& e) {
932       string error_msg = tensorflow::strings::StrCat(
933           "Status: ", e.status, ", message: ", string(e.message), ", in file ",
934           __FILE__, ":", __LINE__);
935       OP_REQUIRES_OK(
936           context,
937           errors::Aborted("Operation received an exception:", error_msg));
938     }
939   }
940 
PadWithConvFusion(OpKernelContext * context,memory::dims & padding_left,memory::dims & padding_right,bool pad_attr_enabled,string data_format_str_)941   void PadWithConvFusion(OpKernelContext* context, memory::dims& padding_left,
942                          memory::dims& padding_right, bool pad_attr_enabled,
943                          string data_format_str_) {
944     Tpadding* paddings = nullptr;
945     if (pad_attr_enabled) {
946       paddings = padding_list_.data();
947     } else {
948       const Tensor& paddings_tf = MklGetInput(context, input_index_pad_);
949       OP_REQUIRES(context, paddings_tf.dims() == 2,
950                   errors::InvalidArgument("paddings must be 2-dimensional: ",
951                                           paddings_tf.shape().DebugString()));
952       // Flatten tensor to get individual paddings.
953       paddings = static_cast<Tpadding*>(
954           const_cast<Tpadding*>(paddings_tf.flat<Tpadding>().data()));
955     }
956     // If the data format is NHWC, indices 0, 1, 6 and 7 of paddings(_tf)
957     // will be zero.
958     // Example:
959     // paddings_tf = [ [0, 0] [1, 2] [3, 4] [0, 0] ],
960     // flat method = row-major, then:
961     // paddings = {0, 0, 1, 2, 3, 4, 0, 0}.
962     // Hence, the values are: top = 1, bottom = 2, left = 3, right = 4.
963     //
964     // Similarly, if the data format is NCHW, indices 0, 1, 2 and 3 of
965     // paddings(_tf) will be zero.
966     // i.e. for the above example, paddings = {0, 0, 0, 0, 1, 2, 3, 4}.
967     int64 pad_top = 0, pad_left = 0, pad_front = 0;
968     int64 pad_bottom = 0, pad_right = 0, pad_back = 0;
969     if (data_format_str_ == "NHWC") {
970       pad_top = paddings[2];
971       pad_bottom = paddings[3];
972       pad_left = paddings[4];
973       pad_right = paddings[5];
974     } else if (data_format_str_ == "NCHW") {
975       pad_top = paddings[4];
976       pad_bottom = paddings[5];
977       pad_left = paddings[6];
978       pad_right = paddings[7];
979     } else if (data_format_str_ == "NDHWC") {
980       pad_front = paddings[2];
981       pad_back = paddings[3];
982       pad_top = paddings[4];
983       pad_bottom = paddings[5];
984       pad_left = paddings[6];
985       pad_right = paddings[7];
986     } else if (data_format_str_ == "NCDHW") {
987       pad_front = paddings[4];
988       pad_back = paddings[5];
989       pad_top = paddings[6];
990       pad_bottom = paddings[7];
991       pad_left = paddings[8];
992       pad_right = paddings[9];
993     }
994     // Create padding arrays for MKL-DNN convolutions.
995     // MKL-DNN uses asymmetric padding.
996     if (data_format_str_ == "NHWC" || data_format_str_ == "NCHW") {
997       padding_left = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
998       padding_right = {static_cast<int>(pad_bottom),
999                        static_cast<int>(pad_right)};
1000     } else if (data_format_str_ == "NDHWC" || data_format_str_ == "NCDHW") {
1001       padding_left = {static_cast<int>(pad_front), static_cast<int>(pad_top),
1002                       static_cast<int>(pad_left)};
1003       padding_right = {static_cast<int>(pad_back), static_cast<int>(pad_bottom),
1004                        static_cast<int>(pad_right)};
1005     }
1006   }
1007 
1008  protected:
set_fuse_biasadd(bool fuse_biasadd)1009   void set_fuse_biasadd(bool fuse_biasadd) { fuse_biasadd_ = fuse_biasadd; }
set_fuse_activation(bool fuse_activation,dnnl::algorithm activation_alg,float alpha_or_upbound=0.0)1010   void set_fuse_activation(bool fuse_activation, dnnl::algorithm activation_alg,
1011                            float alpha_or_upbound = 0.0) {
1012     fuse_activation_ = fuse_activation;
1013     activation_alg_ = activation_alg;
1014     // This variable is used for alpha in leakyrelu or upper bound in relu6
1015     // depending on the context
1016     alpha_or_upbound_ = alpha_or_upbound;
1017   }
set_fuse_pad(bool fuse_pad)1018   void set_fuse_pad(bool fuse_pad) {
1019     fuse_pad_ = fuse_pad;
1020     if (fuse_bn_) {
1021       // If FusedBatchNorm is fused in PadWithFusedConv2D, pad is the 7th input
1022       input_index_pad_ = 6;
1023     } else if (fuse_add_ && fuse_biasadd_) {
1024       // If Bias and Add are fused in PadWithFusedConv2D, pad is the 5th input
1025       input_index_pad_ = 4;
1026     } else {
1027       // Case of Bias is fused in PadwithFusedConv OP, pad is the fourth input
1028       input_index_pad_ = 3;
1029     }
1030   }
set_fuse_add(bool fuse_add)1031   void set_fuse_add(bool fuse_add) { fuse_add_ = fuse_add; }
set_fuse_bn(bool fuse_bn,float epsilon)1032   void set_fuse_bn(bool fuse_bn, float epsilon) {
1033     fuse_bn_ = fuse_bn;
1034     epsilon_ = epsilon;
1035   }
1036 
ComputeBNScale(OpKernelContext * context,float epsilon,int bn_variance_index,Tinput * scale_buf_ptr)1037   virtual void ComputeBNScale(OpKernelContext* context, float epsilon,
1038                               int bn_variance_index, Tinput* scale_buf_ptr) {
1039     OP_REQUIRES(
1040         context, false,
1041         errors::Unimplemented("Compute BN scale not expected in base class"));
1042     return;
1043   }
1044 
1045   // This method is for the base class MklConvOp, which handles the
1046   // floating point implementation of Conv. The quantized conv implementations
1047   // will use overridden versions of this method.
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1048   virtual void ExtendConvFwdParams(OpKernelContext* context,
1049                                    MklConvFwdParams& params) {
1050     // Create a string from data types of input, filter, bias, and output.
1051     params.dtypes.append(typeid(Tinput).name());
1052     params.dtypes.append(typeid(Tfilter).name());
1053     params.dtypes.append(typeid(Tbias).name());
1054     params.dtypes.append(typeid(Toutput).name());
1055 
1056     // Add fusions as post ops
1057     // NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
1058     // checking `fuse_biasadd_` flag.
1059     if (fuse_add_) {
1060       params.post_op_params.push_back(
1061           {"sum", dnnl::algorithm::undef, {1.0}, ""});
1062     }
1063     // NOTE - fuse_bn post_op entry must be before fuse_activation
1064     if (fuse_bn_) {
1065       params.post_op_params.push_back(
1066           {"fuse_bn", dnnl::algorithm::undef, {1.0}, ""});
1067     }
1068     if (fuse_activation_) {
1069       params.post_op_params.push_back(
1070           {"activation", activation_alg_, {1.0, alpha_or_upbound_, 0.0}, ""});
1071     }
1072   }
1073 
GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv2d_fwd_pd,const Tensor & bias_tensor)1074   virtual Tbias* GetBiasHandle(OpKernelContext* context,
1075                                std::shared_ptr<ConvFwdPd>& conv2d_fwd_pd,
1076                                const Tensor& bias_tensor) {
1077     if (fuse_biasadd_) {
1078       return static_cast<Tbias*>(
1079           const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
1080     }
1081     return nullptr;
1082   }
1083 
AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,MklTensorFormat output_tf_format,MklDnnShape * output_mkl_shape,Tensor ** output_tensor)1084   virtual void AllocateOutputTensor(OpKernelContext* context,
1085                                     const ConvFwdPd& conv_prim_desc,
1086                                     const memory::dims& output_dims_mkl_order,
1087                                     MklTensorFormat output_tf_format,
1088                                     MklDnnShape* output_mkl_shape,
1089                                     Tensor** output_tensor) {
1090     DCHECK(output_tensor);
1091     auto dst_md = conv_prim_desc.dst_desc();
1092 
1093     if (!std::is_same<Ttemp_output, Toutput>::value) {
1094       dst_md.data.data_type =
1095           static_cast<dnnl_data_type_t>(MklDnnType<Toutput>());
1096     }
1097 
1098     // Allocate shape of MKL tensor
1099     output_mkl_shape->SetMklTensor(true);
1100     output_mkl_shape->SetMklLayout(&dst_md);
1101     output_mkl_shape->SetElemType(MklDnnType<Toutput>());
1102     output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
1103                                   output_dims_mkl_order, output_tf_format);
1104 
1105     // Allocate shape of TF tensor
1106     TensorShape output_tf_shape;
1107     output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput)));
1108     if (native_format) {
1109       output_tf_shape = output_mkl_shape->GetTfShape();
1110     }
1111 
1112     if (fuse_add_) {
1113       const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add);
1114       MklDnnShape add_mkl_shape;
1115       GetMklShape(context, kInputIndex_Add, &add_mkl_shape, native_format);
1116       // Forward the summand tensor to the output only if it has no other
1117       // references, otherwise make a copy of it.
1118       if (native_format && context->forward_input_to_output_with_shape(
1119                                kInputIndex_Add, kOutputIndex_Dst,
1120                                output_tf_shape, output_tensor)) {
1121         return;
1122       }
1123       // Check if reorder is needed
1124       if (!native_format && add_mkl_shape == *output_mkl_shape &&
1125           ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add,
1126                                               kOutputIndex_Dst, output_tensor,
1127                                               add_mkl_shape, false)) {
1128         return;
1129       } else {
1130         AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
1131                                   output_tf_shape, *output_mkl_shape,
1132                                   native_format);
1133         auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
1134             output_mkl_shape->GetTfDataFormat());
1135         OP_REQUIRES(context, output_format_tag != memory::format_tag::undef,
1136                     errors::InvalidArgument(
1137                         "MklConvOp: AddN fusion: Invalid data format"));
1138         auto add_md =
1139             add_mkl_shape.IsMklTensor()
1140                 ? add_mkl_shape.GetMklLayout()
1141                 : memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
1142                                output_format_tag);
1143         void* add_buf = static_cast<void*>(
1144             const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
1145         void* dst_buf =
1146             static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
1147         if (native_format) {
1148           // We are simply deep copying the add_tensor to output_tensor without
1149           // changing memory layout, hence using same memory descriptor.
1150           add_md = dst_md =
1151               memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(),
1152                            dnnl::memory::format_tag::x);
1153         }
1154         fuse_add_src_.reset(new memory(add_md, this->cpu_engine_, add_buf));
1155         fuse_add_dst_.reset(new memory(dst_md, this->cpu_engine_, dst_buf));
1156         auto reorder_desc =
1157             ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md);
1158 
1159         CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
1160                                 this->cpu_engine_, context);
1161       }
1162     } else {
1163       AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
1164                                 output_tf_shape, *output_mkl_shape,
1165                                 native_format);
1166     }
1167   }
1168 
1169   engine cpu_engine_ = engine(engine::kind::cpu, 0);
1170 
1171  private:
1172   std::shared_ptr<dnnl::memory> fuse_add_src_;
1173   std::shared_ptr<dnnl::memory> fuse_add_dst_;
1174   std::vector<int32> strides_;
1175   std::vector<int32> dilations_;
1176   std::vector<Tpadding> padding_list_;
1177   bool is_filter_const_;
1178   mutex mu_;
1179   Padding padding_;
1180   string data_format_str_;
1181   TensorFormat data_format_;
1182   Tensor cached_filter_data_ TF_GUARDED_BY(mu_);
1183   Tensor cached_filter_md_ TF_GUARDED_BY(mu_);
1184 
1185   // Initialize to values the template is instantiated with
1186   bool fuse_biasadd_ = bias_enabled;
1187   bool fuse_activation_ = false;
1188   bool fuse_pad_ = pad_enabled;
1189   bool fuse_add_ = false;
1190   bool fuse_bn_ = false;
1191   float epsilon_ = 0.0001;
1192 
1193   // This variable is used for alpha in leakyrelu or upper bound in relu6
1194   // depending on the context
1195   float alpha_or_upbound_ = 0.0;
1196   dnnl::algorithm activation_alg_ = dnnl::algorithm::undef;
1197 
1198   int input_index_pad_ = 2;
1199 
1200   const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2;
1201   const int kInputIndex_Add = 3;
1202   const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
1203   const int kDilationH = 0, kDilationW = 1;
1204 
1205   // Input indices for FusedBatchNorm
1206   const int kInputIndex_BN_Scale = 2, kInputIndex_BN_Offset = 3;
1207   const int kInputIndex_BN_Mean = 4, kInputIndex_BN_Variance = 5;
1208 #ifdef DNNL_AARCH64_USE_ACL
1209   const int kFilterTensorHashLength = 1024;
1210 #endif
1211 
GetFilterTfDataFormat(const MklDnnShape * filter_mkl_shape,const ConvFwdPd & conv_prim_desc) const1212   MklTensorFormat GetFilterTfDataFormat(const MklDnnShape* filter_mkl_shape,
1213                                         const ConvFwdPd& conv_prim_desc) const {
1214     DCHECK(filter_mkl_shape);
1215     return filter_mkl_shape->GetTfDataFormat();
1216   }
1217 
1218   // Allocate tensors for cached filter data and cached filter memory
1219   // descriptor (data format)
AllocateTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** filter_tensor,const MklDnnShape * filter_mkl_shape)1220   void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc,
1221                       Tensor** filter_tensor,
1222                       const MklDnnShape* filter_mkl_shape)
1223       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1224     DCHECK(filter_tensor);
1225     TensorShape filter_tf_shape;
1226     filter_tf_shape.AddDim(
1227         (conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter)));
1228     OP_REQUIRES_OK(
1229         context, context->allocate_temp(DataTypeToEnum<Tfilter>::value,
1230                                         filter_tf_shape, &cached_filter_data_));
1231 
1232     *filter_tensor = &cached_filter_data_;
1233 
1234     // There is no tensor format in DNNL 1.x. So we cache the complete filter
1235     // descriptor as flat byte array.
1236     TensorShape cached_filter_md_shape;
1237     memory::desc weights_desc = conv_prim_desc.weights_desc();
1238     // We don't use .get_size() method of memory::desc since it returns size
1239     // required to store primitive's input memory. It is much more than size of
1240     // memory::desc itself.
1241     cached_filter_md_shape.AddDim(sizeof(weights_desc) / sizeof(uint8));
1242     OP_REQUIRES_OK(context,
1243                    context->allocate_temp(DT_UINT8, cached_filter_md_shape,
1244                                           &cached_filter_md_));
1245     *reinterpret_cast<memory::desc*>(cached_filter_md_.flat<uint8>().data()) =
1246         weights_desc;
1247   }
1248 
AllocateTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** filter_tensor)1249   void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc,
1250                       Tensor** filter_tensor) {
1251     AllocateTensor(context, conv_prim_desc, filter_tensor, nullptr);
1252   }
1253 
AllocateFilterOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & filter_dims_tf_order,Tensor ** filter_tensor)1254   void AllocateFilterOutputTensor(OpKernelContext* context,
1255                                   const ConvFwdPd& conv_prim_desc,
1256                                   const memory::dims& filter_dims_tf_order,
1257                                   Tensor** filter_tensor) {
1258     DCHECK(filter_tensor);
1259     auto filter_md = conv_prim_desc.weights_desc();
1260 
1261     // Allocate shape of MKL tensor
1262     MklDnnShape filter_mkl_shape;
1263     filter_mkl_shape.SetMklTensor(true);
1264     filter_mkl_shape.SetMklLayout(&filter_md);
1265     filter_mkl_shape.SetElemType(MklDnnType<Tfilter>());
1266 
1267     // The format of the filter is actually OIhw8i8o, but TF doesn't support
1268     // this format. Just use format::blocked for now because the layout
1269     // is stored in the MKL data.
1270     filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
1271                                  filter_dims_tf_order,
1272                                  MklTensorFormat::FORMAT_BLOCKED);
1273 
1274     // Allocate the data space for the filter to propagate as TF tensor.
1275     TensorShape filter_tf_shape;
1276     filter_tf_shape.AddDim((filter_md.get_size() / sizeof(Tfilter)));
1277 
1278     AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor,
1279                               filter_tf_shape, filter_mkl_shape);
1280   }
1281 
1282   // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
1283   // be acquired before entering the function, since it is acquired
1284   // inside the function.
IsFilterCacheEmpty(OpKernelContext * context)1285   inline bool IsFilterCacheEmpty(OpKernelContext* context)
1286       TF_LOCKS_EXCLUDED(mu_) {
1287     tf_shared_lock lock(mu_);
1288     const Tensor& cached_filter_data_tensor = cached_filter_data_;
1289     return (cached_filter_data_tensor.NumElements() == 0);
1290   }
1291 
1292   // Cache the converted filter in a tensor.
1293   // Only one thread can execute this method at any given time.
CacheFilter(OpKernelContext * context,const std::shared_ptr<ConvFwdPd> & conv_fwd_pd,Tfilter * filter_data,const Tensor & filter_tensor,MklDnnData<Tfilter> & filter,const memory::desc & filter_md,const MklDnnShape & filter_mkl_shape)1294   void CacheFilter(OpKernelContext* context,
1295                    const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
1296                    Tfilter* filter_data, const Tensor& filter_tensor,
1297                    MklDnnData<Tfilter>& filter, const memory::desc& filter_md,
1298                    const MklDnnShape& filter_mkl_shape) TF_LOCKS_EXCLUDED(mu_) {
1299     mutex_lock lock(mu_);
1300     const Tensor& cached_filter_data_tensor = cached_filter_data_;
1301 
1302     // If filter is already cached, there's nothing to do.
1303     if (cached_filter_data_tensor.NumElements() > 0) {
1304       return;
1305     }
1306 
1307     // Otherwise, cache filter
1308     filter.SetUsrMem(filter_md, &filter_tensor);
1309     filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(),
1310                                this->cpu_engine_, context);
1311     filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
1312 
1313     Tensor* filter_tensor_ptr = nullptr;
1314     AllocateTensor(context, *conv_fwd_pd, &filter_tensor_ptr,
1315                    &filter_mkl_shape);
1316     void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
1317     size_t cached_filter_data_size = filter.GetOpMem().get_desc().get_size();
1318     memcpy(cached_filter_data, filter_data, cached_filter_data_size);
1319   }
1320 
AreMemoryDescriptorsEqual(const memory::desc & filter_md,const Tensor & cached_filter_md)1321   bool AreMemoryDescriptorsEqual(const memory::desc& filter_md,
1322                                  const Tensor& cached_filter_md) {
1323     auto filter_md_data = filter_md.data;
1324     const char* filter_data = reinterpret_cast<const char*>(&filter_md_data);
1325 
1326     auto cached_filter_md_data = cached_filter_md.scalar<int64_t>()();
1327     const char* cached_filter_data =
1328         reinterpret_cast<const char*>(&cached_filter_md_data);
1329 
1330     for (size_t i = 0; i < sizeof(filter_md_data); ++i) {
1331       if (*filter_data++ != *cached_filter_data++) {
1332         return false;
1333       }
1334     }
1335     return true;
1336   }
1337 
GetCachedFilter(OpKernelContext * context,const memory::desc & filter_md)1338   Tfilter* GetCachedFilter(OpKernelContext* context,
1339                            const memory::desc& filter_md)
1340       TF_LOCKS_EXCLUDED(mu_) {
1341     tf_shared_lock lock(mu_);
1342     const Tensor& cached_filter_data = cached_filter_data_;
1343     const Tensor& cached_filter_md = cached_filter_md_;
1344 
1345     // Check if the memory descriptor of the cached weights is the same as
1346     // filter_md. If so, we can use the cached weights; otherwise
1347     // return nullptr.
1348     if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
1349       return static_cast<Tfilter*>(
1350           const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
1351     }
1352     return nullptr;
1353   }
1354 };
1355 
1356 // Base class for fused convolution forward operations
1357 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
1358           typename Toutput, typename Ttemp_output, typename Tpadding,
1359           bool pad_enabled, bool native_format>
1360 class MklFusedConvOp
1361     : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
1362                        Tpadding, false, false, false, native_format> {
1363  public:
MklFusedConvOp(OpKernelConstruction * context)1364   explicit MklFusedConvOp(OpKernelConstruction* context)
1365       : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
1366                   Tpadding, false, false, false, native_format>(context) {
1367     // Since we came here through the registration of _MklFusedConv2D, get
1368     // all information from 'fused_ops' and 'num_args'
1369     std::vector<string> fused_ops;
1370     OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
1371 
1372     int num_args;
1373     OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
1374     OP_REQUIRES(context, !fused_ops.empty(),
1375                 errors::InvalidArgument(
1376                     "Fused Conv2D must have at least one fused op."));
1377 
1378     // TODO(intel-tf): Compact the code for activation checking
1379     if (fused_ops == std::vector<string>{"BiasAdd"}) {
1380       this->set_fuse_biasadd(true);
1381       OP_REQUIRES(context, num_args == 1,
1382                   errors::InvalidArgument(
1383                       "Fused Conv2D must have one extra argument: bias."));
1384     } else if (fused_ops == std::vector<string>{"Relu"}) {
1385       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu);
1386     } else if (fused_ops == std::vector<string>{"Relu6"}) {
1387       this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu,
1388                                 6.0);
1389     } else if (fused_ops == std::vector<string>{"Elu"}) {
1390       this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0);
1391     } else if (fused_ops == std::vector<string>{"LeakyRelu"}) {
1392       float leakyrelu_alpha;
1393       OP_REQUIRES_OK(context,
1394                      context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
1395       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu,
1396                                 leakyrelu_alpha);
1397     } else if (fused_ops == std::vector<string>{"FusedBatchNorm"}) {
1398       float epsilon;
1399       OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1400       OP_REQUIRES(
1401           context, num_args == 4,
1402           errors::InvalidArgument(
1403               "Fused Conv2D with batchnorm must have 4 extra argument"));
1404       this->set_fuse_bn(true, epsilon);
1405     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
1406       this->set_fuse_biasadd(true);
1407       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu);
1408       OP_REQUIRES(context, num_args == 1,
1409                   errors::InvalidArgument(
1410                       "Fused Conv2D must have one extra argument: bias."));
1411     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
1412       this->set_fuse_biasadd(true);
1413       this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu,
1414                                 6.0);
1415       OP_REQUIRES(context, num_args == 1,
1416                   errors::InvalidArgument(
1417                       "Fused Conv2D must have one extra argument: bias."));
1418     } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
1419       this->set_fuse_biasadd(true);
1420       this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0);
1421       OP_REQUIRES(context, num_args == 1,
1422                   errors::InvalidArgument(
1423                       "Fused Conv2D must have one extra argument: bias."));
1424     } else if (fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"}) {
1425       this->set_fuse_biasadd(true);
1426       float leakyrelu_alpha;
1427       OP_REQUIRES_OK(context,
1428                      context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
1429       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu,
1430                                 leakyrelu_alpha);
1431       OP_REQUIRES(context, num_args == 1,
1432                   errors::InvalidArgument(
1433                       "Fused Conv2D must have one extra argument: bias."));
1434     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add"}) {
1435       this->set_fuse_biasadd(true);
1436       this->set_fuse_add(true);
1437       OP_REQUIRES(
1438           context, num_args == 2,
1439           errors::InvalidArgument(
1440               "Fused Conv2D must have two extra arguments: bias and add."));
1441     } else if (fused_ops == std::vector<string>{"FusedBatchNorm", "Relu"}) {
1442       float epsilon;
1443       OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1444       OP_REQUIRES(
1445           context, num_args == 4,
1446           errors::InvalidArgument(
1447               "Fused Conv2D with batchnorm must have 4 extra argument"));
1448       this->set_fuse_bn(true, epsilon);
1449       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu);
1450     } else if (fused_ops == std::vector<string>{"FusedBatchNorm", "Relu6"}) {
1451       float epsilon;
1452       OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1453       OP_REQUIRES(
1454           context, num_args == 4,
1455           errors::InvalidArgument(
1456               "Fused Conv2D with batchnorm must have 4 extra argument"));
1457       this->set_fuse_bn(true, epsilon);
1458       this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu,
1459                                 6.0);
1460     } else if (fused_ops == std::vector<string>{"FusedBatchNorm", "Elu"}) {
1461       float epsilon;
1462       OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1463       OP_REQUIRES(
1464           context, num_args == 4,
1465           errors::InvalidArgument(
1466               "Fused Conv2D with batchnorm must have 4 extra argument"));
1467       this->set_fuse_bn(true, epsilon);
1468       this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0);
1469     } else if (fused_ops ==
1470                std::vector<string>{"FusedBatchNorm", "LeakyRelu"}) {
1471       float epsilon, leakyrelu_alpha;
1472       OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1473       OP_REQUIRES_OK(context,
1474                      context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
1475       OP_REQUIRES(
1476           context, num_args == 4,
1477           errors::InvalidArgument(
1478               "Fused Conv2D with batchnorm must have 4 extra argument"));
1479       this->set_fuse_bn(true, epsilon);
1480       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu,
1481                                 leakyrelu_alpha);
1482     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
1483       this->set_fuse_biasadd(true);
1484       this->set_fuse_add(true);
1485       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu);
1486       OP_REQUIRES(
1487           context, num_args == 2,
1488           errors::InvalidArgument(
1489               "Fused Conv2D must have two extra arguments: bias and add."));
1490     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
1491       this->set_fuse_biasadd(true);
1492       this->set_fuse_add(true);
1493       this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu,
1494                                 6.0);
1495       OP_REQUIRES(
1496           context, num_args == 2,
1497           errors::InvalidArgument(
1498               "Fused Conv2D must have two extra arguments: bias and add."));
1499     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
1500       this->set_fuse_biasadd(true);
1501       this->set_fuse_add(true);
1502       this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0);
1503       OP_REQUIRES(
1504           context, num_args == 2,
1505           errors::InvalidArgument(
1506               "Fused Conv2D must have two extra arguments: bias and add."));
1507     } else if (fused_ops ==
1508                std::vector<string>{"BiasAdd", "Add", "LeakyRelu"}) {
1509       this->set_fuse_biasadd(true);
1510       this->set_fuse_add(true);
1511       float leakyrelu_alpha;
1512       OP_REQUIRES_OK(context,
1513                      context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
1514       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu,
1515                                 leakyrelu_alpha);
1516       OP_REQUIRES(
1517           context, num_args == 2,
1518           errors::InvalidArgument(
1519               "Fused Conv2D must have two extra arguments: bias and add."));
1520     } else {
1521       OP_REQUIRES(context, false,
1522                   errors::Unimplemented("Fusion is not implemented: [",
1523                                         absl::StrJoin(fused_ops, ","), "]"));
1524     }
1525 
1526     if (pad_enabled) {
1527       this->set_fuse_pad(true);
1528     }
1529   }
1530 
ComputeBNScale(OpKernelContext * context,float epsilon,int bn_variance_index,Tinput * scale_buf_ptr)1531   void ComputeBNScale(OpKernelContext* context, float epsilon,
1532                       int bn_variance_index, Tinput* scale_buf_ptr) override {
1533     const Tensor& bn_var_tensor = MklGetInput(context, bn_variance_index);
1534 
1535     Eigen::Tensor<Tinput, 1, Eigen::RowMajor> bn_rsqrt =
1536         (bn_var_tensor.flat<Tinput>() + static_cast<Tinput>(epsilon)).rsqrt();
1537     Tinput* bn_rsqrt_data = bn_rsqrt.data();
1538     size_t num_elem = bn_var_tensor.shape().dim_size(0);
1539     for (size_t i = 0; i < num_elem; i++) {
1540       scale_buf_ptr[i] = bn_rsqrt_data[i];
1541     }
1542     return;
1543   }
1544 
~MklFusedConvOp()1545   virtual ~MklFusedConvOp() {}
1546 };
1547 
1548 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
1549           typename Toutput, typename Ttemp_output, typename Tpadding,
1550           bool pad_enabled, bool bias_enabled, bool is_depthwise,
1551           bool native_format>
1552 class MklFusedDepthwiseConvOp
1553     : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
1554                        Tpadding, bias_enabled, false, is_depthwise,
1555                        native_format> {
1556  public:
MklFusedDepthwiseConvOp(OpKernelConstruction * context)1557   explicit MklFusedDepthwiseConvOp(OpKernelConstruction* context)
1558       : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
1559                   Tpadding, bias_enabled, false, is_depthwise, native_format>(
1560             context) {
1561     // Since we came here through the registration of
1562     // _MklFusedDepthwiseConv2dNative, get all
1563     // information from 'fused_ops' and 'num_args'
1564     std::vector<string> fused_ops;
1565     OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
1566 
1567     int num_args;
1568     OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
1569     OP_REQUIRES(context, !fused_ops.empty(),
1570                 errors::InvalidArgument(
1571                     "Fused DepthwiseConv2D must have at least one fused op."));
1572 
1573     if (fused_ops == std::vector<string>{"BiasAdd"}) {
1574       this->set_fuse_biasadd(true);
1575     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
1576       this->set_fuse_biasadd(true);
1577       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu);
1578     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
1579       this->set_fuse_biasadd(true);
1580       this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu,
1581                                 6.0);
1582     } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
1583       this->set_fuse_biasadd(true);
1584       this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0);
1585     } else {
1586       OP_REQUIRES(context, false,
1587                   errors::Unimplemented("Fusion is not implemented: [",
1588                                         absl::StrJoin(fused_ops, ","), "]"));
1589     }
1590 
1591     OP_REQUIRES(
1592         context, num_args == 1,
1593         errors::InvalidArgument(
1594             "Fused DepthwiseConv2D must have one extra argument: bias."));
1595 
1596     if (pad_enabled) {
1597       this->set_fuse_pad(true);
1598     }
1599   }
1600 
~MklFusedDepthwiseConvOp()1601   virtual ~MklFusedDepthwiseConvOp() {}
1602 };
1603 
1604 // We create new class for each version of Quantized Convolution and inherit
1605 // from the FP32 version of the base class
1606 template <typename Device, typename Tinput, typename Tbias, typename Toutput,
1607           typename Ttemp_output, bool bias_enabled, bool is_depthwise,
1608           bool native_format = false>
1609 class MklQuantizedConv2DOp
1610     : public MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output,
1611                        int32, bias_enabled, false, is_depthwise,
1612                        native_format> {
1613  public:
~MklQuantizedConv2DOp()1614   virtual ~MklQuantizedConv2DOp() {
1615     if (this->input_bias_ != nullptr) {
1616       delete this->input_bias_;
1617       input_bias_ = nullptr;
1618     }
1619 
1620     if (this->scaled_bias_ != nullptr) {
1621       delete this->scaled_bias_;
1622       scaled_bias_ = nullptr;
1623     }
1624   }
1625 
MklQuantizedConv2DOp(OpKernelConstruction * context)1626   explicit MklQuantizedConv2DOp(OpKernelConstruction* context)
1627       : MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
1628                   bias_enabled, false, is_depthwise, native_format>(context) {
1629     bool is_filter_const;
1630     OP_REQUIRES_OK(context,
1631                    context->GetAttr("is_filter_const", &is_filter_const));
1632 
1633     if (bias_enabled) {
1634       OP_REQUIRES_OK(context,
1635                      context->GetAttr("is_bias_const", &is_bias_const_));
1636     }
1637 
1638     OP_REQUIRES(context, is_filter_const,
1639                 errors::InvalidArgument("Filter must be a constant"));
1640   }
1641 
Compute(OpKernelContext * context)1642   void Compute(OpKernelContext* context) override {
1643     // Compute int32 output tensor
1644     MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
1645               bias_enabled, false, is_depthwise,
1646               native_format>::Compute(context);
1647 
1648     // Compute additional outputs: min/max scalars.
1649     int bias_index_offset;
1650     bias_index_offset = bias_enabled ? 1 : 0;
1651 
1652     const float min_input =
1653         context->input(2 + bias_index_offset).flat<float>()(0);
1654     const float max_input =
1655         context->input(3 + bias_index_offset).flat<float>()(0);
1656 
1657     MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
1658     output_min_mkl_shape.SetMklTensor(false);
1659     output_max_mkl_shape.SetMklTensor(false);
1660 
1661     Tensor* output_min = nullptr;
1662     Tensor* output_max = nullptr;
1663     if (std::is_same<Toutput, quint8>::value ||
1664         std::is_same<Toutput, qint8>::value) {
1665       AllocateOutputSetMklShape(context, 1, &output_min, {},
1666                                 output_min_mkl_shape, native_format);
1667       AllocateOutputSetMklShape(context, 2, &output_max, {},
1668                                 output_max_mkl_shape, native_format);
1669       // This is the case the convolution and requantization are fused.
1670       output_min->flat<float>()(0) =
1671           context->input(6 + bias_index_offset).flat<float>()(0);
1672       output_max->flat<float>()(0) =
1673           context->input(7 + bias_index_offset).flat<float>()(0);
1674     } else {
1675       const Tensor& min_filter = context->input(4 + bias_index_offset);
1676       const Tensor& max_filter = context->input(5 + bias_index_offset);
1677       if (min_filter.dims() == 0) {
1678         float min_output_value;
1679         float max_output_value;
1680         MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>(
1681             min_input, max_input, min_filter.flat<float>()(0),
1682             max_filter.flat<float>()(0), &min_output_value, &max_output_value);
1683         AllocateOutputSetMklShape(context, 1, &output_min, {},
1684                                   output_min_mkl_shape, native_format);
1685         AllocateOutputSetMklShape(context, 2, &output_max, {},
1686                                   output_max_mkl_shape, native_format);
1687         output_min->flat<float>()(0) = min_output_value;
1688         output_max->flat<float>()(0) = max_output_value;
1689       } else {
1690         size_t depth = min_filter.NumElements();
1691         AllocateOutputSetMklShape(context, 1, &output_min,
1692                                   {static_cast<ptrdiff_t>(depth)},
1693                                   output_min_mkl_shape, native_format);
1694         AllocateOutputSetMklShape(context, 2, &output_max,
1695                                   {static_cast<ptrdiff_t>(depth)},
1696                                   output_max_mkl_shape, native_format);
1697         MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>(
1698             min_input, max_input, min_filter, max_filter, &output_min,
1699             &output_max);
1700       }
1701     }
1702   }
1703 
1704  protected:
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1705   void ExtendConvFwdParams(OpKernelContext* context,
1706                            MklConvFwdParams& params) override {
1707     MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
1708               bias_enabled, false, is_depthwise,
1709               native_format>::ExtendConvFwdParams(context, params);
1710 
1711     // When the output type is quint8, the output data id requantized
1712     // into quint8. A post_op "output_scale" is added to do the conversion.
1713     if (std::is_same<Toutput, quint8>::value ||
1714         std::is_same<Toutput, qint8>::value) {
1715       int bias_index_offset;
1716       bias_index_offset = bias_enabled ? 1 : 0;
1717 
1718       const float min_input =
1719           context->input(2 + bias_index_offset).flat<float>()(0);
1720       const float max_input =
1721           context->input(3 + bias_index_offset).flat<float>()(0);
1722       const Tensor& min_filter_vector = context->input(4 + bias_index_offset);
1723       const Tensor& max_filter_vector = context->input(5 + bias_index_offset);
1724 
1725       // min_freezed_output and max_freezed_output are the actual range
1726       // for the output.
1727       const float min_freezed_output =
1728           context->input(6 + bias_index_offset).flat<float>()(0);
1729       const float max_freezed_output =
1730           context->input(7 + bias_index_offset).flat<float>()(0);
1731 
1732       float int_output_limit =
1733           std::is_same<Toutput, quint8>::value ? 255.0f : 127.0f;
1734       size_t depth = min_filter_vector.NumElements();
1735       const float* min_filter = min_filter_vector.flat<float>().data();
1736       const float* max_filter = max_filter_vector.flat<float>().data();
1737       std::vector<float> scales(depth);
1738       float float_input_range =
1739           std::max(std::abs(min_input), std::abs(max_input));
1740       float float_output_range =
1741           std::max(std::abs(min_freezed_output), std::abs(max_freezed_output));
1742       const float int_const_scale_limit =
1743           (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0;
1744       for (size_t i = 0; i < depth; ++i) {
1745         // For simplicity and symmetry, we set filter range to be outer
1746         // bounds of min_filter and max_filter.
1747         float float_filter_range =
1748             std::max(std::abs(min_filter[i]), std::abs(max_filter[i]));
1749         // To understand the scaling, please see mkl_requantize_ops_test.
1750         scales[i] = int_output_limit * float_input_range * float_filter_range /
1751                     (int_const_scale_limit * float_output_range);
1752       }
1753       // we are creating a partial key here to use with primitive key caching to
1754       // improve key creation performance. Instead of using actual values we are
1755       // using the pointers for min/max_filter_vector, and this works since the
1756       // filter vector here is a constant.
1757       FactoryKeyCreator param_key;
1758       param_key.AddAsKey<float>(min_input);
1759       param_key.AddAsKey<float>(max_input);
1760       param_key.AddAsKey<float>(min_freezed_output);
1761       param_key.AddAsKey<float>(max_freezed_output);
1762       param_key.AddAsKey<const float*>(min_filter);
1763       param_key.AddAsKey<const float*>(max_filter);
1764       params.post_op_params.push_back(
1765           {"output_scale", dnnl::algorithm::undef, scales, param_key.GetKey()});
1766     }
1767   }
1768 
GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv_fwd_pd,const Tensor & bias_tensor)1769   Tbias* GetBiasHandle(OpKernelContext* context,
1770                        std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
1771                        const Tensor& bias_tensor) override {
1772     if (!bias_enabled) {
1773       return nullptr;
1774     }
1775     if (std::is_same<Tbias, qint32>::value) {
1776       return static_cast<Tbias*>(
1777           const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
1778     }
1779     int bias_index_offset;
1780     bias_index_offset = bias_enabled ? 1 : 0;
1781 
1782     const float min_input =
1783         context->input(2 + bias_index_offset).flat<float>()(0);
1784     const float max_input =
1785         context->input(3 + bias_index_offset).flat<float>()(0);
1786     const Tensor& min_filter_vector = context->input(4 + bias_index_offset);
1787     const Tensor& max_filter_vector = context->input(5 + bias_index_offset);
1788     const float* min_filter = min_filter_vector.flat<float>().data();
1789     const float* max_filter = max_filter_vector.flat<float>().data();
1790 
1791     const float int_const_scale_limit =
1792         (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0;
1793     // Re-scale bias if either of following 2 conditions are met:
1794     // 1. Bias is not const;
1795     // 2. Bias is const, but bias cache is empty (first iteration).
1796 
1797     size_t depth = min_filter_vector.NumElements();
1798     bool scales_are_valid = (depth == scales_.size());
1799     scales_.resize(depth);
1800     for (size_t i = 0; i < depth; ++i) {
1801       float tmp_scale =
1802           int_const_scale_limit /
1803           (std::max(std::abs(max_input), std::abs(min_input)) *
1804            std::max(std::abs(max_filter[i]), std::abs(min_filter[i])));
1805       if (scales_are_valid && std::abs(tmp_scale - scales_[i]) > 1e-6) {
1806         scales_are_valid = false;
1807       }
1808       scales_[i] = tmp_scale;
1809     }
1810     if (!is_bias_const_ || IsBiasCacheEmpty(context) || !scales_are_valid) {
1811       dnnl::primitive_attr bias_attr;
1812       if (depth == 1) {
1813         bias_attr.set_output_scales(0, scales_);
1814       } else {
1815         bias_attr.set_output_scales(1, scales_);
1816       }
1817 
1818       auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())},
1819                                   MklDnnType<Tbias>(), memory::format_tag::x);
1820       void* bias_buf = static_cast<void*>(
1821           const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
1822       if (!input_bias_) {
1823         input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf);
1824       } else {
1825         input_bias_->set_data_handle(bias_buf);
1826       }
1827 
1828       if (!scaled_bias_buf_)
1829         AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
1830                               conv_fwd_pd->bias_desc(), &scaled_bias_buf_);
1831       if (!scaled_bias_) {
1832         scaled_bias_ = new memory(bias_md, this->cpu_engine_, scaled_bias_buf_);
1833       } else {
1834         scaled_bias_->set_data_handle(scaled_bias_buf_);
1835       }
1836       auto reorder_desc =
1837           ReorderPd(this->cpu_engine_, input_bias_->get_desc(),
1838                     this->cpu_engine_, scaled_bias_->get_desc(), bias_attr);
1839       CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
1840                               this->cpu_engine_, context);
1841 
1842       Tbias* bias_data =
1843           reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
1844       if (is_bias_const_)
1845         CacheBias(context, conv_fwd_pd, bias_data, scaled_bias_);
1846 
1847       return bias_data;
1848     }
1849     return GetCachedBias(context);
1850   }
1851 
1852   bool is_bias_const_;
1853   Tensor cached_bias_data_ TF_GUARDED_BY(bias_cache_mu_);
1854 
1855   memory* input_bias_ = nullptr;
1856   memory* scaled_bias_ = nullptr;
1857 
1858   Tensor scaled_bias_tensor_;
1859   void* scaled_bias_buf_ = nullptr;
1860 
1861  private:
1862   std::vector<float> scales_;
1863   mutex bias_cache_mu_;
1864   // Allocate tensors for cached bias data and
1865   // cached bias memory descriptor (data format)
AllocateTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** bias_tensor)1866   void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc,
1867                       Tensor** bias_tensor) {
1868     DCHECK(bias_tensor);
1869     TensorShape bias_tf_shape;
1870     bias_tf_shape.AddDim(
1871         (conv_prim_desc.bias_desc().get_size() / sizeof(Tbias)));
1872     OP_REQUIRES_OK(context,
1873                    context->allocate_temp(DataTypeToEnum<Tbias>::value,
1874                                           bias_tf_shape, &cached_bias_data_));
1875     *bias_tensor = &cached_bias_data_;
1876   }
1877 
1878   // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
1879   // be acquired before entering the function, since it is acquired
1880   // inside the function.
IsBiasCacheEmpty(OpKernelContext * context)1881   inline bool IsBiasCacheEmpty(OpKernelContext* context)
1882       TF_LOCKS_EXCLUDED(bias_cache_mu_) {
1883     tf_shared_lock lock(bias_cache_mu_);
1884     return (cached_bias_data_.NumElements() == 0);
1885   }
1886 
1887   // Cache the converted bias in a tensor.
1888   // Only one thread can execute this method at any given time.
CacheBias(OpKernelContext * context,const std::shared_ptr<ConvFwdPd> & conv_fwd_pd,Tbias * bias_data,const memory * scaled_bias)1889   void CacheBias(OpKernelContext* context,
1890                  const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
1891                  Tbias* bias_data, const memory* scaled_bias)
1892       TF_LOCKS_EXCLUDED(bias_cache_mu_) {
1893     mutex_lock lock(bias_cache_mu_);
1894 
1895     // If bias is already cached, there's nothing to do.
1896     if (cached_bias_data_.NumElements() > 0) {
1897       return;
1898     }
1899 
1900     // Otherwise, cache bias
1901     Tensor* bias_tensor_ptr = nullptr;
1902     AllocateTensor(context, *conv_fwd_pd, &bias_tensor_ptr);
1903     void* cached_bias_data = const_cast<void*>(
1904         static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data()));
1905     size_t cached_bias_data_size = scaled_bias->get_desc().get_size();
1906     memcpy(cached_bias_data, bias_data, cached_bias_data_size);
1907   }
1908 
GetCachedBias(OpKernelContext * context)1909   Tbias* GetCachedBias(OpKernelContext* context)
1910       TF_LOCKS_EXCLUDED(bias_cache_mu_) {
1911     tf_shared_lock lock(bias_cache_mu_);
1912     const Tensor& cached_bias_data = cached_bias_data_;
1913 
1914     return static_cast<Tbias*>(
1915         const_cast<Tbias*>(cached_bias_data.flat<Tbias>().data()));
1916   }
1917 };
1918 
1919 template <typename Device, typename Tinput, typename Tbias, typename Toutput,
1920           typename Ttemp_output, bool bias_enabled, bool is_depthwise,
1921           bool native_format = false>
1922 class MklQuantizedConv2DReluOp
1923     : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1924                                   bias_enabled, is_depthwise, native_format> {
1925  public:
~MklQuantizedConv2DReluOp()1926   virtual ~MklQuantizedConv2DReluOp() {}
1927 
MklQuantizedConv2DReluOp(OpKernelConstruction * context)1928   explicit MklQuantizedConv2DReluOp(OpKernelConstruction* context)
1929       : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1930                              bias_enabled, is_depthwise, native_format>(
1931             context) {}
1932 
1933  protected:
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1934   void ExtendConvFwdParams(OpKernelContext* context,
1935                            MklConvFwdParams& params) override {
1936     MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1937                          bias_enabled, is_depthwise,
1938                          native_format>::ExtendConvFwdParams(context, params);
1939 
1940     params.post_op_params.push_back(
1941         {"activation", dnnl::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
1942   }
1943 };
1944 
1945 template <typename Device, typename Tinput, typename Tbias, typename Toutput,
1946           typename Ttemp_output, bool bias_enabled, bool is_depthwise,
1947           bool native_format = false>
1948 class MklQuantizedConv2DSumReluOp
1949     : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1950                                   bias_enabled, is_depthwise, native_format> {
1951  public:
~MklQuantizedConv2DSumReluOp()1952   virtual ~MklQuantizedConv2DSumReluOp() {}
1953 
MklQuantizedConv2DSumReluOp(OpKernelConstruction * context)1954   explicit MklQuantizedConv2DSumReluOp(OpKernelConstruction* context)
1955       : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1956                              bias_enabled, is_depthwise, native_format>(
1957             context) {}
1958 
1959  protected:
ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1960   void ExtendConvFwdParams(OpKernelContext* context,
1961                            MklConvFwdParams& params) override {
1962     MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
1963                          bias_enabled, is_depthwise,
1964                          native_format>::ExtendConvFwdParams(context, params);
1965     // Calculate the scale (beta in oneDNN API term) for sum
1966     if (std::is_same<Toutput, quint8>::value) {
1967       int summand_idx = native_format ? context->num_inputs() - 1 - 2
1968                                       : context->num_inputs() / 2 - 1 - 2;
1969       DataType summand_type = this->input_type(summand_idx);
1970       bool summand_condition =
1971           (summand_type == DT_QINT8) || (summand_type == DT_QUINT8);
1972       CHECK((summand_condition));
1973       int bias_index_offset = bias_enabled ? 1 : 0;
1974       const float min_freezed_output =
1975           context->input(6 + bias_index_offset).flat<float>()(0);
1976       const float max_freezed_output =
1977           context->input(7 + bias_index_offset).flat<float>()(0);
1978       const float min_freezed_summand =
1979           context->input(9 + bias_index_offset).flat<float>()(0);
1980       const float max_freezed_summand =
1981           context->input(10 + bias_index_offset).flat<float>()(0);
1982 
1983       float scale_output =
1984           std::max(std::abs(min_freezed_output), std::abs(max_freezed_output));
1985       float scale_summand = std::max(std::abs(min_freezed_summand),
1986                                      std::abs(max_freezed_summand));
1987       // if summand_type is also DT_QUINT8 as the scale_output,
1988       // the scaling factor of 255.0f cancels each other and thus is avoided.
1989       // If it is not then  it is DT_INT8 and is scaled appropriately.
1990       if (summand_type == DT_QUINT8) {
1991         params.post_op_params.push_back({"sum",
1992                                          dnnl::algorithm::undef,
1993                                          {scale_summand / scale_output},
1994                                          ""});
1995       } else {
1996         params.post_op_params.push_back(
1997             {"sum",
1998              dnnl::algorithm::undef,
1999              {255.0f * scale_summand / (scale_output * 127.0f)},
2000              ""});
2001       }
2002     } else {
2003       params.post_op_params.push_back(
2004           {"sum", dnnl::algorithm::undef, {1.0}, ""});
2005     }
2006     params.post_op_params.push_back(
2007         {"activation", dnnl::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
2008   }
2009 
AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,MklTensorFormat output_tf_format,MklDnnShape * output_mkl_shape,Tensor ** output_tensor)2010   void AllocateOutputTensor(OpKernelContext* context,
2011                             const ConvFwdPd& conv_prim_desc,
2012                             const memory::dims& output_dims_mkl_order,
2013                             MklTensorFormat output_tf_format,
2014                             MklDnnShape* output_mkl_shape,
2015                             Tensor** output_tensor) override {
2016     int summand_idx = native_format ? context->num_inputs() - 1
2017                                     : context->num_inputs() / 2 - 1;
2018     if (std::is_same<Toutput, quint8>::value) {
2019       summand_idx -= 2;
2020       DataType summand_type = this->input_type(summand_idx);
2021       bool summand_condition =
2022           (summand_type == DT_QINT8) || (summand_type == DT_QUINT8);
2023       CHECK((summand_condition));
2024       Tensor& summand = const_cast<Tensor&>(MklGetInput(context, summand_idx));
2025       MklDnnShape summand_mkl_shape;
2026       GetMklShape(context, summand_idx, &summand_mkl_shape, native_format);
2027       auto dst_md = summand_mkl_shape.GetMklLayout();
2028 
2029       // TODO(intel-tf): Handle both non-MKL and MKL tensors
2030       if (summand_type == DT_QINT8) {
2031         OP_REQUIRES_OK(
2032             context, summand.BitcastFrom(summand, DT_QUINT8, summand.shape()));
2033         dst_md.data.data_type =
2034             static_cast<dnnl_data_type_t>(MklDnnType<Toutput>());
2035         summand_mkl_shape.SetMklLayout(&dst_md);
2036         summand_mkl_shape.SetElemType(MklDnnType<Toutput>());
2037       }
2038       // TODO(intel-tf): Support cases when summand cannot be forwarded.
2039       OP_REQUIRES(context,
2040                   native_format
2041                       ? context->forward_input_to_output_with_shape(
2042                             summand_idx, 0, summand.shape(), output_tensor)
2043                       : ForwardMklTensorInToOutWithMklShape(
2044                             context, summand_idx, 0, output_tensor,
2045                             summand_mkl_shape, false),
2046                   errors::InvalidArgument(
2047                       "Summand cannot be forwarded in the current fusion."));
2048       return;
2049     }
2050     MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
2051               bias_enabled, false, false,
2052               native_format>::AllocateOutputTensor(context, conv_prim_desc,
2053                                                    output_dims_mkl_order,
2054                                                    output_tf_format,
2055                                                    output_mkl_shape,
2056                                                    output_tensor);
2057     const Tensor& summand = MklGetInput(context, summand_idx);
2058     if (summand.dtype() != DT_FLOAT)
2059       TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION,
2060                          "Current fusion requires summand to be float"));
2061     MklDnnShape summand_mkl_shape;
2062     GetMklShape(context, summand_idx, &summand_mkl_shape, native_format);
2063     // We need to compute scale for the summand
2064     int bias_index_offset = bias_enabled ? 1 : 0;
2065     const float min_input =
2066         context->input(2 + bias_index_offset).flat<float>()(0);
2067     const float max_input =
2068         context->input(3 + bias_index_offset).flat<float>()(0);
2069     const Tensor& min_filter_vector = context->input(4 + bias_index_offset);
2070     const Tensor& max_filter_vector = context->input(5 + bias_index_offset);
2071     const float* min_filter = min_filter_vector.flat<float>().data();
2072     const float* max_filter = max_filter_vector.flat<float>().data();
2073 
2074     const float int_const_scale_limit =
2075         (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0;
2076     size_t depth = min_filter_vector.NumElements();
2077     std::vector<float> scales(depth);
2078     for (size_t i = 0; i < depth; ++i) {
2079       // TODO(intel-tf): scale factors for UINT8(inputs) & INT8(weights) are
2080       // done regularly. A Cleaner design to address all mapping in one
2081       // function needs to be implemented in future which also supports other
2082       // quantized type mapping in future.
2083       scales[i] = int_const_scale_limit /
2084                   (std::max(std::abs(max_input), std::abs(min_input)) *
2085                    std::max(std::abs(max_filter[i]), std::abs(min_filter[i])));
2086     }
2087     dnnl::primitive_attr reorder_attr;
2088     if (depth == 1) {
2089       reorder_attr.set_output_scales(0, scales);
2090     } else {
2091       reorder_attr.set_output_scales(2, scales);
2092     }
2093     auto summand_md =
2094         summand_mkl_shape.IsMklTensor()
2095             ? summand_mkl_shape.GetMklLayout()
2096             : memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
2097                            memory::format_tag::nhwc);
2098     void* summand_buf =
2099         static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
2100     void* dst_buf =
2101         static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
2102     summand_.reset(new memory(summand_md, this->cpu_engine_, summand_buf));
2103     dst_.reset(
2104         new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf));
2105     auto reorder_desc =
2106         ReorderPd(this->cpu_engine_, summand_md, this->cpu_engine_,
2107                   conv_prim_desc.dst_desc(), reorder_attr);
2108     CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
2109                             context);
2110   }
2111 
2112   std::shared_ptr<dnnl::memory> summand_;
2113   std::shared_ptr<dnnl::memory> dst_;
2114 };
2115 
2116 // Base class for fused convolution forward operations
2117 template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
2118           typename Toutput, typename Ttemp_output, typename Tpadding,
2119           bool pad_enabled, bool native_format>
2120 class MklFusedConv3DOp
2121     : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
2122                        Tpadding, false, false, false, native_format> {
2123  public:
MklFusedConv3DOp(OpKernelConstruction * context)2124   explicit MklFusedConv3DOp(OpKernelConstruction* context)
2125       : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
2126                   Tpadding, false, false, false, native_format>(context) {
2127     // Since we came here through the registration of _MklFusedConv3D, get
2128     // all information from 'fused_ops' and 'num_args'
2129     std::vector<string> fused_ops;
2130     OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
2131 
2132     int num_args;
2133     OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
2134 
2135     std::vector<int> padding_list;
2136     OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list));
2137     if (padding_list.empty()) {
2138       OP_REQUIRES(context, !fused_ops.empty(),
2139                   errors::InvalidArgument("Fused Conv3D must have at least one "
2140                                           "fused op when Pad is not fused."));
2141       if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") ==
2142           fused_ops.end()) {
2143         OP_REQUIRES(context, num_args == 1,
2144                     errors::InvalidArgument(
2145                         "Fused Conv3D must have one extra argument: bias."));
2146       } else if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") ==
2147                      fused_ops.end() &&
2148                  std::find(fused_ops.begin(), fused_ops.end(), "Add") ==
2149                      fused_ops.end()) {
2150         OP_REQUIRES(
2151             context, num_args == 2,
2152             errors::InvalidArgument(
2153                 "Fused Conv3D must have two extra arguments: bias and add."));
2154       }
2155     }
2156 
2157     if (fused_ops == std::vector<string>{"BiasAdd"}) {
2158       this->set_fuse_biasadd(true);
2159     } else if (fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"}) {
2160       this->set_fuse_biasadd(true);
2161       float leakyrelu_alpha;
2162       OP_REQUIRES_OK(context,
2163                      context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
2164       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu,
2165                                 leakyrelu_alpha);
2166     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
2167       this->set_fuse_biasadd(true);
2168       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu);
2169     } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
2170       this->set_fuse_biasadd(true);
2171       this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu,
2172                                 6.0);
2173     } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
2174       this->set_fuse_biasadd(true);
2175       this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0);
2176     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add"}) {
2177       this->set_fuse_biasadd(true);
2178       this->set_fuse_add(true);
2179     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
2180       this->set_fuse_biasadd(true);
2181       this->set_fuse_add(true);
2182       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu);
2183     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
2184       this->set_fuse_biasadd(true);
2185       this->set_fuse_add(true);
2186       this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu,
2187                                 6.0);
2188     } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
2189       this->set_fuse_biasadd(true);
2190       this->set_fuse_add(true);
2191       this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0);
2192     } else if (fused_ops ==
2193                std::vector<string>{"BiasAdd", "Add", "LeakyRelu"}) {
2194       this->set_fuse_biasadd(true);
2195       this->set_fuse_add(true);
2196       float leakyrelu_alpha;
2197       OP_REQUIRES_OK(context,
2198                      context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
2199       this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu,
2200                                 leakyrelu_alpha);
2201     } else {
2202       if (padding_list.empty()) {
2203         OP_REQUIRES(context, false,
2204                     errors::Unimplemented("Fusion is not implemented: [",
2205                                           absl::StrJoin(fused_ops, ","), "]"));
2206       }
2207     }
2208   }
2209 
~MklFusedConv3DOp()2210   virtual ~MklFusedConv3DOp() {}
2211 };
2212 
2213 #define REGISTER_MKL_KERNEL(op, kernel, input_type, bias_type, output_type, \
2214                             accu_type, has_bias, is_depthwise, is_native)   \
2215   REGISTER_KERNEL_BUILDER(                                                  \
2216       Name(op)                                                              \
2217           .Device(DEVICE_CPU)                                               \
2218           .TypeConstraint<input_type>("Tinput")                             \
2219           .TypeConstraint<qint8>("Tfilter") BIAS_TYPE_CONSTRAINT(bias_type) \
2220           .TypeConstraint<output_type>("out_type") LABEL,                   \
2221       kernel TEMPLATE_ARGS(CPUDevice, input_type, bias_type, output_type,   \
2222                            accu_type, has_bias, is_depthwise, is_native));
2223 
2224 #define REGISTER_MKL_KERNEL_ALL_INPUT_TYPES(op, kernel, bias_type,            \
2225                                             output_type, accu_type, has_bias, \
2226                                             is_depthwise, is_native)          \
2227   REGISTER_MKL_KERNEL(op, kernel, qint8, bias_type, output_type, accu_type,   \
2228                       has_bias, is_depthwise, is_native);                     \
2229   REGISTER_MKL_KERNEL(op, kernel, quint8, bias_type, output_type, accu_type,  \
2230                       has_bias, is_depthwise, is_native);
2231 
2232 #define REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(op, kernel, input_type,            \
2233                                            output_type, accu_type, has_bias,  \
2234                                            is_depthwise, is_native)           \
2235   REGISTER_MKL_KERNEL(op, kernel, input_type, qint32, output_type, accu_type, \
2236                       has_bias, is_depthwise, is_native);                     \
2237   REGISTER_MKL_KERNEL(op, kernel, input_type, float, output_type, accu_type,  \
2238                       has_bias, is_depthwise, is_native);
2239 
2240 #define REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES(                      \
2241     op, kernel, output_type, accu_type, has_bias, is_depthwise, is_native) \
2242   REGISTER_MKL_KERNEL_ALL_INPUT_TYPES(op, kernel, qint32, output_type,     \
2243                                       accu_type, has_bias, is_depthwise,   \
2244                                       is_native);                          \
2245   REGISTER_MKL_KERNEL_ALL_INPUT_TYPES(op, kernel, float, output_type,      \
2246                                       accu_type, has_bias, is_depthwise,   \
2247                                       is_native);
2248 
2249 #define LABEL
2250 #define TEMPLATE_ARGS(CPUDevice, input_type, bias_type, output_type, \
2251                       accu_type, has_bias, is_depthwise, is_native)
2252 #define BIAS_TYPE_CONSTRAINT(bias_type)
2253 
2254 REGISTER_MKL_KERNEL("QuantizedConv2D", NoOp, quint8, float, qint32, qint32,
2255                     false, false, false);
2256 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("QuantizedConv2DWithBias", NoOp, float,
2257                                     qint32, qint32, false, false, false);
2258 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("QuantizedConv2DWithBiasAndRelu", NoOp,
2259                                     float, qint32, qint32, false, false, false);
2260 REGISTER_MKL_KERNEL("QuantizedConv2DWithBiasSumAndRelu", NoOp, quint8, float,
2261                     qint32, qint32, false, false, false);
2262 REGISTER_MKL_KERNEL("QuantizedConv2DAndRequantize", NoOp, quint8, float, qint8,
2263                     qint8, false, false, false);
2264 REGISTER_MKL_KERNEL("QuantizedConv2DPerChannel", NoOp, quint8, float, qint32,
2265                     qint32, false, false, false);
2266 REGISTER_MKL_KERNEL("QuantizedConv2DAndRelu", NoOp, quint8, float, qint32,
2267                     qint32, false, false, false);
2268 REGISTER_MKL_KERNEL("QuantizedConv2DAndReluAndRequantize", NoOp, quint8, float,
2269                     quint8, quint8, false, false, false);
2270 REGISTER_MKL_KERNEL("QuantizedDepthwiseConv2D", NoOp, quint8, float, qint32,
2271                     qint32, false, false, false);
2272 REGISTER_MKL_KERNEL("QuantizedDepthwiseConv2DWithBias", NoOp, quint8, float,
2273                     qint32, qint32, false, false, false);
2274 REGISTER_MKL_KERNEL("QuantizedDepthwiseConv2DWithBiasAndRelu", NoOp, quint8,
2275                     float, qint32, qint32, false, false, false);
2276 #undef BIAS_TYPE_CONSTRAINT
2277 
2278 #define BIAS_TYPE_CONSTRAINT(bias_type) .TypeConstraint<bias_type>("Tbias")
2279 REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES(
2280     "QuantizedConv2DWithBiasAndRequantize", NoOp, qint8, qint8, false, false,
2281     false);
2282 REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES(
2283     "QuantizedConv2DWithBiasAndReluAndRequantize", NoOp, quint8, quint8, false,
2284     false, false);
2285 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(
2286     "QuantizedConv2DWithBiasSumAndReluAndRequantize", NoOp, quint8, quint8,
2287     quint8, false, false, false);
2288 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(
2289     "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", NoOp, quint8,
2290     quint8, qint8, false, false, false);
2291 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(
2292     "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", NoOp, quint8,
2293     quint8, quint8, false, false, false);
2294 #undef BIAS_TYPE_CONSTRAINT
2295 #undef TEMPLATE_ARGS
2296 #undef LABEL
2297 
2298 #define LABEL .Label(mkl_op_registry::kMklQuantizedOpLabel)
2299 #define TEMPLATE_ARGS(CPUDevice, input_type, bias_type, output_type, \
2300                       accu_type, has_bias, is_depthwise, is_native)  \
2301 <CPUDevice, input_type, bias_type, output_type, accu_type, has_bias, \
2302       is_depthwise, is_native>
2303 #define BIAS_TYPE_CONSTRAINT(bias_type)
2304 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("_MklQuantizedConv2D", MklQuantizedConv2DOp,
2305                                     float, qint32, qint32, false, false, true);
2306 REGISTER_MKL_KERNEL("_MklQuantizedConv2DPerChannel", MklQuantizedConv2DOp,
2307                     quint8, float, qint32, qint32, false, false, true);
2308 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("_MklQuantizedConv2DWithBias",
2309                                     MklQuantizedConv2DOp, float, qint32, qint32,
2310                                     true, false, true);
2311 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("_MklQuantizedConv2DWithBiasAndRelu",
2312                                     MklQuantizedConv2DReluOp, float, qint32,
2313                                     qint32, true, false, true);
2314 REGISTER_MKL_KERNEL("_MklQuantizedConv2DWithBiasSumAndRelu",
2315                     MklQuantizedConv2DSumReluOp, quint8, float, qint32, qint32,
2316                     true, false, true);
2317 REGISTER_MKL_KERNEL("_MklQuantizedConv2DAndRequantize", MklQuantizedConv2DOp,
2318                     quint8, float, qint8, qint8, false, false, true);
2319 REGISTER_MKL_KERNEL("_MklQuantizedConv2DAndRelu", MklQuantizedConv2DReluOp,
2320                     quint8, float, qint32, qint32, false, false, true);
2321 REGISTER_MKL_KERNEL("_MklQuantizedConv2DAndReluAndRequantize",
2322                     MklQuantizedConv2DReluOp, quint8, float, quint8, quint8,
2323                     false, false, true);
2324 REGISTER_MKL_KERNEL("_MklQuantizedDepthwiseConv2D", MklQuantizedConv2DOp,
2325                     quint8, float, qint32, qint32, false, true, true);
2326 REGISTER_MKL_KERNEL("_MklQuantizedDepthwiseConv2DWithBias",
2327                     MklQuantizedConv2DOp, quint8, float, qint32, qint32, true,
2328                     true, true);
2329 REGISTER_MKL_KERNEL("_MklQuantizedDepthwiseConv2DWithBiasAndRelu",
2330                     MklQuantizedConv2DReluOp, quint8, float, qint32, qint32,
2331                     true, true, true);
2332 #undef BIAS_TYPE_CONSTRAINT
2333 
2334 #define BIAS_TYPE_CONSTRAINT(bias_type) .TypeConstraint<bias_type>("Tbias")
2335 REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES(
2336     "_MklQuantizedConv2DWithBiasAndRequantize", MklQuantizedConv2DOp, qint8,
2337     qint8, true, false, true);
2338 REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES(
2339     "_MklQuantizedConv2DWithBiasAndReluAndRequantize", MklQuantizedConv2DReluOp,
2340     quint8, quint8, true, false, true);
2341 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(
2342     "_MklQuantizedConv2DWithBiasSumAndReluAndRequantize",
2343     MklQuantizedConv2DSumReluOp, quint8, quint8, quint8, true, false, true);
2344 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(
2345     "_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize",
2346     MklQuantizedConv2DSumReluOp, quint8, quint8, qint8, true, false, true);
2347 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(
2348     "_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize",
2349     MklQuantizedConv2DReluOp, quint8, quint8, quint8, true, true, true);
2350 #undef BIAS_TYPE_CONSTRAINT
2351 #undef TEMPLATE_ARGS
2352 #undef LABEL
2353 
2354 // Register NoOp kernel for ops that will be rewritten to the _Mkl* version
2355 
2356 #define REGISTER_NO_OP_CPU_2D_DEPTHWISE(T)                    \
2357   REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative") \
2358                               .Device(DEVICE_CPU)             \
2359                               .TypeConstraint<T>("T"),        \
2360                           NoOp);
2361 
2362 TF_CALL_float(REGISTER_NO_OP_CPU_2D_DEPTHWISE);
2363 TF_CALL_bfloat16(REGISTER_NO_OP_CPU_2D_DEPTHWISE);
2364 
2365 // Register 2D operations
2366 #define REGISTER_MKL_CPU_2D(T)                                                 \
2367   REGISTER_KERNEL_BUILDER(                                                     \
2368       Name("_MklConv2D")                                                       \
2369           .Device(DEVICE_CPU)                                                  \
2370           .TypeConstraint<T>("T")                                              \
2371           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2372       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \
2373   REGISTER_KERNEL_BUILDER(                                                     \
2374       Name("_MklConv2DWithBias")                                               \
2375           .Device(DEVICE_CPU)                                                  \
2376           .TypeConstraint<T>("T")                                              \
2377           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2378       MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, false>);  \
2379   REGISTER_KERNEL_BUILDER(                                                     \
2380       Name("__MklDummyConv2DWithBias")                                         \
2381           .Device(DEVICE_CPU)                                                  \
2382           .TypeConstraint<T>("T")                                              \
2383           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2384       MklDummyOp<CPUDevice, T>);                                               \
2385   REGISTER_KERNEL_BUILDER(                                                     \
2386       Name("_MklPadWithConv2D")                                                \
2387           .Device(DEVICE_CPU)                                                  \
2388           .TypeConstraint<T>("T")                                              \
2389           .TypeConstraint<int32>("Tpaddings")                                  \
2390           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2391       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, false>);  \
2392   REGISTER_KERNEL_BUILDER(                                                     \
2393       Name("_MklPadWithConv2D")                                                \
2394           .Device(DEVICE_CPU)                                                  \
2395           .TypeConstraint<T>("T")                                              \
2396           .TypeConstraint<int64_t>("Tpaddings")                                \
2397           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2398       MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, false>);  \
2399   REGISTER_KERNEL_BUILDER(                                                     \
2400       Name("__MklDummyPadWithConv2D")                                          \
2401           .Device(DEVICE_CPU)                                                  \
2402           .TypeConstraint<T>("T")                                              \
2403           .TypeConstraint<int32>("Tpaddings")                                  \
2404           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2405       MklDummyOp<CPUDevice, T>);                                               \
2406   REGISTER_KERNEL_BUILDER(                                                     \
2407       Name("_MklNativeConv2D")                                                 \
2408           .Device(DEVICE_CPU)                                                  \
2409           .TypeConstraint<T>("T")                                              \
2410           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2411       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>);  \
2412   REGISTER_KERNEL_BUILDER(                                                     \
2413       Name("_MklNativeConv2DWithBias")                                         \
2414           .Device(DEVICE_CPU)                                                  \
2415           .TypeConstraint<T>("T")                                              \
2416           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2417       MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, true>);   \
2418   REGISTER_KERNEL_BUILDER(                                                     \
2419       Name("_MklNativePadWithConv2D")                                          \
2420           .Device(DEVICE_CPU)                                                  \
2421           .TypeConstraint<T>("T")                                              \
2422           .TypeConstraint<int32>("Tpaddings")                                  \
2423           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2424       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, true>);   \
2425   REGISTER_KERNEL_BUILDER(                                                     \
2426       Name("_MklNativePadWithConv2D")                                          \
2427           .Device(DEVICE_CPU)                                                  \
2428           .TypeConstraint<T>("T")                                              \
2429           .TypeConstraint<int64_t>("Tpaddings")                                \
2430           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2431       MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, true>);
2432 
2433 TF_CALL_float(REGISTER_MKL_CPU_2D);
2434 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D);
2435 
2436 #define REGISTER_MKL_CPU_2D_DEPTHWISE(T)                                      \
2437   REGISTER_KERNEL_BUILDER(                                                    \
2438       Name("_MklDepthwiseConv2dNative")                                       \
2439           .Device(DEVICE_CPU)                                                 \
2440           .TypeConstraint<T>("T")                                             \
2441           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                \
2442       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, false>); \
2443   REGISTER_KERNEL_BUILDER(                                                    \
2444       Name("_MklFusedDepthwiseConv2dNative")                                  \
2445           .Device(DEVICE_CPU)                                                 \
2446           .TypeConstraint<T>("T")                                             \
2447           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                \
2448       MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true,   \
2449                               true, false>);                                  \
2450   REGISTER_KERNEL_BUILDER(                                                    \
2451       Name("_MklNativeFusedDepthwiseConv2dNative")                            \
2452           .Device(DEVICE_CPU)                                                 \
2453           .TypeConstraint<T>("T")                                             \
2454           .Label(mkl_op_registry::kMklNameChangeOpLabel),                     \
2455       MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true,   \
2456                               true, true>);                                   \
2457   REGISTER_KERNEL_BUILDER(                                                    \
2458       Name("_MklNativeDepthwiseConv2dNative")                                 \
2459           .Device(DEVICE_CPU)                                                 \
2460           .TypeConstraint<T>("T")                                             \
2461           .Label(mkl_op_registry::kMklNameChangeOpLabel),                     \
2462       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, true>);
2463 
2464 TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE);
2465 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE);
2466 
2467 // Note we are registering _MklFusedConv2D.
2468 // We check the fused_ops attributes to decide if bias is enabled or not.
2469 #define REGISTER_MKL_CPU_2D_FUSED(T)                                  \
2470   REGISTER_KERNEL_BUILDER(                                            \
2471       Name("_MklFusedConv2D")                                         \
2472           .Device(DEVICE_CPU)                                         \
2473           .TypeConstraint<T>("T")                                     \
2474           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),        \
2475       MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false, false>); \
2476   REGISTER_KERNEL_BUILDER(                                            \
2477       Name("_MklPadWithFusedConv2D")                                  \
2478           .Device(DEVICE_CPU)                                         \
2479           .TypeConstraint<int32>("Tpaddings")                         \
2480           .TypeConstraint<T>("T")                                     \
2481           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),        \
2482       MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true, false>);  \
2483   REGISTER_KERNEL_BUILDER(                                            \
2484       Name("_MklPadWithFusedConv2D")                                  \
2485           .Device(DEVICE_CPU)                                         \
2486           .TypeConstraint<T>("T")                                     \
2487           .TypeConstraint<int64_t>("Tpaddings")                       \
2488           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),        \
2489       MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true, false>);  \
2490   REGISTER_KERNEL_BUILDER(                                            \
2491       Name("__MklDummyPadWithFusedConv2D")                            \
2492           .Device(DEVICE_CPU)                                         \
2493           .TypeConstraint<T>("T")                                     \
2494           .TypeConstraint<int32>("Tpaddings")                         \
2495           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),        \
2496       MklDummyOp<CPUDevice, T>);                                      \
2497   REGISTER_KERNEL_BUILDER(                                            \
2498       Name("_MklNativeFusedConv2D")                                   \
2499           .Device(DEVICE_CPU)                                         \
2500           .TypeConstraint<T>("T")                                     \
2501           .Label(mkl_op_registry::kMklNameChangeOpLabel),             \
2502       MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false, true>);  \
2503   REGISTER_KERNEL_BUILDER(                                            \
2504       Name("_MklNativePadWithFusedConv2D")                            \
2505           .Device(DEVICE_CPU)                                         \
2506           .TypeConstraint<int32>("Tpaddings")                         \
2507           .TypeConstraint<T>("T")                                     \
2508           .Label(mkl_op_registry::kMklNameChangeOpLabel),             \
2509       MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true, true>);   \
2510   REGISTER_KERNEL_BUILDER(                                            \
2511       Name("_MklNativePadWithFusedConv2D")                            \
2512           .Device(DEVICE_CPU)                                         \
2513           .TypeConstraint<T>("T")                                     \
2514           .TypeConstraint<int64_t>("Tpaddings")                       \
2515           .Label(mkl_op_registry::kMklNameChangeOpLabel),             \
2516       MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true, true>);
2517 
2518 TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED);
2519 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED);
2520 
2521 // Register 3D operations
2522 #define REGISTER_MKL_CPU_3D(T)                                                 \
2523   REGISTER_KERNEL_BUILDER(                                                     \
2524       Name("_MklConv3D")                                                       \
2525           .Device(DEVICE_CPU)                                                  \
2526           .TypeConstraint<T>("T")                                              \
2527           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),                 \
2528       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \
2529   REGISTER_KERNEL_BUILDER(                                                     \
2530       Name("_MklNativeConv3D")                                                 \
2531           .Device(DEVICE_CPU)                                                  \
2532           .TypeConstraint<T>("T")                                              \
2533           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2534       MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>);  \
2535   REGISTER_KERNEL_BUILDER(                                                     \
2536       Name("_MklNativeFusedConv3D")                                            \
2537           .Device(DEVICE_CPU)                                                  \
2538           .TypeConstraint<T>("T")                                              \
2539           .Label(mkl_op_registry::kMklNameChangeOpLabel),                      \
2540       MklFusedConv3DOp<CPUDevice, T, T, T, T, T, int32, false, true>);
2541 TF_CALL_float(REGISTER_MKL_CPU_3D);
2542 TF_CALL_bfloat16(REGISTER_MKL_CPU_3D);
2543 
2544 REGISTER_KERNEL_BUILDER(
2545     Name("_FusedConv3D").Device(DEVICE_CPU).TypeConstraint<float>("T"), NoOp);
2546 REGISTER_KERNEL_BUILDER(
2547     Name("_FusedConv3D").Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"),
2548     NoOp);
2549 }  // namespace tensorflow
2550 #endif  // INTEL_MKL
2551