1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 #ifdef INTEL_MKL 18 19 #include "tensorflow/core/kernels/mkl/mkl_conv_ops.h" 20 21 #include <algorithm> 22 #include <map> 23 #include <string> 24 #include <unordered_map> 25 26 #include "absl/strings/str_join.h" 27 #include "tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h" 28 #include "tensorflow/core/kernels/no_op.h" 29 #ifdef DNNL_AARCH64_USE_ACL 30 #include "tensorflow/core/platform/hash.h" 31 #include "tensorflow/core/platform/mutex.h" 32 #endif 33 34 using dnnl::convolution_forward; 35 using dnnl::prop_kind; 36 using dnnl::stream; 37 using ConvFwdPd = dnnl::convolution_forward::primitive_desc; 38 using ReorderPd = dnnl::reorder::primitive_desc; 39 40 namespace tensorflow { 41 // This structure aggregates multiple inputs to Conv2DFwd* methods. 42 struct MklConvFwdParams { 43 memory::dims src_dims; 44 memory::dims filter_dims; 45 memory::dims bias_dims; 46 memory::dims dst_dims; 47 memory::dims strides; 48 memory::dims dilations; 49 memory::dims padding_left; 50 memory::dims padding_right; 51 memory::dims fuse_bn_dims; 52 MklTensorFormat tf_fmt; 53 bool native_format; 54 string dtypes = string(""); 55 #ifdef DNNL_AARCH64_USE_ACL 56 uint64 filter_hash; 57 #endif 58 struct PostOpParam { 59 string name; 60 dnnl::algorithm alg; 61 std::vector<float> param; 62 std::string partial_key; 63 }; 64 std::vector<PostOpParam> post_op_params; 65 MklConvFwdParamstensorflow::MklConvFwdParams66 MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims, 67 memory::dims bias_dims, memory::dims dst_dims, 68 memory::dims strides, memory::dims dilations, 69 memory::dims padding_left, memory::dims padding_right, 70 memory::dims fuse_bn_dims, MklTensorFormat tf_fmt, 71 bool native_format) 72 : src_dims(src_dims), 73 filter_dims(filter_dims), 74 bias_dims(bias_dims), 75 dst_dims(dst_dims), 76 strides(strides), 77 dilations(dilations), 78 padding_left(padding_left), 79 padding_right(padding_right), 80 fuse_bn_dims(fuse_bn_dims), 81 tf_fmt(tf_fmt), 82 native_format(native_format) {} 83 }; 84 85 // With quantization, input, filter, and output can have different types 86 // so we use different template parameter for each type 87 template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput> 88 class MklConvFwdPrimitive : public MklPrimitive { 89 public: MklConvFwdPrimitive(const MklConvFwdParams & convFwdDims)90 explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) 91 : MklPrimitive(engine(engine::kind::cpu, 0)) { 92 // Create convolution primitive 93 if (context_.conv_fwd == nullptr) { 94 Setup(convFwdDims); 95 } 96 } ~MklConvFwdPrimitive()97 ~MklConvFwdPrimitive() {} 98 GetScratchPadDesc()99 dnnl::memory::desc GetScratchPadDesc() { 100 return context_.fwd_pd->scratchpad_desc(); 101 } 102 103 // Convolution forward execute with bias 104 // src_data: input data buffer of src 105 // filter_data: input data buffer of filter (weights) 106 // bias_data: input data buffer of bias 107 // dst_data: output data buffer of dst Execute(const Tinput * src_data,const Tfilter * filter_data,const Tbias * bias_data,const Toutput * dst_data,std::shared_ptr<stream> fwd_stream,void * sp_data=nullptr)108 void Execute(const Tinput* src_data, const Tfilter* filter_data, 109 const Tbias* bias_data, const Toutput* dst_data, 110 std::shared_ptr<stream> fwd_stream, void* sp_data = nullptr) { 111 Execute(src_data, filter_data, bias_data, dst_data, nullptr, nullptr, 112 nullptr, nullptr, fwd_stream, sp_data); 113 } 114 Execute(const Tinput * src_data,const Tfilter * filter_data,const Tbias * bias_data,const Toutput * dst_data,const Tinput * bn_scale_data,const Tinput * bn_mean_data,const Tinput * bn_offset_data,const Tinput * bn_rsqrt_data,std::shared_ptr<stream> fwd_stream,void * sp_data)115 void Execute(const Tinput* src_data, const Tfilter* filter_data, 116 const Tbias* bias_data, const Toutput* dst_data, 117 const Tinput* bn_scale_data, const Tinput* bn_mean_data, 118 const Tinput* bn_offset_data, const Tinput* bn_rsqrt_data, 119 std::shared_ptr<stream> fwd_stream, void* sp_data) { 120 #ifdef DNNL_AARCH64_USE_ACL 121 // When we are using single global cache then in this case we can have 122 // multiple threads running the same primitive that we created so this 123 // should happen under the lock. 124 mutex_lock lock(primitive_execution_mu_); 125 #endif 126 #ifndef ENABLE_ONEDNN_OPENMP 127 // TODO(intel-tf): Create a common function and avoid the duplicate code 128 context_.src_mem->set_data_handle( 129 static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream); 130 context_.filter_mem->set_data_handle( 131 static_cast<void*>(const_cast<Tfilter*>(filter_data)), *fwd_stream); 132 if (bias_data != nullptr) { 133 context_.bias_mem->set_data_handle( 134 static_cast<void*>(const_cast<Tbias*>(bias_data)), *fwd_stream); 135 } 136 if (bn_scale_data != nullptr) { 137 context_.bn_scale_mem->set_data_handle( 138 static_cast<void*>(const_cast<Tinput*>(bn_scale_data)), *fwd_stream); 139 context_.bn_mean_mem->set_data_handle( 140 static_cast<void*>(const_cast<Tinput*>(bn_mean_data)), *fwd_stream); 141 context_.bn_rsqrt_mem->set_data_handle( 142 static_cast<void*>(const_cast<Tinput*>(bn_rsqrt_data)), *fwd_stream); 143 context_.bn_offset_mem->set_data_handle( 144 static_cast<void*>(const_cast<Tinput*>(bn_offset_data)), *fwd_stream); 145 } 146 context_.dst_mem->set_data_handle( 147 static_cast<void*>(const_cast<Toutput*>(dst_data)), *fwd_stream); 148 #else 149 context_.src_mem->set_data_handle( 150 static_cast<void*>(const_cast<Tinput*>(src_data))); 151 context_.filter_mem->set_data_handle( 152 static_cast<void*>(const_cast<Tfilter*>(filter_data))); 153 if (bias_data != nullptr) { 154 context_.bias_mem->set_data_handle( 155 static_cast<void*>(const_cast<Tbias*>(bias_data))); 156 } 157 if (bn_scale_data != nullptr) { 158 context_.bn_scale_mem->set_data_handle( 159 static_cast<void*>(const_cast<Tinput*>(bn_scale_data))); 160 context_.bn_mean_mem->set_data_handle( 161 static_cast<void*>(const_cast<Tinput*>(bn_mean_data))); 162 context_.bn_rsqrt_mem->set_data_handle( 163 static_cast<void*>(const_cast<Tinput*>(bn_rsqrt_data))); 164 context_.bn_offset_mem->set_data_handle( 165 static_cast<void*>(const_cast<Tinput*>(bn_offset_data))); 166 } 167 context_.dst_mem->set_data_handle( 168 static_cast<void*>(const_cast<Toutput*>(dst_data))); 169 #endif // !ENABLE_ONEDNN_OPENMP 170 if (sp_data) { 171 context_.sp_mem->set_data_handle(static_cast<void*>(sp_data), 172 *fwd_stream); 173 } 174 175 DCHECK_EQ(context_.fwd_primitives.size(), 176 context_.fwd_primitives_args.size()); 177 for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { 178 context_.fwd_primitives.at(i).execute(*fwd_stream, 179 context_.fwd_primitives_args.at(i)); 180 } 181 182 // After execution, set data handle back 183 context_.src_mem->set_data_handle(DummyData); 184 context_.filter_mem->set_data_handle(DummyData); 185 if (bias_data != nullptr) { 186 context_.bias_mem->set_data_handle(DummyData); 187 } 188 if (bn_scale_data != nullptr) { 189 context_.bn_scale_mem->set_data_handle(DummyData); 190 context_.bn_mean_mem->set_data_handle(DummyData); 191 context_.bn_rsqrt_mem->set_data_handle(DummyData); 192 context_.bn_offset_mem->set_data_handle(DummyData); 193 } 194 context_.dst_mem->set_data_handle(DummyData); 195 if (sp_data) { 196 context_.sp_mem->set_data_handle(DummyData); 197 } 198 } 199 200 // Convolution forward execute without bias 201 // src_data: input data buffer of src 202 // filter_data: input data buffer of filter (weights) 203 // dst_data: output data buffer of dst Execute(const Tinput * src_data,const Tfilter * filter_data,const Toutput * dst_data,std::shared_ptr<stream> fwd_stream,void * sp_data)204 void Execute(const Tinput* src_data, const Tfilter* filter_data, 205 const Toutput* dst_data, std::shared_ptr<stream> fwd_stream, 206 void* sp_data) { 207 Execute(src_data, filter_data, nullptr, dst_data, nullptr, nullptr, nullptr, 208 nullptr, fwd_stream, sp_data); 209 } 210 GetPrimitiveDesc() const211 std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const { 212 return context_.fwd_pd; 213 } 214 215 private: 216 // Primitive reuse context for Conv2D Fwd op 217 struct ConvFwdContext { 218 // MKL-DNN memory 219 std::shared_ptr<dnnl::memory> src_mem; 220 std::shared_ptr<dnnl::memory> filter_mem; 221 std::shared_ptr<dnnl::memory> bias_mem; 222 std::shared_ptr<dnnl::memory> dst_mem; 223 std::shared_ptr<dnnl::memory> sp_mem; 224 225 // FusedBatchNorm related memory 226 std::shared_ptr<dnnl::memory> bn_scale_mem; 227 std::shared_ptr<dnnl::memory> bn_mean_mem; 228 std::shared_ptr<dnnl::memory> bn_rsqrt_mem; 229 std::shared_ptr<dnnl::memory> bn_offset_mem; 230 231 // Desc & primitive desc 232 std::shared_ptr<dnnl::convolution_forward::desc> fwd_desc; 233 234 // Memory desc 235 std::shared_ptr<dnnl::memory::desc> src_md; 236 std::shared_ptr<dnnl::memory::desc> filter_md; 237 std::shared_ptr<dnnl::memory::desc> bias_md; 238 std::shared_ptr<dnnl::memory::desc> dst_md; 239 240 // TODO(intel-tf): Only need one? FusedBatchNorm related. 241 std::shared_ptr<dnnl::memory::desc> bn_scale_md; 242 std::shared_ptr<dnnl::memory::desc> bn_mean_md; 243 std::shared_ptr<dnnl::memory::desc> bn_rsqrt_md; 244 std::shared_ptr<dnnl::memory::desc> bn_offset_md; 245 246 // Convolution primitive 247 std::shared_ptr<ConvFwdPd> fwd_pd; 248 std::shared_ptr<dnnl::primitive> conv_fwd; 249 250 std::vector<dnnl::primitive> fwd_primitives; 251 std::vector<std::unordered_map<int, memory>> fwd_primitives_args; 252 ConvFwdContexttensorflow::MklConvFwdPrimitive::ConvFwdContext253 ConvFwdContext() 254 : src_mem(nullptr), 255 filter_mem(nullptr), 256 bias_mem(nullptr), 257 dst_mem(nullptr), 258 sp_mem(nullptr), 259 bn_scale_mem(nullptr), 260 bn_mean_mem(nullptr), 261 bn_rsqrt_mem(nullptr), 262 bn_offset_mem(nullptr), 263 fwd_desc(nullptr), 264 src_md(nullptr), 265 filter_md(nullptr), 266 bias_md(nullptr), 267 dst_md(nullptr), 268 bn_scale_md(nullptr), 269 bn_mean_md(nullptr), 270 bn_rsqrt_md(nullptr), 271 bn_offset_md(nullptr), 272 fwd_pd(nullptr), 273 conv_fwd(nullptr) {} 274 }; 275 Setup(const MklConvFwdParams & convFwdDims)276 void Setup(const MklConvFwdParams& convFwdDims) { 277 memory::format_tag user_data_fmt; 278 if (convFwdDims.native_format) { 279 user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt); 280 } else { 281 // Create memory descriptors for convolution data w/ no specified format 282 user_data_fmt = memory::format_tag::any; 283 } 284 context_.src_md.reset(new memory::desc( 285 {convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt)); 286 287 context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims}, 288 MklDnnType<Tfilter>(), 289 memory::format_tag::any)); 290 291 context_.dst_md.reset(new memory::desc( 292 {convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt)); 293 294 if (!convFwdDims.bias_dims.empty()) { 295 context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, 296 MklDnnType<Tbias>(), 297 memory::format_tag::any)); 298 // Create a convolution descriptor 299 context_.fwd_desc.reset(new convolution_forward::desc( 300 prop_kind::forward, dnnl::algorithm::convolution_direct, 301 *context_.src_md, *context_.filter_md, *context_.bias_md, 302 *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, 303 convFwdDims.padding_left, convFwdDims.padding_right)); 304 } else { 305 context_.fwd_desc.reset(new convolution_forward::desc( 306 prop_kind::forward, dnnl::algorithm::convolution_direct, 307 *context_.src_md, *context_.filter_md, *context_.dst_md, 308 convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, 309 convFwdDims.padding_right)); 310 } 311 312 if (!convFwdDims.fuse_bn_dims.empty()) { 313 const memory::format_tag fused_bn_arg_fmt = 314 convFwdDims.native_format 315 ? user_data_fmt 316 : MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt); 317 318 context_.bn_scale_md.reset(new memory::desc( 319 {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt)); 320 context_.bn_mean_md.reset(new memory::desc( 321 {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt)); 322 context_.bn_rsqrt_md.reset(new memory::desc( 323 {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt)); 324 context_.bn_offset_md.reset(new memory::desc( 325 {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt)); 326 } 327 328 // Check if there is any fusions as post-ops 329 auto const& post_op_params = convFwdDims.post_op_params; 330 dnnl::primitive_attr post_ops_attr; 331 dnnl::post_ops post_ops; 332 post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); 333 if (!post_op_params.empty()) { 334 for (auto const& post_op_param : post_op_params) { 335 if (post_op_param.name == "activation") { 336 DCHECK_EQ(post_op_param.param.size(), 3); 337 float op_scale = post_op_param.param[0]; 338 float op_alpha = post_op_param.param[1]; 339 float op_beta = post_op_param.param[2]; 340 post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha, 341 op_beta); 342 } else if (post_op_param.name == "sum") { 343 DCHECK_EQ(post_op_param.param.size(), 1); 344 float op_scale = post_op_param.param[0]; 345 post_ops.append_sum(op_scale); 346 } else if (post_op_param.name == "output_scale") { 347 if (post_op_param.param.size() == 1) { 348 post_ops_attr.set_output_scales(0, post_op_param.param); 349 } else { 350 post_ops_attr.set_output_scales(2, post_op_param.param); 351 } 352 } else if (post_op_param.name == "fuse_bn") { 353 post_ops.append_binary(dnnl::algorithm::binary_sub, 354 *context_.bn_mean_md); 355 post_ops.append_binary(dnnl::algorithm::binary_mul, 356 *context_.bn_rsqrt_md); 357 post_ops.append_binary(dnnl::algorithm::binary_mul, 358 *context_.bn_scale_md); 359 post_ops.append_binary(dnnl::algorithm::binary_add, 360 *context_.bn_offset_md); 361 } else { 362 DCHECK((post_op_param.name == "activation") || 363 (post_op_param.name == "sum") || 364 (post_op_param.name == "output_scale") || 365 (post_op_param.name == "fuse_bn")); 366 } 367 } 368 post_ops_attr.set_post_ops(post_ops); 369 } 370 context_.fwd_pd.reset( 371 new ConvFwdPd(*context_.fwd_desc, post_ops_attr, cpu_engine_)); 372 373 // Create memory primitive based on dummy data 374 context_.src_mem.reset( 375 new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData)); 376 context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(), 377 cpu_engine_, DummyData)); 378 context_.dst_mem.reset( 379 new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData)); 380 381 context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); 382 auto scratchpad_md = context_.fwd_pd->scratchpad_desc(); 383 context_.sp_mem.reset( 384 new dnnl::memory(scratchpad_md, cpu_engine_, DummyData)); 385 386 // Create convolution primitive and add it to net 387 if (!convFwdDims.bias_dims.empty()) { 388 context_.bias_mem.reset(new memory( 389 {{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x}, 390 cpu_engine_, DummyData)); 391 context_.fwd_primitives_args.push_back( 392 {{DNNL_ARG_SRC, *context_.src_mem}, 393 {DNNL_ARG_WEIGHTS, *context_.filter_mem}, 394 {DNNL_ARG_BIAS, *context_.bias_mem}, 395 {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, 396 {DNNL_ARG_DST, *context_.dst_mem}}); 397 } else if (!convFwdDims.fuse_bn_dims.empty()) { 398 context_.bn_scale_mem.reset( 399 new memory(*context_.bn_scale_md, cpu_engine_, DummyData)); 400 context_.bn_mean_mem.reset( 401 new memory(*context_.bn_mean_md, cpu_engine_, DummyData)); 402 context_.bn_offset_mem.reset( 403 new memory(*context_.bn_offset_md, cpu_engine_, DummyData)); 404 context_.bn_rsqrt_mem.reset( 405 new memory(*context_.bn_rsqrt_md, cpu_engine_, DummyData)); 406 407 context_.fwd_primitives_args.push_back( 408 {{DNNL_ARG_SRC, *context_.src_mem}, 409 {DNNL_ARG_WEIGHTS, *context_.filter_mem}, 410 {DNNL_ARG_DST, *context_.dst_mem}, 411 {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, 412 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, 413 *context_.bn_mean_mem}, 414 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, 415 *context_.bn_rsqrt_mem}, 416 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1, 417 *context_.bn_scale_mem}, 418 {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, 419 *context_.bn_offset_mem}}); 420 } else { 421 context_.fwd_primitives_args.push_back( 422 {{DNNL_ARG_SRC, *context_.src_mem}, 423 {DNNL_ARG_WEIGHTS, *context_.filter_mem}, 424 {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, 425 {DNNL_ARG_DST, *context_.dst_mem}}); 426 } 427 context_.fwd_primitives.push_back(*context_.conv_fwd); 428 } 429 430 struct ConvFwdContext context_; 431 432 #ifdef DNNL_AARCH64_USE_ACL 433 // Guards Execution() 434 mutex primitive_execution_mu_; 435 #endif 436 }; 437 438 // TODO(intel-tf): We should not require passing a type to MklPrimitiveFactory. 439 // But removing the need for type in MklPrimitiveFactory is going to require 440 // change to every MKL op. So not doing it now. Instead passing float. 441 template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput> 442 class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<float> { 443 public: Get(const MklConvFwdParams & convFwdDims,bool do_not_cache)444 static MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>* Get( 445 const MklConvFwdParams& convFwdDims, bool do_not_cache) { 446 MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>* conv_fwd = nullptr; 447 448 if (do_not_cache) { 449 // Always create a new primitive 450 conv_fwd = 451 new MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>(convFwdDims); 452 } else { 453 // Try to find a suitable one in pool 454 conv_fwd = 455 dynamic_cast<MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>*>( 456 MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, 457 Toutput>::GetInstance() 458 .GetConvFwd(convFwdDims)); 459 if (conv_fwd == nullptr) { 460 conv_fwd = new MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>( 461 convFwdDims); 462 MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, 463 Toutput>::GetInstance() 464 .SetConvFwd(convFwdDims, conv_fwd); 465 } 466 } 467 468 return conv_fwd; 469 } 470 471 private: MklConvFwdPrimitiveFactory()472 MklConvFwdPrimitiveFactory() {} ~MklConvFwdPrimitiveFactory()473 ~MklConvFwdPrimitiveFactory() {} 474 475 static const int kDilationH = 0, kDilationW = 1; 476 GetInstance()477 static MklConvFwdPrimitiveFactory& GetInstance() { 478 static MklConvFwdPrimitiveFactory instance_; 479 return instance_; 480 } 481 CreateKey(const MklConvFwdParams & convFwdDims)482 static string CreateKey(const MklConvFwdParams& convFwdDims) { 483 string prefix = "conv_fwd_"; 484 FactoryKeyCreator key_creator; 485 key_creator.AddAsKey(prefix); 486 key_creator.AddAsKey(convFwdDims.src_dims); 487 key_creator.AddAsKey(convFwdDims.filter_dims); 488 #ifdef DNNL_AARCH64_USE_ACL 489 key_creator.AddAsKey(convFwdDims.filter_hash); 490 #endif 491 key_creator.AddAsKey(convFwdDims.bias_dims); 492 key_creator.AddAsKey(convFwdDims.dst_dims); 493 key_creator.AddAsKey(convFwdDims.strides); 494 key_creator.AddAsKey(convFwdDims.dilations); 495 key_creator.AddAsKey(convFwdDims.padding_left); 496 key_creator.AddAsKey(convFwdDims.padding_right); 497 key_creator.AddAsKey(convFwdDims.dtypes); 498 if (convFwdDims.native_format) { 499 key_creator.AddAsKey(convFwdDims.tf_fmt); 500 } 501 502 // Generate keys for post-ops 503 for (auto const& post_op_param : convFwdDims.post_op_params) { 504 key_creator.AddAsKey(post_op_param.name); 505 if (post_op_param.name == "activation") { 506 DCHECK_EQ(post_op_param.param.size(), 3); 507 for (auto& param : post_op_param.param) { 508 key_creator.AddAsKey(param); 509 } 510 } else if (post_op_param.name == "sum") { 511 DCHECK_EQ(post_op_param.param.size(), 1); 512 for (auto& param : post_op_param.param) { 513 key_creator.AddAsKey(param); 514 } 515 } else if (post_op_param.name == "output_scale") { 516 key_creator.AddAsKey(post_op_param.partial_key); 517 } else if (post_op_param.name == "fuse_bn") { 518 key_creator.AddAsKey(post_op_param.name); 519 key_creator.AddAsKey(convFwdDims.fuse_bn_dims); 520 } else { 521 return string("not_a_key"); 522 } 523 } 524 525 return key_creator.GetKey(); 526 } 527 GetConvFwd(const MklConvFwdParams & convFwdDims)528 MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) { 529 string key = CreateKey(convFwdDims); 530 return this->GetOp(key); 531 } 532 SetConvFwd(const MklConvFwdParams & convFwdDims,MklPrimitive * op)533 void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) { 534 string key = CreateKey(convFwdDims); 535 this->SetOp(key, op); 536 } 537 }; 538 539 // Base class for convolution forward operations 540 template <typename Device, typename Tinput, typename Tfilter, typename Tbias, 541 typename Toutput, typename Ttemp_output, typename Tpadding, 542 bool bias_enabled, bool pad_enabled, bool is_depthwise, 543 bool native_format> 544 class MklConvOp : public OpKernel { 545 public: ~MklConvOp()546 ~MklConvOp() {} 547 MklConvOp(OpKernelConstruction * context)548 explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) { 549 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 550 551 // Conv and QuantizedConv ops have different padding attributes 552 // (`padding_list` versus `explicit_paddings`). But one and only one 553 // attribute is expected. 554 OP_REQUIRES( 555 context, 556 !(context->HasAttr("padding_list") && 557 context->HasAttr("explicit_paddings")), 558 errors::InvalidArgument("Can only have 1 `padding` list at most")); 559 if (context->HasAttr("padding_list")) { 560 OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_)); 561 } 562 if (context->HasAttr("explicit_paddings")) { 563 OP_REQUIRES_OK(context, 564 context->GetAttr("explicit_paddings", &padding_list_)); 565 } 566 567 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 568 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str_)); 569 OP_REQUIRES(context, FormatFromString(data_format_str_, &data_format_), 570 errors::InvalidArgument("Invalid data format")); 571 OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5), 572 errors::InvalidArgument("Sliding window strides field must " 573 "specify 4 or 5 dimensions")); 574 575 const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); 576 const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); 577 OP_REQUIRES( 578 context, stride_n == 1 && stride_c == 1, 579 errors::Unimplemented("Current implementation does not yet support " 580 "strides in the batch and depth dimensions.")); 581 582 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 583 is_filter_const_ = false; 584 if (AreWeightsFrozen()) { 585 is_filter_const_ = true; 586 } else if (context->HasAttr("is_filter_const")) { 587 OP_REQUIRES_OK(context, 588 context->GetAttr("is_filter_const", &is_filter_const_)); 589 } 590 591 if (strides_.size() == 4) { 592 OP_REQUIRES(context, dilations_.size() == 4, 593 errors::InvalidArgument("Sliding window dilations field must " 594 "specify 4 dimensions")); 595 const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); 596 const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); 597 const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); 598 const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); 599 OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, 600 errors::InvalidArgument( 601 "Current implementation does not yet support " 602 "dilations in the batch and depth dimensions.")); 603 OP_REQUIRES( 604 context, dilation_h > 0 && dilation_w > 0, 605 errors::InvalidArgument("Dilated rates should be larger than 0.")); 606 } else if (strides_.size() == 5) { 607 OP_REQUIRES(context, dilations_.size() == 5, 608 errors::InvalidArgument("Dilation rates field must " 609 "specify 5 dimensions")); 610 OP_REQUIRES(context, 611 (GetTensorDim(dilations_, data_format_, 'N') == 1 && 612 GetTensorDim(dilations_, data_format_, 'C') == 1), 613 errors::InvalidArgument( 614 "Current implementation does not yet support " 615 "dilations rates in the batch and depth dimensions.")); 616 OP_REQUIRES( 617 context, 618 (GetTensorDim(dilations_, data_format_, '0') > 0 && 619 GetTensorDim(dilations_, data_format_, '1') > 0 && 620 GetTensorDim(dilations_, data_format_, '2') > 0), 621 errors::InvalidArgument("Dilated rates should be larger than 0.")); 622 } 623 } 624 Compute(OpKernelContext * context)625 void Compute(OpKernelContext* context) override { 626 try { 627 // Input tensors 628 const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src); 629 const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter); 630 631 OP_REQUIRES( 632 context, filter_tensor.NumElements() > 0, 633 errors::InvalidArgument("filter must not have zero elements " 634 "(i.e. all dimensions must be non-zero)")); 635 636 MklDnnShape src_mkl_shape, filter_mkl_shape; 637 GetMklShape(context, kInputIndex_Src, &src_mkl_shape, native_format); 638 GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape, 639 native_format); 640 641 OP_REQUIRES(context, !filter_mkl_shape.IsMklTensor(), 642 errors::InvalidArgument("Filter should not be in " 643 "Mkl Layout")); 644 645 MklDnnData<Tinput> src(&cpu_engine_); 646 MklDnnData<Tfilter> filter(&cpu_engine_); 647 648 memory::dims src_dims, filter_dims, padding_left, padding_right, 649 dilations, strides; 650 memory::dims dst_dims_tf_order, dst_dims_mkl_order; 651 652 // For any Conv with `EXPLICIT` padding, get padding from `padding_list` 653 // attribute. Otherwise, get it from one of the inputs. 654 bool pad_attr_enabled = false; 655 for (auto const& padding_val : padding_list_) { 656 if (padding_val) { 657 pad_attr_enabled = true; 658 659 break; 660 } 661 } 662 663 if (fuse_pad_ || pad_attr_enabled) { 664 PadWithConvFusion(context, padding_left, padding_right, 665 pad_attr_enabled, data_format_str_); 666 } 667 668 // Get shapes of input tensors in MKL-DNN order 669 MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_, 670 dilations_); 671 auto src_tf_shape = GetTfShape(context, kInputIndex_Src, native_format); 672 auto filter_tf_shape = 673 GetTfShape(context, kInputIndex_Filter, native_format); 674 bool is_grouped_convolution = false; 675 conv_utl.GetConvFwdSizesInMklOrder( 676 src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides, 677 &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left, 678 &padding_right, &is_grouped_convolution, 679 (fuse_pad_ || pad_attr_enabled), is_depthwise); 680 681 if (!context->status().ok()) return; 682 683 // Check for corner case - if there is nothing to compute, return. 684 TensorShape dst_tf_shape = MklDnnDimsToTFShape(dst_dims_tf_order); 685 686 // Corner cases: output with 0 elements and 0 batch size. 687 Tensor* dst_tensor = nullptr; 688 bool emit_filter_output = (typeid(Tinput) == typeid(Tfilter) && 689 typeid(Tinput) == typeid(Toutput) && 690 (typeid(Tinput) == typeid(float) || 691 typeid(Tinput) == typeid(bfloat16))) && 692 !native_format; 693 if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) { 694 MklDnnShape dst_mkl_shape; 695 dst_mkl_shape.SetMklTensor(false); 696 AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor, 697 src_tf_shape, dst_mkl_shape, native_format); 698 699 // MklConv2D/3D also outputs converted filter as 2nd output. 700 filter_mkl_shape.SetMklTensor(false); 701 Tensor* output_filter_tensor = nullptr; 702 if (emit_filter_output) { 703 filter_mkl_shape.SetMklTensor(false); 704 AllocateOutputSetMklShape(context, kOutputIndex_Filter, 705 &output_filter_tensor, filter_tf_shape, 706 filter_mkl_shape); 707 } 708 return; 709 } 710 711 bool is_conv2d = (strides_.size() == 4); 712 bool is_conv3d = (strides_.size() == 5); 713 714 if (!is_conv2d && !is_conv3d) { 715 OP_REQUIRES( 716 context, !pad_enabled, 717 errors::InvalidArgument("Pad + Conv fusion only works for 2D/3D")); 718 OP_REQUIRES( 719 context, !fuse_pad_, 720 errors::InvalidArgument("Pad+Conv fusion only works for 2D/3D")); 721 } 722 723 // TODO(intel-tf) 3-D support for Depthwise is not there 724 if (is_depthwise) { 725 OP_REQUIRES(context, is_conv2d, 726 errors::InvalidArgument( 727 "Only 2D convolution is supported for depthwise.")); 728 } 729 730 // Create memory for user data. 731 // Describe how the inputs and outputs of Convolution look like. Also 732 // specify buffers containing actual input and output data. 733 auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_) 734 : TFDataFormatToMklDnn3DDataFormat(data_format_); 735 736 auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); 737 // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU 738 OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, 739 errors::InvalidArgument("Invalid data format")); 740 741 // If input is in MKL layout, then simply grab the layout; otherwise, 742 // construct TF layout for input. 743 // For constructing TF layout for input, although input shape (src_dims) 744 // is required to be in MKL-DNN order, the input layout is actually in 745 // TF layout depending on the data format: 746 // Conv2D: NHWC or NCHW 747 // Conv3D: NDHWC or NCDHW 748 auto src_md = 749 src_mkl_shape.IsMklTensor() 750 ? src_mkl_shape.GetMklLayout() 751 : memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag); 752 src.SetUsrMem(src_md, &src_tensor); 753 754 // Although filter shape (filter_dims) required is in MKL-DNN order, 755 // the layout is Tensorflow's layout (HWIO) and (HWIGO) for 756 // depthwise/group convolutions. 757 auto filter_format = is_conv2d ? ((is_depthwise || is_grouped_convolution) 758 ? memory::format_tag::hwigo 759 : memory::format_tag::hwio) 760 : memory::format_tag::dhwio; 761 762 DCHECK(!filter_mkl_shape.IsMklTensor()); 763 auto filter_md = 764 filter_mkl_shape.IsMklTensor() 765 ? filter_mkl_shape.GetMklLayout() 766 : memory::desc(filter_dims, MklDnnType<Tfilter>(), filter_format); 767 filter.SetUsrMem(filter_md, &filter_tensor); 768 769 // MKL-DNN dilations start from 0. 770 for (int i = 0; i < dilations.size(); ++i) --dilations[i]; 771 772 // In some cases, primitive descriptor could potentially contain 773 // large buffers. As a result, we don't cache these primitives if the 774 // environment variable `TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE` is set to True. 775 // MKL-DNN allocates buffers in the following cases: 776 // 1. Legacy CPU without AVX512/AVX2, or 777 // 2. 1x1 convolution with strides != 1 778 bool do_not_cache = 779 MklPrimitiveFactory<Tinput>::IsPrimitiveMemOptEnabled() && 780 (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) && 781 (MklPrimitiveFactory<Tinput>::IsLegacyPlatform() || 782 IsConv1x1StrideNot1(filter_dims, strides)); 783 784 // Get a conv2d fwd from primitive pool 785 MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Ttemp_output>* conv_fwd = 786 nullptr; 787 memory::dims bias_dims = {}; 788 if (fuse_biasadd_) { 789 conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims); 790 } 791 memory::dims fuse_bn_dims = {}; 792 TensorShape fuse_bn_shape; 793 if (fuse_bn_) { 794 // Inputs to FusedBatchNorm have same 1D shape 795 fuse_bn_shape = MklGetInput(context, kInputIndex_BN_Mean).shape(); 796 OP_REQUIRES(context, fuse_bn_shape.dims() == 1, 797 errors::InvalidArgument("FusedBatchNorm must be 1D, not: ", 798 fuse_bn_shape.DebugString())); 799 800 // Note - MKL-DNN expects {1, C, 1, 1} for binary post-op even for NHWC 801 fuse_bn_dims = {1, fuse_bn_shape.dim_size(0), 1, 1}; 802 } 803 804 MklConvFwdParams convFwdDims( 805 src_dims, filter_dims, fuse_biasadd_ ? bias_dims : NONE_DIMS, 806 dst_dims_mkl_order, strides, dilations, padding_left, padding_right, 807 fuse_bn_dims, tf_fmt, native_format); 808 809 // TODO(intel-tf): Extend the basic parameters for data types and fusions 810 this->ExtendConvFwdParams(context, convFwdDims); 811 #ifdef DNNL_AARCH64_USE_ACL 812 // TODO(milpuz01): Remove once Arm Compute Library provides support for 813 // in-place updates 814 convFwdDims.filter_hash = Hash64( 815 filter_tensor.tensor_data().data(), 816 std::min(kFilterTensorHashLength, 817 static_cast<int>(filter_tensor.tensor_data().size()))); 818 #endif 819 820 conv_fwd = 821 MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get( 822 convFwdDims, do_not_cache); 823 // Allocate output tensors `dst_tensor` and `filter_out_tensor` 824 MklDnnShape output_mkl_shape; 825 std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc(); 826 AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt, 827 &output_mkl_shape, &dst_tensor); 828 829 Tensor* filter_out_tensor = nullptr; 830 if (emit_filter_output) { 831 AllocateFilterOutputTensor(context, *conv_fwd_pd, 832 TFShapeToMklDnnDims(filter_tf_shape), 833 &filter_out_tensor); 834 } 835 836 Ttemp_output* dst_data = 837 reinterpret_cast<Ttemp_output*>(dst_tensor->flat<Toutput>().data()); 838 839 // Check whether src and filter need to be reordered. 840 Tinput* src_data = nullptr; 841 if (src_md != conv_fwd_pd->src_desc()) { 842 src.SetUsrMem(src_md, &src_tensor); 843 src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_, context); 844 src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle()); 845 } else { 846 src_data = static_cast<Tinput*>( 847 const_cast<Tinput*>(src_tensor.flat<Tinput>().data())); 848 } 849 850 Tfilter* filter_data = nullptr; 851 if (filter_md != conv_fwd_pd->weights_desc()) { 852 bool is_filter_cached = false; 853 // If filter is a constant, we can avoid the conversion of filter from 854 // Tensorflow format to MKL format by caching the filter when it is 855 // converted for the first time. This cached filter can then be reused 856 // in subsequent iterations. 857 if (is_filter_const_) { 858 if (IsFilterCacheEmpty(context)) { 859 // Cache filter if it is not already cached. 860 CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor, 861 filter, filter_md, filter_mkl_shape); 862 } 863 filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc()); 864 is_filter_cached = (filter_data != nullptr); 865 } 866 if (!is_filter_cached) { 867 filter.SetUsrMem(filter_md, &filter_tensor); 868 if (filter_out_tensor == nullptr) { 869 filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), cpu_engine_, 870 context); 871 } else { 872 filter.CheckReorderToOpMem( 873 conv_fwd_pd->weights_desc(), 874 filter.GetTensorBuffer(filter_out_tensor), cpu_engine_, 875 context); 876 } 877 filter_data = 878 static_cast<Tfilter*>(filter.GetOpMem().get_data_handle()); 879 } 880 } else { 881 filter_data = static_cast<Tfilter*>( 882 const_cast<Tfilter*>(filter_tensor.flat<Tfilter>().data())); 883 } 884 885 UserScratchPad<unsigned char> scratch_pad; 886 scratch_pad.AllocateSPTensor(conv_fwd, context); 887 888 // Execute convolution 889 std::shared_ptr<stream> fwd_cpu_stream; 890 MklDnnThreadPool eigen_tp(context); 891 fwd_cpu_stream.reset(CreateStream(&eigen_tp, conv_fwd->GetEngine())); 892 if (fuse_biasadd_) { 893 const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias); 894 Tbias* bias_data = 895 this->GetBiasHandle(context, conv_fwd_pd, bias_tensor); 896 conv_fwd->Execute(src_data, filter_data, bias_data, dst_data, 897 fwd_cpu_stream, scratch_pad.Get()); 898 } else if (fuse_bn_) { 899 const Tensor& bn_scale_tensor = 900 MklGetInput(context, kInputIndex_BN_Scale); 901 Tinput* bn_scale_data = static_cast<Tinput*>( 902 const_cast<Tinput*>(bn_scale_tensor.flat<Tinput>().data())); 903 const Tensor& bn_mean_tensor = 904 MklGetInput(context, kInputIndex_BN_Mean); 905 Tinput* bn_mean_data = static_cast<Tinput*>( 906 const_cast<Tinput*>(bn_mean_tensor.flat<Tinput>().data())); 907 const Tensor& bn_offset_tensor = 908 MklGetInput(context, kInputIndex_BN_Offset); 909 Tinput* bn_offset_data = static_cast<Tinput*>( 910 const_cast<Tinput*>(bn_offset_tensor.flat<Tinput>().data())); 911 912 Tensor bn_rsqrt_tensor; 913 OP_REQUIRES_OK(context, 914 context->allocate_temp(DataTypeToEnum<Tinput>::v(), 915 fuse_bn_shape, &bn_rsqrt_tensor)); 916 Tinput* bn_rsqrt_data = static_cast<Tinput*>( 917 const_cast<Tinput*>(bn_rsqrt_tensor.flat<Tinput>().data())); 918 this->ComputeBNScale(context, epsilon_, kInputIndex_BN_Variance, 919 bn_rsqrt_data); 920 conv_fwd->Execute(src_data, filter_data, nullptr, dst_data, 921 bn_scale_data, bn_mean_data, bn_offset_data, 922 bn_rsqrt_data, fwd_cpu_stream, scratch_pad.Get()); 923 } else { 924 conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream, 925 scratch_pad.Get()); 926 } 927 928 // Delete primitive since it is not cached. 929 if (do_not_cache) delete conv_fwd; 930 931 } catch (dnnl::error& e) { 932 string error_msg = tensorflow::strings::StrCat( 933 "Status: ", e.status, ", message: ", string(e.message), ", in file ", 934 __FILE__, ":", __LINE__); 935 OP_REQUIRES_OK( 936 context, 937 errors::Aborted("Operation received an exception:", error_msg)); 938 } 939 } 940 PadWithConvFusion(OpKernelContext * context,memory::dims & padding_left,memory::dims & padding_right,bool pad_attr_enabled,string data_format_str_)941 void PadWithConvFusion(OpKernelContext* context, memory::dims& padding_left, 942 memory::dims& padding_right, bool pad_attr_enabled, 943 string data_format_str_) { 944 Tpadding* paddings = nullptr; 945 if (pad_attr_enabled) { 946 paddings = padding_list_.data(); 947 } else { 948 const Tensor& paddings_tf = MklGetInput(context, input_index_pad_); 949 OP_REQUIRES(context, paddings_tf.dims() == 2, 950 errors::InvalidArgument("paddings must be 2-dimensional: ", 951 paddings_tf.shape().DebugString())); 952 // Flatten tensor to get individual paddings. 953 paddings = static_cast<Tpadding*>( 954 const_cast<Tpadding*>(paddings_tf.flat<Tpadding>().data())); 955 } 956 // If the data format is NHWC, indices 0, 1, 6 and 7 of paddings(_tf) 957 // will be zero. 958 // Example: 959 // paddings_tf = [ [0, 0] [1, 2] [3, 4] [0, 0] ], 960 // flat method = row-major, then: 961 // paddings = {0, 0, 1, 2, 3, 4, 0, 0}. 962 // Hence, the values are: top = 1, bottom = 2, left = 3, right = 4. 963 // 964 // Similarly, if the data format is NCHW, indices 0, 1, 2 and 3 of 965 // paddings(_tf) will be zero. 966 // i.e. for the above example, paddings = {0, 0, 0, 0, 1, 2, 3, 4}. 967 int64 pad_top = 0, pad_left = 0, pad_front = 0; 968 int64 pad_bottom = 0, pad_right = 0, pad_back = 0; 969 if (data_format_str_ == "NHWC") { 970 pad_top = paddings[2]; 971 pad_bottom = paddings[3]; 972 pad_left = paddings[4]; 973 pad_right = paddings[5]; 974 } else if (data_format_str_ == "NCHW") { 975 pad_top = paddings[4]; 976 pad_bottom = paddings[5]; 977 pad_left = paddings[6]; 978 pad_right = paddings[7]; 979 } else if (data_format_str_ == "NDHWC") { 980 pad_front = paddings[2]; 981 pad_back = paddings[3]; 982 pad_top = paddings[4]; 983 pad_bottom = paddings[5]; 984 pad_left = paddings[6]; 985 pad_right = paddings[7]; 986 } else if (data_format_str_ == "NCDHW") { 987 pad_front = paddings[4]; 988 pad_back = paddings[5]; 989 pad_top = paddings[6]; 990 pad_bottom = paddings[7]; 991 pad_left = paddings[8]; 992 pad_right = paddings[9]; 993 } 994 // Create padding arrays for MKL-DNN convolutions. 995 // MKL-DNN uses asymmetric padding. 996 if (data_format_str_ == "NHWC" || data_format_str_ == "NCHW") { 997 padding_left = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; 998 padding_right = {static_cast<int>(pad_bottom), 999 static_cast<int>(pad_right)}; 1000 } else if (data_format_str_ == "NDHWC" || data_format_str_ == "NCDHW") { 1001 padding_left = {static_cast<int>(pad_front), static_cast<int>(pad_top), 1002 static_cast<int>(pad_left)}; 1003 padding_right = {static_cast<int>(pad_back), static_cast<int>(pad_bottom), 1004 static_cast<int>(pad_right)}; 1005 } 1006 } 1007 1008 protected: set_fuse_biasadd(bool fuse_biasadd)1009 void set_fuse_biasadd(bool fuse_biasadd) { fuse_biasadd_ = fuse_biasadd; } set_fuse_activation(bool fuse_activation,dnnl::algorithm activation_alg,float alpha_or_upbound=0.0)1010 void set_fuse_activation(bool fuse_activation, dnnl::algorithm activation_alg, 1011 float alpha_or_upbound = 0.0) { 1012 fuse_activation_ = fuse_activation; 1013 activation_alg_ = activation_alg; 1014 // This variable is used for alpha in leakyrelu or upper bound in relu6 1015 // depending on the context 1016 alpha_or_upbound_ = alpha_or_upbound; 1017 } set_fuse_pad(bool fuse_pad)1018 void set_fuse_pad(bool fuse_pad) { 1019 fuse_pad_ = fuse_pad; 1020 if (fuse_bn_) { 1021 // If FusedBatchNorm is fused in PadWithFusedConv2D, pad is the 7th input 1022 input_index_pad_ = 6; 1023 } else if (fuse_add_ && fuse_biasadd_) { 1024 // If Bias and Add are fused in PadWithFusedConv2D, pad is the 5th input 1025 input_index_pad_ = 4; 1026 } else { 1027 // Case of Bias is fused in PadwithFusedConv OP, pad is the fourth input 1028 input_index_pad_ = 3; 1029 } 1030 } set_fuse_add(bool fuse_add)1031 void set_fuse_add(bool fuse_add) { fuse_add_ = fuse_add; } set_fuse_bn(bool fuse_bn,float epsilon)1032 void set_fuse_bn(bool fuse_bn, float epsilon) { 1033 fuse_bn_ = fuse_bn; 1034 epsilon_ = epsilon; 1035 } 1036 ComputeBNScale(OpKernelContext * context,float epsilon,int bn_variance_index,Tinput * scale_buf_ptr)1037 virtual void ComputeBNScale(OpKernelContext* context, float epsilon, 1038 int bn_variance_index, Tinput* scale_buf_ptr) { 1039 OP_REQUIRES( 1040 context, false, 1041 errors::Unimplemented("Compute BN scale not expected in base class")); 1042 return; 1043 } 1044 1045 // This method is for the base class MklConvOp, which handles the 1046 // floating point implementation of Conv. The quantized conv implementations 1047 // will use overridden versions of this method. ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1048 virtual void ExtendConvFwdParams(OpKernelContext* context, 1049 MklConvFwdParams& params) { 1050 // Create a string from data types of input, filter, bias, and output. 1051 params.dtypes.append(typeid(Tinput).name()); 1052 params.dtypes.append(typeid(Tfilter).name()); 1053 params.dtypes.append(typeid(Tbias).name()); 1054 params.dtypes.append(typeid(Toutput).name()); 1055 1056 // Add fusions as post ops 1057 // NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by 1058 // checking `fuse_biasadd_` flag. 1059 if (fuse_add_) { 1060 params.post_op_params.push_back( 1061 {"sum", dnnl::algorithm::undef, {1.0}, ""}); 1062 } 1063 // NOTE - fuse_bn post_op entry must be before fuse_activation 1064 if (fuse_bn_) { 1065 params.post_op_params.push_back( 1066 {"fuse_bn", dnnl::algorithm::undef, {1.0}, ""}); 1067 } 1068 if (fuse_activation_) { 1069 params.post_op_params.push_back( 1070 {"activation", activation_alg_, {1.0, alpha_or_upbound_, 0.0}, ""}); 1071 } 1072 } 1073 GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv2d_fwd_pd,const Tensor & bias_tensor)1074 virtual Tbias* GetBiasHandle(OpKernelContext* context, 1075 std::shared_ptr<ConvFwdPd>& conv2d_fwd_pd, 1076 const Tensor& bias_tensor) { 1077 if (fuse_biasadd_) { 1078 return static_cast<Tbias*>( 1079 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 1080 } 1081 return nullptr; 1082 } 1083 AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,MklTensorFormat output_tf_format,MklDnnShape * output_mkl_shape,Tensor ** output_tensor)1084 virtual void AllocateOutputTensor(OpKernelContext* context, 1085 const ConvFwdPd& conv_prim_desc, 1086 const memory::dims& output_dims_mkl_order, 1087 MklTensorFormat output_tf_format, 1088 MklDnnShape* output_mkl_shape, 1089 Tensor** output_tensor) { 1090 DCHECK(output_tensor); 1091 auto dst_md = conv_prim_desc.dst_desc(); 1092 1093 if (!std::is_same<Ttemp_output, Toutput>::value) { 1094 dst_md.data.data_type = 1095 static_cast<dnnl_data_type_t>(MklDnnType<Toutput>()); 1096 } 1097 1098 // Allocate shape of MKL tensor 1099 output_mkl_shape->SetMklTensor(true); 1100 output_mkl_shape->SetMklLayout(&dst_md); 1101 output_mkl_shape->SetElemType(MklDnnType<Toutput>()); 1102 output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(), 1103 output_dims_mkl_order, output_tf_format); 1104 1105 // Allocate shape of TF tensor 1106 TensorShape output_tf_shape; 1107 output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput))); 1108 if (native_format) { 1109 output_tf_shape = output_mkl_shape->GetTfShape(); 1110 } 1111 1112 if (fuse_add_) { 1113 const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add); 1114 MklDnnShape add_mkl_shape; 1115 GetMklShape(context, kInputIndex_Add, &add_mkl_shape, native_format); 1116 // Forward the summand tensor to the output only if it has no other 1117 // references, otherwise make a copy of it. 1118 if (native_format && context->forward_input_to_output_with_shape( 1119 kInputIndex_Add, kOutputIndex_Dst, 1120 output_tf_shape, output_tensor)) { 1121 return; 1122 } 1123 // Check if reorder is needed 1124 if (!native_format && add_mkl_shape == *output_mkl_shape && 1125 ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add, 1126 kOutputIndex_Dst, output_tensor, 1127 add_mkl_shape, false)) { 1128 return; 1129 } else { 1130 AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, 1131 output_tf_shape, *output_mkl_shape, 1132 native_format); 1133 auto output_format_tag = MklTensorFormatToMklDnnDataFormat( 1134 output_mkl_shape->GetTfDataFormat()); 1135 OP_REQUIRES(context, output_format_tag != memory::format_tag::undef, 1136 errors::InvalidArgument( 1137 "MklConvOp: AddN fusion: Invalid data format")); 1138 auto add_md = 1139 add_mkl_shape.IsMklTensor() 1140 ? add_mkl_shape.GetMklLayout() 1141 : memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(), 1142 output_format_tag); 1143 void* add_buf = static_cast<void*>( 1144 const_cast<Toutput*>(add_tensor.flat<Toutput>().data())); 1145 void* dst_buf = 1146 static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data()); 1147 if (native_format) { 1148 // We are simply deep copying the add_tensor to output_tensor without 1149 // changing memory layout, hence using same memory descriptor. 1150 add_md = dst_md = 1151 memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(), 1152 dnnl::memory::format_tag::x); 1153 } 1154 fuse_add_src_.reset(new memory(add_md, this->cpu_engine_, add_buf)); 1155 fuse_add_dst_.reset(new memory(dst_md, this->cpu_engine_, dst_buf)); 1156 auto reorder_desc = 1157 ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md); 1158 1159 CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_, 1160 this->cpu_engine_, context); 1161 } 1162 } else { 1163 AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, 1164 output_tf_shape, *output_mkl_shape, 1165 native_format); 1166 } 1167 } 1168 1169 engine cpu_engine_ = engine(engine::kind::cpu, 0); 1170 1171 private: 1172 std::shared_ptr<dnnl::memory> fuse_add_src_; 1173 std::shared_ptr<dnnl::memory> fuse_add_dst_; 1174 std::vector<int32> strides_; 1175 std::vector<int32> dilations_; 1176 std::vector<Tpadding> padding_list_; 1177 bool is_filter_const_; 1178 mutex mu_; 1179 Padding padding_; 1180 string data_format_str_; 1181 TensorFormat data_format_; 1182 Tensor cached_filter_data_ TF_GUARDED_BY(mu_); 1183 Tensor cached_filter_md_ TF_GUARDED_BY(mu_); 1184 1185 // Initialize to values the template is instantiated with 1186 bool fuse_biasadd_ = bias_enabled; 1187 bool fuse_activation_ = false; 1188 bool fuse_pad_ = pad_enabled; 1189 bool fuse_add_ = false; 1190 bool fuse_bn_ = false; 1191 float epsilon_ = 0.0001; 1192 1193 // This variable is used for alpha in leakyrelu or upper bound in relu6 1194 // depending on the context 1195 float alpha_or_upbound_ = 0.0; 1196 dnnl::algorithm activation_alg_ = dnnl::algorithm::undef; 1197 1198 int input_index_pad_ = 2; 1199 1200 const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2; 1201 const int kInputIndex_Add = 3; 1202 const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1; 1203 const int kDilationH = 0, kDilationW = 1; 1204 1205 // Input indices for FusedBatchNorm 1206 const int kInputIndex_BN_Scale = 2, kInputIndex_BN_Offset = 3; 1207 const int kInputIndex_BN_Mean = 4, kInputIndex_BN_Variance = 5; 1208 #ifdef DNNL_AARCH64_USE_ACL 1209 const int kFilterTensorHashLength = 1024; 1210 #endif 1211 GetFilterTfDataFormat(const MklDnnShape * filter_mkl_shape,const ConvFwdPd & conv_prim_desc) const1212 MklTensorFormat GetFilterTfDataFormat(const MklDnnShape* filter_mkl_shape, 1213 const ConvFwdPd& conv_prim_desc) const { 1214 DCHECK(filter_mkl_shape); 1215 return filter_mkl_shape->GetTfDataFormat(); 1216 } 1217 1218 // Allocate tensors for cached filter data and cached filter memory 1219 // descriptor (data format) AllocateTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** filter_tensor,const MklDnnShape * filter_mkl_shape)1220 void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, 1221 Tensor** filter_tensor, 1222 const MklDnnShape* filter_mkl_shape) 1223 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 1224 DCHECK(filter_tensor); 1225 TensorShape filter_tf_shape; 1226 filter_tf_shape.AddDim( 1227 (conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter))); 1228 OP_REQUIRES_OK( 1229 context, context->allocate_temp(DataTypeToEnum<Tfilter>::value, 1230 filter_tf_shape, &cached_filter_data_)); 1231 1232 *filter_tensor = &cached_filter_data_; 1233 1234 // There is no tensor format in DNNL 1.x. So we cache the complete filter 1235 // descriptor as flat byte array. 1236 TensorShape cached_filter_md_shape; 1237 memory::desc weights_desc = conv_prim_desc.weights_desc(); 1238 // We don't use .get_size() method of memory::desc since it returns size 1239 // required to store primitive's input memory. It is much more than size of 1240 // memory::desc itself. 1241 cached_filter_md_shape.AddDim(sizeof(weights_desc) / sizeof(uint8)); 1242 OP_REQUIRES_OK(context, 1243 context->allocate_temp(DT_UINT8, cached_filter_md_shape, 1244 &cached_filter_md_)); 1245 *reinterpret_cast<memory::desc*>(cached_filter_md_.flat<uint8>().data()) = 1246 weights_desc; 1247 } 1248 AllocateTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** filter_tensor)1249 void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, 1250 Tensor** filter_tensor) { 1251 AllocateTensor(context, conv_prim_desc, filter_tensor, nullptr); 1252 } 1253 AllocateFilterOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & filter_dims_tf_order,Tensor ** filter_tensor)1254 void AllocateFilterOutputTensor(OpKernelContext* context, 1255 const ConvFwdPd& conv_prim_desc, 1256 const memory::dims& filter_dims_tf_order, 1257 Tensor** filter_tensor) { 1258 DCHECK(filter_tensor); 1259 auto filter_md = conv_prim_desc.weights_desc(); 1260 1261 // Allocate shape of MKL tensor 1262 MklDnnShape filter_mkl_shape; 1263 filter_mkl_shape.SetMklTensor(true); 1264 filter_mkl_shape.SetMklLayout(&filter_md); 1265 filter_mkl_shape.SetElemType(MklDnnType<Tfilter>()); 1266 1267 // The format of the filter is actually OIhw8i8o, but TF doesn't support 1268 // this format. Just use format::blocked for now because the layout 1269 // is stored in the MKL data. 1270 filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(), 1271 filter_dims_tf_order, 1272 MklTensorFormat::FORMAT_BLOCKED); 1273 1274 // Allocate the data space for the filter to propagate as TF tensor. 1275 TensorShape filter_tf_shape; 1276 filter_tf_shape.AddDim((filter_md.get_size() / sizeof(Tfilter))); 1277 1278 AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor, 1279 filter_tf_shape, filter_mkl_shape); 1280 } 1281 1282 // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot 1283 // be acquired before entering the function, since it is acquired 1284 // inside the function. IsFilterCacheEmpty(OpKernelContext * context)1285 inline bool IsFilterCacheEmpty(OpKernelContext* context) 1286 TF_LOCKS_EXCLUDED(mu_) { 1287 tf_shared_lock lock(mu_); 1288 const Tensor& cached_filter_data_tensor = cached_filter_data_; 1289 return (cached_filter_data_tensor.NumElements() == 0); 1290 } 1291 1292 // Cache the converted filter in a tensor. 1293 // Only one thread can execute this method at any given time. CacheFilter(OpKernelContext * context,const std::shared_ptr<ConvFwdPd> & conv_fwd_pd,Tfilter * filter_data,const Tensor & filter_tensor,MklDnnData<Tfilter> & filter,const memory::desc & filter_md,const MklDnnShape & filter_mkl_shape)1294 void CacheFilter(OpKernelContext* context, 1295 const std::shared_ptr<ConvFwdPd>& conv_fwd_pd, 1296 Tfilter* filter_data, const Tensor& filter_tensor, 1297 MklDnnData<Tfilter>& filter, const memory::desc& filter_md, 1298 const MklDnnShape& filter_mkl_shape) TF_LOCKS_EXCLUDED(mu_) { 1299 mutex_lock lock(mu_); 1300 const Tensor& cached_filter_data_tensor = cached_filter_data_; 1301 1302 // If filter is already cached, there's nothing to do. 1303 if (cached_filter_data_tensor.NumElements() > 0) { 1304 return; 1305 } 1306 1307 // Otherwise, cache filter 1308 filter.SetUsrMem(filter_md, &filter_tensor); 1309 filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(), 1310 this->cpu_engine_, context); 1311 filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle()); 1312 1313 Tensor* filter_tensor_ptr = nullptr; 1314 AllocateTensor(context, *conv_fwd_pd, &filter_tensor_ptr, 1315 &filter_mkl_shape); 1316 void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr); 1317 size_t cached_filter_data_size = filter.GetOpMem().get_desc().get_size(); 1318 memcpy(cached_filter_data, filter_data, cached_filter_data_size); 1319 } 1320 AreMemoryDescriptorsEqual(const memory::desc & filter_md,const Tensor & cached_filter_md)1321 bool AreMemoryDescriptorsEqual(const memory::desc& filter_md, 1322 const Tensor& cached_filter_md) { 1323 auto filter_md_data = filter_md.data; 1324 const char* filter_data = reinterpret_cast<const char*>(&filter_md_data); 1325 1326 auto cached_filter_md_data = cached_filter_md.scalar<int64_t>()(); 1327 const char* cached_filter_data = 1328 reinterpret_cast<const char*>(&cached_filter_md_data); 1329 1330 for (size_t i = 0; i < sizeof(filter_md_data); ++i) { 1331 if (*filter_data++ != *cached_filter_data++) { 1332 return false; 1333 } 1334 } 1335 return true; 1336 } 1337 GetCachedFilter(OpKernelContext * context,const memory::desc & filter_md)1338 Tfilter* GetCachedFilter(OpKernelContext* context, 1339 const memory::desc& filter_md) 1340 TF_LOCKS_EXCLUDED(mu_) { 1341 tf_shared_lock lock(mu_); 1342 const Tensor& cached_filter_data = cached_filter_data_; 1343 const Tensor& cached_filter_md = cached_filter_md_; 1344 1345 // Check if the memory descriptor of the cached weights is the same as 1346 // filter_md. If so, we can use the cached weights; otherwise 1347 // return nullptr. 1348 if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) { 1349 return static_cast<Tfilter*>( 1350 const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data())); 1351 } 1352 return nullptr; 1353 } 1354 }; 1355 1356 // Base class for fused convolution forward operations 1357 template <typename Device, typename Tinput, typename Tfilter, typename Tbias, 1358 typename Toutput, typename Ttemp_output, typename Tpadding, 1359 bool pad_enabled, bool native_format> 1360 class MklFusedConvOp 1361 : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 1362 Tpadding, false, false, false, native_format> { 1363 public: MklFusedConvOp(OpKernelConstruction * context)1364 explicit MklFusedConvOp(OpKernelConstruction* context) 1365 : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 1366 Tpadding, false, false, false, native_format>(context) { 1367 // Since we came here through the registration of _MklFusedConv2D, get 1368 // all information from 'fused_ops' and 'num_args' 1369 std::vector<string> fused_ops; 1370 OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops)); 1371 1372 int num_args; 1373 OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); 1374 OP_REQUIRES(context, !fused_ops.empty(), 1375 errors::InvalidArgument( 1376 "Fused Conv2D must have at least one fused op.")); 1377 1378 // TODO(intel-tf): Compact the code for activation checking 1379 if (fused_ops == std::vector<string>{"BiasAdd"}) { 1380 this->set_fuse_biasadd(true); 1381 OP_REQUIRES(context, num_args == 1, 1382 errors::InvalidArgument( 1383 "Fused Conv2D must have one extra argument: bias.")); 1384 } else if (fused_ops == std::vector<string>{"Relu"}) { 1385 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); 1386 } else if (fused_ops == std::vector<string>{"Relu6"}) { 1387 this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, 1388 6.0); 1389 } else if (fused_ops == std::vector<string>{"Elu"}) { 1390 this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); 1391 } else if (fused_ops == std::vector<string>{"LeakyRelu"}) { 1392 float leakyrelu_alpha; 1393 OP_REQUIRES_OK(context, 1394 context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); 1395 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, 1396 leakyrelu_alpha); 1397 } else if (fused_ops == std::vector<string>{"FusedBatchNorm"}) { 1398 float epsilon; 1399 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 1400 OP_REQUIRES( 1401 context, num_args == 4, 1402 errors::InvalidArgument( 1403 "Fused Conv2D with batchnorm must have 4 extra argument")); 1404 this->set_fuse_bn(true, epsilon); 1405 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { 1406 this->set_fuse_biasadd(true); 1407 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); 1408 OP_REQUIRES(context, num_args == 1, 1409 errors::InvalidArgument( 1410 "Fused Conv2D must have one extra argument: bias.")); 1411 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) { 1412 this->set_fuse_biasadd(true); 1413 this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, 1414 6.0); 1415 OP_REQUIRES(context, num_args == 1, 1416 errors::InvalidArgument( 1417 "Fused Conv2D must have one extra argument: bias.")); 1418 } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) { 1419 this->set_fuse_biasadd(true); 1420 this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); 1421 OP_REQUIRES(context, num_args == 1, 1422 errors::InvalidArgument( 1423 "Fused Conv2D must have one extra argument: bias.")); 1424 } else if (fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"}) { 1425 this->set_fuse_biasadd(true); 1426 float leakyrelu_alpha; 1427 OP_REQUIRES_OK(context, 1428 context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); 1429 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, 1430 leakyrelu_alpha); 1431 OP_REQUIRES(context, num_args == 1, 1432 errors::InvalidArgument( 1433 "Fused Conv2D must have one extra argument: bias.")); 1434 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add"}) { 1435 this->set_fuse_biasadd(true); 1436 this->set_fuse_add(true); 1437 OP_REQUIRES( 1438 context, num_args == 2, 1439 errors::InvalidArgument( 1440 "Fused Conv2D must have two extra arguments: bias and add.")); 1441 } else if (fused_ops == std::vector<string>{"FusedBatchNorm", "Relu"}) { 1442 float epsilon; 1443 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 1444 OP_REQUIRES( 1445 context, num_args == 4, 1446 errors::InvalidArgument( 1447 "Fused Conv2D with batchnorm must have 4 extra argument")); 1448 this->set_fuse_bn(true, epsilon); 1449 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); 1450 } else if (fused_ops == std::vector<string>{"FusedBatchNorm", "Relu6"}) { 1451 float epsilon; 1452 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 1453 OP_REQUIRES( 1454 context, num_args == 4, 1455 errors::InvalidArgument( 1456 "Fused Conv2D with batchnorm must have 4 extra argument")); 1457 this->set_fuse_bn(true, epsilon); 1458 this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, 1459 6.0); 1460 } else if (fused_ops == std::vector<string>{"FusedBatchNorm", "Elu"}) { 1461 float epsilon; 1462 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 1463 OP_REQUIRES( 1464 context, num_args == 4, 1465 errors::InvalidArgument( 1466 "Fused Conv2D with batchnorm must have 4 extra argument")); 1467 this->set_fuse_bn(true, epsilon); 1468 this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); 1469 } else if (fused_ops == 1470 std::vector<string>{"FusedBatchNorm", "LeakyRelu"}) { 1471 float epsilon, leakyrelu_alpha; 1472 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 1473 OP_REQUIRES_OK(context, 1474 context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); 1475 OP_REQUIRES( 1476 context, num_args == 4, 1477 errors::InvalidArgument( 1478 "Fused Conv2D with batchnorm must have 4 extra argument")); 1479 this->set_fuse_bn(true, epsilon); 1480 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, 1481 leakyrelu_alpha); 1482 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) { 1483 this->set_fuse_biasadd(true); 1484 this->set_fuse_add(true); 1485 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); 1486 OP_REQUIRES( 1487 context, num_args == 2, 1488 errors::InvalidArgument( 1489 "Fused Conv2D must have two extra arguments: bias and add.")); 1490 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) { 1491 this->set_fuse_biasadd(true); 1492 this->set_fuse_add(true); 1493 this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, 1494 6.0); 1495 OP_REQUIRES( 1496 context, num_args == 2, 1497 errors::InvalidArgument( 1498 "Fused Conv2D must have two extra arguments: bias and add.")); 1499 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) { 1500 this->set_fuse_biasadd(true); 1501 this->set_fuse_add(true); 1502 this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); 1503 OP_REQUIRES( 1504 context, num_args == 2, 1505 errors::InvalidArgument( 1506 "Fused Conv2D must have two extra arguments: bias and add.")); 1507 } else if (fused_ops == 1508 std::vector<string>{"BiasAdd", "Add", "LeakyRelu"}) { 1509 this->set_fuse_biasadd(true); 1510 this->set_fuse_add(true); 1511 float leakyrelu_alpha; 1512 OP_REQUIRES_OK(context, 1513 context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); 1514 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, 1515 leakyrelu_alpha); 1516 OP_REQUIRES( 1517 context, num_args == 2, 1518 errors::InvalidArgument( 1519 "Fused Conv2D must have two extra arguments: bias and add.")); 1520 } else { 1521 OP_REQUIRES(context, false, 1522 errors::Unimplemented("Fusion is not implemented: [", 1523 absl::StrJoin(fused_ops, ","), "]")); 1524 } 1525 1526 if (pad_enabled) { 1527 this->set_fuse_pad(true); 1528 } 1529 } 1530 ComputeBNScale(OpKernelContext * context,float epsilon,int bn_variance_index,Tinput * scale_buf_ptr)1531 void ComputeBNScale(OpKernelContext* context, float epsilon, 1532 int bn_variance_index, Tinput* scale_buf_ptr) override { 1533 const Tensor& bn_var_tensor = MklGetInput(context, bn_variance_index); 1534 1535 Eigen::Tensor<Tinput, 1, Eigen::RowMajor> bn_rsqrt = 1536 (bn_var_tensor.flat<Tinput>() + static_cast<Tinput>(epsilon)).rsqrt(); 1537 Tinput* bn_rsqrt_data = bn_rsqrt.data(); 1538 size_t num_elem = bn_var_tensor.shape().dim_size(0); 1539 for (size_t i = 0; i < num_elem; i++) { 1540 scale_buf_ptr[i] = bn_rsqrt_data[i]; 1541 } 1542 return; 1543 } 1544 ~MklFusedConvOp()1545 virtual ~MklFusedConvOp() {} 1546 }; 1547 1548 template <typename Device, typename Tinput, typename Tfilter, typename Tbias, 1549 typename Toutput, typename Ttemp_output, typename Tpadding, 1550 bool pad_enabled, bool bias_enabled, bool is_depthwise, 1551 bool native_format> 1552 class MklFusedDepthwiseConvOp 1553 : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 1554 Tpadding, bias_enabled, false, is_depthwise, 1555 native_format> { 1556 public: MklFusedDepthwiseConvOp(OpKernelConstruction * context)1557 explicit MklFusedDepthwiseConvOp(OpKernelConstruction* context) 1558 : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 1559 Tpadding, bias_enabled, false, is_depthwise, native_format>( 1560 context) { 1561 // Since we came here through the registration of 1562 // _MklFusedDepthwiseConv2dNative, get all 1563 // information from 'fused_ops' and 'num_args' 1564 std::vector<string> fused_ops; 1565 OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops)); 1566 1567 int num_args; 1568 OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); 1569 OP_REQUIRES(context, !fused_ops.empty(), 1570 errors::InvalidArgument( 1571 "Fused DepthwiseConv2D must have at least one fused op.")); 1572 1573 if (fused_ops == std::vector<string>{"BiasAdd"}) { 1574 this->set_fuse_biasadd(true); 1575 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { 1576 this->set_fuse_biasadd(true); 1577 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); 1578 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) { 1579 this->set_fuse_biasadd(true); 1580 this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, 1581 6.0); 1582 } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) { 1583 this->set_fuse_biasadd(true); 1584 this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); 1585 } else { 1586 OP_REQUIRES(context, false, 1587 errors::Unimplemented("Fusion is not implemented: [", 1588 absl::StrJoin(fused_ops, ","), "]")); 1589 } 1590 1591 OP_REQUIRES( 1592 context, num_args == 1, 1593 errors::InvalidArgument( 1594 "Fused DepthwiseConv2D must have one extra argument: bias.")); 1595 1596 if (pad_enabled) { 1597 this->set_fuse_pad(true); 1598 } 1599 } 1600 ~MklFusedDepthwiseConvOp()1601 virtual ~MklFusedDepthwiseConvOp() {} 1602 }; 1603 1604 // We create new class for each version of Quantized Convolution and inherit 1605 // from the FP32 version of the base class 1606 template <typename Device, typename Tinput, typename Tbias, typename Toutput, 1607 typename Ttemp_output, bool bias_enabled, bool is_depthwise, 1608 bool native_format = false> 1609 class MklQuantizedConv2DOp 1610 : public MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, 1611 int32, bias_enabled, false, is_depthwise, 1612 native_format> { 1613 public: ~MklQuantizedConv2DOp()1614 virtual ~MklQuantizedConv2DOp() { 1615 if (this->input_bias_ != nullptr) { 1616 delete this->input_bias_; 1617 input_bias_ = nullptr; 1618 } 1619 1620 if (this->scaled_bias_ != nullptr) { 1621 delete this->scaled_bias_; 1622 scaled_bias_ = nullptr; 1623 } 1624 } 1625 MklQuantizedConv2DOp(OpKernelConstruction * context)1626 explicit MklQuantizedConv2DOp(OpKernelConstruction* context) 1627 : MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, 1628 bias_enabled, false, is_depthwise, native_format>(context) { 1629 bool is_filter_const; 1630 OP_REQUIRES_OK(context, 1631 context->GetAttr("is_filter_const", &is_filter_const)); 1632 1633 if (bias_enabled) { 1634 OP_REQUIRES_OK(context, 1635 context->GetAttr("is_bias_const", &is_bias_const_)); 1636 } 1637 1638 OP_REQUIRES(context, is_filter_const, 1639 errors::InvalidArgument("Filter must be a constant")); 1640 } 1641 Compute(OpKernelContext * context)1642 void Compute(OpKernelContext* context) override { 1643 // Compute int32 output tensor 1644 MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, 1645 bias_enabled, false, is_depthwise, 1646 native_format>::Compute(context); 1647 1648 // Compute additional outputs: min/max scalars. 1649 int bias_index_offset; 1650 bias_index_offset = bias_enabled ? 1 : 0; 1651 1652 const float min_input = 1653 context->input(2 + bias_index_offset).flat<float>()(0); 1654 const float max_input = 1655 context->input(3 + bias_index_offset).flat<float>()(0); 1656 1657 MklDnnShape output_min_mkl_shape, output_max_mkl_shape; 1658 output_min_mkl_shape.SetMklTensor(false); 1659 output_max_mkl_shape.SetMklTensor(false); 1660 1661 Tensor* output_min = nullptr; 1662 Tensor* output_max = nullptr; 1663 if (std::is_same<Toutput, quint8>::value || 1664 std::is_same<Toutput, qint8>::value) { 1665 AllocateOutputSetMklShape(context, 1, &output_min, {}, 1666 output_min_mkl_shape, native_format); 1667 AllocateOutputSetMklShape(context, 2, &output_max, {}, 1668 output_max_mkl_shape, native_format); 1669 // This is the case the convolution and requantization are fused. 1670 output_min->flat<float>()(0) = 1671 context->input(6 + bias_index_offset).flat<float>()(0); 1672 output_max->flat<float>()(0) = 1673 context->input(7 + bias_index_offset).flat<float>()(0); 1674 } else { 1675 const Tensor& min_filter = context->input(4 + bias_index_offset); 1676 const Tensor& max_filter = context->input(5 + bias_index_offset); 1677 if (min_filter.dims() == 0) { 1678 float min_output_value; 1679 float max_output_value; 1680 MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>( 1681 min_input, max_input, min_filter.flat<float>()(0), 1682 max_filter.flat<float>()(0), &min_output_value, &max_output_value); 1683 AllocateOutputSetMklShape(context, 1, &output_min, {}, 1684 output_min_mkl_shape, native_format); 1685 AllocateOutputSetMklShape(context, 2, &output_max, {}, 1686 output_max_mkl_shape, native_format); 1687 output_min->flat<float>()(0) = min_output_value; 1688 output_max->flat<float>()(0) = max_output_value; 1689 } else { 1690 size_t depth = min_filter.NumElements(); 1691 AllocateOutputSetMklShape(context, 1, &output_min, 1692 {static_cast<ptrdiff_t>(depth)}, 1693 output_min_mkl_shape, native_format); 1694 AllocateOutputSetMklShape(context, 2, &output_max, 1695 {static_cast<ptrdiff_t>(depth)}, 1696 output_max_mkl_shape, native_format); 1697 MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>( 1698 min_input, max_input, min_filter, max_filter, &output_min, 1699 &output_max); 1700 } 1701 } 1702 } 1703 1704 protected: ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1705 void ExtendConvFwdParams(OpKernelContext* context, 1706 MklConvFwdParams& params) override { 1707 MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, 1708 bias_enabled, false, is_depthwise, 1709 native_format>::ExtendConvFwdParams(context, params); 1710 1711 // When the output type is quint8, the output data id requantized 1712 // into quint8. A post_op "output_scale" is added to do the conversion. 1713 if (std::is_same<Toutput, quint8>::value || 1714 std::is_same<Toutput, qint8>::value) { 1715 int bias_index_offset; 1716 bias_index_offset = bias_enabled ? 1 : 0; 1717 1718 const float min_input = 1719 context->input(2 + bias_index_offset).flat<float>()(0); 1720 const float max_input = 1721 context->input(3 + bias_index_offset).flat<float>()(0); 1722 const Tensor& min_filter_vector = context->input(4 + bias_index_offset); 1723 const Tensor& max_filter_vector = context->input(5 + bias_index_offset); 1724 1725 // min_freezed_output and max_freezed_output are the actual range 1726 // for the output. 1727 const float min_freezed_output = 1728 context->input(6 + bias_index_offset).flat<float>()(0); 1729 const float max_freezed_output = 1730 context->input(7 + bias_index_offset).flat<float>()(0); 1731 1732 float int_output_limit = 1733 std::is_same<Toutput, quint8>::value ? 255.0f : 127.0f; 1734 size_t depth = min_filter_vector.NumElements(); 1735 const float* min_filter = min_filter_vector.flat<float>().data(); 1736 const float* max_filter = max_filter_vector.flat<float>().data(); 1737 std::vector<float> scales(depth); 1738 float float_input_range = 1739 std::max(std::abs(min_input), std::abs(max_input)); 1740 float float_output_range = 1741 std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); 1742 const float int_const_scale_limit = 1743 (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0; 1744 for (size_t i = 0; i < depth; ++i) { 1745 // For simplicity and symmetry, we set filter range to be outer 1746 // bounds of min_filter and max_filter. 1747 float float_filter_range = 1748 std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); 1749 // To understand the scaling, please see mkl_requantize_ops_test. 1750 scales[i] = int_output_limit * float_input_range * float_filter_range / 1751 (int_const_scale_limit * float_output_range); 1752 } 1753 // we are creating a partial key here to use with primitive key caching to 1754 // improve key creation performance. Instead of using actual values we are 1755 // using the pointers for min/max_filter_vector, and this works since the 1756 // filter vector here is a constant. 1757 FactoryKeyCreator param_key; 1758 param_key.AddAsKey<float>(min_input); 1759 param_key.AddAsKey<float>(max_input); 1760 param_key.AddAsKey<float>(min_freezed_output); 1761 param_key.AddAsKey<float>(max_freezed_output); 1762 param_key.AddAsKey<const float*>(min_filter); 1763 param_key.AddAsKey<const float*>(max_filter); 1764 params.post_op_params.push_back( 1765 {"output_scale", dnnl::algorithm::undef, scales, param_key.GetKey()}); 1766 } 1767 } 1768 GetBiasHandle(OpKernelContext * context,std::shared_ptr<ConvFwdPd> & conv_fwd_pd,const Tensor & bias_tensor)1769 Tbias* GetBiasHandle(OpKernelContext* context, 1770 std::shared_ptr<ConvFwdPd>& conv_fwd_pd, 1771 const Tensor& bias_tensor) override { 1772 if (!bias_enabled) { 1773 return nullptr; 1774 } 1775 if (std::is_same<Tbias, qint32>::value) { 1776 return static_cast<Tbias*>( 1777 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 1778 } 1779 int bias_index_offset; 1780 bias_index_offset = bias_enabled ? 1 : 0; 1781 1782 const float min_input = 1783 context->input(2 + bias_index_offset).flat<float>()(0); 1784 const float max_input = 1785 context->input(3 + bias_index_offset).flat<float>()(0); 1786 const Tensor& min_filter_vector = context->input(4 + bias_index_offset); 1787 const Tensor& max_filter_vector = context->input(5 + bias_index_offset); 1788 const float* min_filter = min_filter_vector.flat<float>().data(); 1789 const float* max_filter = max_filter_vector.flat<float>().data(); 1790 1791 const float int_const_scale_limit = 1792 (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0; 1793 // Re-scale bias if either of following 2 conditions are met: 1794 // 1. Bias is not const; 1795 // 2. Bias is const, but bias cache is empty (first iteration). 1796 1797 size_t depth = min_filter_vector.NumElements(); 1798 bool scales_are_valid = (depth == scales_.size()); 1799 scales_.resize(depth); 1800 for (size_t i = 0; i < depth; ++i) { 1801 float tmp_scale = 1802 int_const_scale_limit / 1803 (std::max(std::abs(max_input), std::abs(min_input)) * 1804 std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); 1805 if (scales_are_valid && std::abs(tmp_scale - scales_[i]) > 1e-6) { 1806 scales_are_valid = false; 1807 } 1808 scales_[i] = tmp_scale; 1809 } 1810 if (!is_bias_const_ || IsBiasCacheEmpty(context) || !scales_are_valid) { 1811 dnnl::primitive_attr bias_attr; 1812 if (depth == 1) { 1813 bias_attr.set_output_scales(0, scales_); 1814 } else { 1815 bias_attr.set_output_scales(1, scales_); 1816 } 1817 1818 auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())}, 1819 MklDnnType<Tbias>(), memory::format_tag::x); 1820 void* bias_buf = static_cast<void*>( 1821 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 1822 if (!input_bias_) { 1823 input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf); 1824 } else { 1825 input_bias_->set_data_handle(bias_buf); 1826 } 1827 1828 if (!scaled_bias_buf_) 1829 AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_, 1830 conv_fwd_pd->bias_desc(), &scaled_bias_buf_); 1831 if (!scaled_bias_) { 1832 scaled_bias_ = new memory(bias_md, this->cpu_engine_, scaled_bias_buf_); 1833 } else { 1834 scaled_bias_->set_data_handle(scaled_bias_buf_); 1835 } 1836 auto reorder_desc = 1837 ReorderPd(this->cpu_engine_, input_bias_->get_desc(), 1838 this->cpu_engine_, scaled_bias_->get_desc(), bias_attr); 1839 CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_, 1840 this->cpu_engine_, context); 1841 1842 Tbias* bias_data = 1843 reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle()); 1844 if (is_bias_const_) 1845 CacheBias(context, conv_fwd_pd, bias_data, scaled_bias_); 1846 1847 return bias_data; 1848 } 1849 return GetCachedBias(context); 1850 } 1851 1852 bool is_bias_const_; 1853 Tensor cached_bias_data_ TF_GUARDED_BY(bias_cache_mu_); 1854 1855 memory* input_bias_ = nullptr; 1856 memory* scaled_bias_ = nullptr; 1857 1858 Tensor scaled_bias_tensor_; 1859 void* scaled_bias_buf_ = nullptr; 1860 1861 private: 1862 std::vector<float> scales_; 1863 mutex bias_cache_mu_; 1864 // Allocate tensors for cached bias data and 1865 // cached bias memory descriptor (data format) AllocateTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,Tensor ** bias_tensor)1866 void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, 1867 Tensor** bias_tensor) { 1868 DCHECK(bias_tensor); 1869 TensorShape bias_tf_shape; 1870 bias_tf_shape.AddDim( 1871 (conv_prim_desc.bias_desc().get_size() / sizeof(Tbias))); 1872 OP_REQUIRES_OK(context, 1873 context->allocate_temp(DataTypeToEnum<Tbias>::value, 1874 bias_tf_shape, &cached_bias_data_)); 1875 *bias_tensor = &cached_bias_data_; 1876 } 1877 1878 // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot 1879 // be acquired before entering the function, since it is acquired 1880 // inside the function. IsBiasCacheEmpty(OpKernelContext * context)1881 inline bool IsBiasCacheEmpty(OpKernelContext* context) 1882 TF_LOCKS_EXCLUDED(bias_cache_mu_) { 1883 tf_shared_lock lock(bias_cache_mu_); 1884 return (cached_bias_data_.NumElements() == 0); 1885 } 1886 1887 // Cache the converted bias in a tensor. 1888 // Only one thread can execute this method at any given time. CacheBias(OpKernelContext * context,const std::shared_ptr<ConvFwdPd> & conv_fwd_pd,Tbias * bias_data,const memory * scaled_bias)1889 void CacheBias(OpKernelContext* context, 1890 const std::shared_ptr<ConvFwdPd>& conv_fwd_pd, 1891 Tbias* bias_data, const memory* scaled_bias) 1892 TF_LOCKS_EXCLUDED(bias_cache_mu_) { 1893 mutex_lock lock(bias_cache_mu_); 1894 1895 // If bias is already cached, there's nothing to do. 1896 if (cached_bias_data_.NumElements() > 0) { 1897 return; 1898 } 1899 1900 // Otherwise, cache bias 1901 Tensor* bias_tensor_ptr = nullptr; 1902 AllocateTensor(context, *conv_fwd_pd, &bias_tensor_ptr); 1903 void* cached_bias_data = const_cast<void*>( 1904 static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data())); 1905 size_t cached_bias_data_size = scaled_bias->get_desc().get_size(); 1906 memcpy(cached_bias_data, bias_data, cached_bias_data_size); 1907 } 1908 GetCachedBias(OpKernelContext * context)1909 Tbias* GetCachedBias(OpKernelContext* context) 1910 TF_LOCKS_EXCLUDED(bias_cache_mu_) { 1911 tf_shared_lock lock(bias_cache_mu_); 1912 const Tensor& cached_bias_data = cached_bias_data_; 1913 1914 return static_cast<Tbias*>( 1915 const_cast<Tbias*>(cached_bias_data.flat<Tbias>().data())); 1916 } 1917 }; 1918 1919 template <typename Device, typename Tinput, typename Tbias, typename Toutput, 1920 typename Ttemp_output, bool bias_enabled, bool is_depthwise, 1921 bool native_format = false> 1922 class MklQuantizedConv2DReluOp 1923 : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1924 bias_enabled, is_depthwise, native_format> { 1925 public: ~MklQuantizedConv2DReluOp()1926 virtual ~MklQuantizedConv2DReluOp() {} 1927 MklQuantizedConv2DReluOp(OpKernelConstruction * context)1928 explicit MklQuantizedConv2DReluOp(OpKernelConstruction* context) 1929 : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1930 bias_enabled, is_depthwise, native_format>( 1931 context) {} 1932 1933 protected: ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1934 void ExtendConvFwdParams(OpKernelContext* context, 1935 MklConvFwdParams& params) override { 1936 MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1937 bias_enabled, is_depthwise, 1938 native_format>::ExtendConvFwdParams(context, params); 1939 1940 params.post_op_params.push_back( 1941 {"activation", dnnl::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""}); 1942 } 1943 }; 1944 1945 template <typename Device, typename Tinput, typename Tbias, typename Toutput, 1946 typename Ttemp_output, bool bias_enabled, bool is_depthwise, 1947 bool native_format = false> 1948 class MklQuantizedConv2DSumReluOp 1949 : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1950 bias_enabled, is_depthwise, native_format> { 1951 public: ~MklQuantizedConv2DSumReluOp()1952 virtual ~MklQuantizedConv2DSumReluOp() {} 1953 MklQuantizedConv2DSumReluOp(OpKernelConstruction * context)1954 explicit MklQuantizedConv2DSumReluOp(OpKernelConstruction* context) 1955 : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1956 bias_enabled, is_depthwise, native_format>( 1957 context) {} 1958 1959 protected: ExtendConvFwdParams(OpKernelContext * context,MklConvFwdParams & params)1960 void ExtendConvFwdParams(OpKernelContext* context, 1961 MklConvFwdParams& params) override { 1962 MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, 1963 bias_enabled, is_depthwise, 1964 native_format>::ExtendConvFwdParams(context, params); 1965 // Calculate the scale (beta in oneDNN API term) for sum 1966 if (std::is_same<Toutput, quint8>::value) { 1967 int summand_idx = native_format ? context->num_inputs() - 1 - 2 1968 : context->num_inputs() / 2 - 1 - 2; 1969 DataType summand_type = this->input_type(summand_idx); 1970 bool summand_condition = 1971 (summand_type == DT_QINT8) || (summand_type == DT_QUINT8); 1972 CHECK((summand_condition)); 1973 int bias_index_offset = bias_enabled ? 1 : 0; 1974 const float min_freezed_output = 1975 context->input(6 + bias_index_offset).flat<float>()(0); 1976 const float max_freezed_output = 1977 context->input(7 + bias_index_offset).flat<float>()(0); 1978 const float min_freezed_summand = 1979 context->input(9 + bias_index_offset).flat<float>()(0); 1980 const float max_freezed_summand = 1981 context->input(10 + bias_index_offset).flat<float>()(0); 1982 1983 float scale_output = 1984 std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); 1985 float scale_summand = std::max(std::abs(min_freezed_summand), 1986 std::abs(max_freezed_summand)); 1987 // if summand_type is also DT_QUINT8 as the scale_output, 1988 // the scaling factor of 255.0f cancels each other and thus is avoided. 1989 // If it is not then it is DT_INT8 and is scaled appropriately. 1990 if (summand_type == DT_QUINT8) { 1991 params.post_op_params.push_back({"sum", 1992 dnnl::algorithm::undef, 1993 {scale_summand / scale_output}, 1994 ""}); 1995 } else { 1996 params.post_op_params.push_back( 1997 {"sum", 1998 dnnl::algorithm::undef, 1999 {255.0f * scale_summand / (scale_output * 127.0f)}, 2000 ""}); 2001 } 2002 } else { 2003 params.post_op_params.push_back( 2004 {"sum", dnnl::algorithm::undef, {1.0}, ""}); 2005 } 2006 params.post_op_params.push_back( 2007 {"activation", dnnl::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""}); 2008 } 2009 AllocateOutputTensor(OpKernelContext * context,const ConvFwdPd & conv_prim_desc,const memory::dims & output_dims_mkl_order,MklTensorFormat output_tf_format,MklDnnShape * output_mkl_shape,Tensor ** output_tensor)2010 void AllocateOutputTensor(OpKernelContext* context, 2011 const ConvFwdPd& conv_prim_desc, 2012 const memory::dims& output_dims_mkl_order, 2013 MklTensorFormat output_tf_format, 2014 MklDnnShape* output_mkl_shape, 2015 Tensor** output_tensor) override { 2016 int summand_idx = native_format ? context->num_inputs() - 1 2017 : context->num_inputs() / 2 - 1; 2018 if (std::is_same<Toutput, quint8>::value) { 2019 summand_idx -= 2; 2020 DataType summand_type = this->input_type(summand_idx); 2021 bool summand_condition = 2022 (summand_type == DT_QINT8) || (summand_type == DT_QUINT8); 2023 CHECK((summand_condition)); 2024 Tensor& summand = const_cast<Tensor&>(MklGetInput(context, summand_idx)); 2025 MklDnnShape summand_mkl_shape; 2026 GetMklShape(context, summand_idx, &summand_mkl_shape, native_format); 2027 auto dst_md = summand_mkl_shape.GetMklLayout(); 2028 2029 // TODO(intel-tf): Handle both non-MKL and MKL tensors 2030 if (summand_type == DT_QINT8) { 2031 OP_REQUIRES_OK( 2032 context, summand.BitcastFrom(summand, DT_QUINT8, summand.shape())); 2033 dst_md.data.data_type = 2034 static_cast<dnnl_data_type_t>(MklDnnType<Toutput>()); 2035 summand_mkl_shape.SetMklLayout(&dst_md); 2036 summand_mkl_shape.SetElemType(MklDnnType<Toutput>()); 2037 } 2038 // TODO(intel-tf): Support cases when summand cannot be forwarded. 2039 OP_REQUIRES(context, 2040 native_format 2041 ? context->forward_input_to_output_with_shape( 2042 summand_idx, 0, summand.shape(), output_tensor) 2043 : ForwardMklTensorInToOutWithMklShape( 2044 context, summand_idx, 0, output_tensor, 2045 summand_mkl_shape, false), 2046 errors::InvalidArgument( 2047 "Summand cannot be forwarded in the current fusion.")); 2048 return; 2049 } 2050 MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, 2051 bias_enabled, false, false, 2052 native_format>::AllocateOutputTensor(context, conv_prim_desc, 2053 output_dims_mkl_order, 2054 output_tf_format, 2055 output_mkl_shape, 2056 output_tensor); 2057 const Tensor& summand = MklGetInput(context, summand_idx); 2058 if (summand.dtype() != DT_FLOAT) 2059 TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION, 2060 "Current fusion requires summand to be float")); 2061 MklDnnShape summand_mkl_shape; 2062 GetMklShape(context, summand_idx, &summand_mkl_shape, native_format); 2063 // We need to compute scale for the summand 2064 int bias_index_offset = bias_enabled ? 1 : 0; 2065 const float min_input = 2066 context->input(2 + bias_index_offset).flat<float>()(0); 2067 const float max_input = 2068 context->input(3 + bias_index_offset).flat<float>()(0); 2069 const Tensor& min_filter_vector = context->input(4 + bias_index_offset); 2070 const Tensor& max_filter_vector = context->input(5 + bias_index_offset); 2071 const float* min_filter = min_filter_vector.flat<float>().data(); 2072 const float* max_filter = max_filter_vector.flat<float>().data(); 2073 2074 const float int_const_scale_limit = 2075 (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0; 2076 size_t depth = min_filter_vector.NumElements(); 2077 std::vector<float> scales(depth); 2078 for (size_t i = 0; i < depth; ++i) { 2079 // TODO(intel-tf): scale factors for UINT8(inputs) & INT8(weights) are 2080 // done regularly. A Cleaner design to address all mapping in one 2081 // function needs to be implemented in future which also supports other 2082 // quantized type mapping in future. 2083 scales[i] = int_const_scale_limit / 2084 (std::max(std::abs(max_input), std::abs(min_input)) * 2085 std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); 2086 } 2087 dnnl::primitive_attr reorder_attr; 2088 if (depth == 1) { 2089 reorder_attr.set_output_scales(0, scales); 2090 } else { 2091 reorder_attr.set_output_scales(2, scales); 2092 } 2093 auto summand_md = 2094 summand_mkl_shape.IsMklTensor() 2095 ? summand_mkl_shape.GetMklLayout() 2096 : memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(), 2097 memory::format_tag::nhwc); 2098 void* summand_buf = 2099 static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data())); 2100 void* dst_buf = 2101 static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data()); 2102 summand_.reset(new memory(summand_md, this->cpu_engine_, summand_buf)); 2103 dst_.reset( 2104 new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf)); 2105 auto reorder_desc = 2106 ReorderPd(this->cpu_engine_, summand_md, this->cpu_engine_, 2107 conv_prim_desc.dst_desc(), reorder_attr); 2108 CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_, 2109 context); 2110 } 2111 2112 std::shared_ptr<dnnl::memory> summand_; 2113 std::shared_ptr<dnnl::memory> dst_; 2114 }; 2115 2116 // Base class for fused convolution forward operations 2117 template <typename Device, typename Tinput, typename Tfilter, typename Tbias, 2118 typename Toutput, typename Ttemp_output, typename Tpadding, 2119 bool pad_enabled, bool native_format> 2120 class MklFusedConv3DOp 2121 : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 2122 Tpadding, false, false, false, native_format> { 2123 public: MklFusedConv3DOp(OpKernelConstruction * context)2124 explicit MklFusedConv3DOp(OpKernelConstruction* context) 2125 : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, 2126 Tpadding, false, false, false, native_format>(context) { 2127 // Since we came here through the registration of _MklFusedConv3D, get 2128 // all information from 'fused_ops' and 'num_args' 2129 std::vector<string> fused_ops; 2130 OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops)); 2131 2132 int num_args; 2133 OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); 2134 2135 std::vector<int> padding_list; 2136 OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list)); 2137 if (padding_list.empty()) { 2138 OP_REQUIRES(context, !fused_ops.empty(), 2139 errors::InvalidArgument("Fused Conv3D must have at least one " 2140 "fused op when Pad is not fused.")); 2141 if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == 2142 fused_ops.end()) { 2143 OP_REQUIRES(context, num_args == 1, 2144 errors::InvalidArgument( 2145 "Fused Conv3D must have one extra argument: bias.")); 2146 } else if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == 2147 fused_ops.end() && 2148 std::find(fused_ops.begin(), fused_ops.end(), "Add") == 2149 fused_ops.end()) { 2150 OP_REQUIRES( 2151 context, num_args == 2, 2152 errors::InvalidArgument( 2153 "Fused Conv3D must have two extra arguments: bias and add.")); 2154 } 2155 } 2156 2157 if (fused_ops == std::vector<string>{"BiasAdd"}) { 2158 this->set_fuse_biasadd(true); 2159 } else if (fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"}) { 2160 this->set_fuse_biasadd(true); 2161 float leakyrelu_alpha; 2162 OP_REQUIRES_OK(context, 2163 context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); 2164 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, 2165 leakyrelu_alpha); 2166 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { 2167 this->set_fuse_biasadd(true); 2168 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); 2169 } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) { 2170 this->set_fuse_biasadd(true); 2171 this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, 2172 6.0); 2173 } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) { 2174 this->set_fuse_biasadd(true); 2175 this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); 2176 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add"}) { 2177 this->set_fuse_biasadd(true); 2178 this->set_fuse_add(true); 2179 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) { 2180 this->set_fuse_biasadd(true); 2181 this->set_fuse_add(true); 2182 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); 2183 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) { 2184 this->set_fuse_biasadd(true); 2185 this->set_fuse_add(true); 2186 this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, 2187 6.0); 2188 } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) { 2189 this->set_fuse_biasadd(true); 2190 this->set_fuse_add(true); 2191 this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); 2192 } else if (fused_ops == 2193 std::vector<string>{"BiasAdd", "Add", "LeakyRelu"}) { 2194 this->set_fuse_biasadd(true); 2195 this->set_fuse_add(true); 2196 float leakyrelu_alpha; 2197 OP_REQUIRES_OK(context, 2198 context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); 2199 this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, 2200 leakyrelu_alpha); 2201 } else { 2202 if (padding_list.empty()) { 2203 OP_REQUIRES(context, false, 2204 errors::Unimplemented("Fusion is not implemented: [", 2205 absl::StrJoin(fused_ops, ","), "]")); 2206 } 2207 } 2208 } 2209 ~MklFusedConv3DOp()2210 virtual ~MklFusedConv3DOp() {} 2211 }; 2212 2213 #define REGISTER_MKL_KERNEL(op, kernel, input_type, bias_type, output_type, \ 2214 accu_type, has_bias, is_depthwise, is_native) \ 2215 REGISTER_KERNEL_BUILDER( \ 2216 Name(op) \ 2217 .Device(DEVICE_CPU) \ 2218 .TypeConstraint<input_type>("Tinput") \ 2219 .TypeConstraint<qint8>("Tfilter") BIAS_TYPE_CONSTRAINT(bias_type) \ 2220 .TypeConstraint<output_type>("out_type") LABEL, \ 2221 kernel TEMPLATE_ARGS(CPUDevice, input_type, bias_type, output_type, \ 2222 accu_type, has_bias, is_depthwise, is_native)); 2223 2224 #define REGISTER_MKL_KERNEL_ALL_INPUT_TYPES(op, kernel, bias_type, \ 2225 output_type, accu_type, has_bias, \ 2226 is_depthwise, is_native) \ 2227 REGISTER_MKL_KERNEL(op, kernel, qint8, bias_type, output_type, accu_type, \ 2228 has_bias, is_depthwise, is_native); \ 2229 REGISTER_MKL_KERNEL(op, kernel, quint8, bias_type, output_type, accu_type, \ 2230 has_bias, is_depthwise, is_native); 2231 2232 #define REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(op, kernel, input_type, \ 2233 output_type, accu_type, has_bias, \ 2234 is_depthwise, is_native) \ 2235 REGISTER_MKL_KERNEL(op, kernel, input_type, qint32, output_type, accu_type, \ 2236 has_bias, is_depthwise, is_native); \ 2237 REGISTER_MKL_KERNEL(op, kernel, input_type, float, output_type, accu_type, \ 2238 has_bias, is_depthwise, is_native); 2239 2240 #define REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES( \ 2241 op, kernel, output_type, accu_type, has_bias, is_depthwise, is_native) \ 2242 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES(op, kernel, qint32, output_type, \ 2243 accu_type, has_bias, is_depthwise, \ 2244 is_native); \ 2245 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES(op, kernel, float, output_type, \ 2246 accu_type, has_bias, is_depthwise, \ 2247 is_native); 2248 2249 #define LABEL 2250 #define TEMPLATE_ARGS(CPUDevice, input_type, bias_type, output_type, \ 2251 accu_type, has_bias, is_depthwise, is_native) 2252 #define BIAS_TYPE_CONSTRAINT(bias_type) 2253 2254 REGISTER_MKL_KERNEL("QuantizedConv2D", NoOp, quint8, float, qint32, qint32, 2255 false, false, false); 2256 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("QuantizedConv2DWithBias", NoOp, float, 2257 qint32, qint32, false, false, false); 2258 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("QuantizedConv2DWithBiasAndRelu", NoOp, 2259 float, qint32, qint32, false, false, false); 2260 REGISTER_MKL_KERNEL("QuantizedConv2DWithBiasSumAndRelu", NoOp, quint8, float, 2261 qint32, qint32, false, false, false); 2262 REGISTER_MKL_KERNEL("QuantizedConv2DAndRequantize", NoOp, quint8, float, qint8, 2263 qint8, false, false, false); 2264 REGISTER_MKL_KERNEL("QuantizedConv2DPerChannel", NoOp, quint8, float, qint32, 2265 qint32, false, false, false); 2266 REGISTER_MKL_KERNEL("QuantizedConv2DAndRelu", NoOp, quint8, float, qint32, 2267 qint32, false, false, false); 2268 REGISTER_MKL_KERNEL("QuantizedConv2DAndReluAndRequantize", NoOp, quint8, float, 2269 quint8, quint8, false, false, false); 2270 REGISTER_MKL_KERNEL("QuantizedDepthwiseConv2D", NoOp, quint8, float, qint32, 2271 qint32, false, false, false); 2272 REGISTER_MKL_KERNEL("QuantizedDepthwiseConv2DWithBias", NoOp, quint8, float, 2273 qint32, qint32, false, false, false); 2274 REGISTER_MKL_KERNEL("QuantizedDepthwiseConv2DWithBiasAndRelu", NoOp, quint8, 2275 float, qint32, qint32, false, false, false); 2276 #undef BIAS_TYPE_CONSTRAINT 2277 2278 #define BIAS_TYPE_CONSTRAINT(bias_type) .TypeConstraint<bias_type>("Tbias") 2279 REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES( 2280 "QuantizedConv2DWithBiasAndRequantize", NoOp, qint8, qint8, false, false, 2281 false); 2282 REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES( 2283 "QuantizedConv2DWithBiasAndReluAndRequantize", NoOp, quint8, quint8, false, 2284 false, false); 2285 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( 2286 "QuantizedConv2DWithBiasSumAndReluAndRequantize", NoOp, quint8, quint8, 2287 quint8, false, false, false); 2288 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( 2289 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", NoOp, quint8, 2290 quint8, qint8, false, false, false); 2291 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( 2292 "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", NoOp, quint8, 2293 quint8, quint8, false, false, false); 2294 #undef BIAS_TYPE_CONSTRAINT 2295 #undef TEMPLATE_ARGS 2296 #undef LABEL 2297 2298 #define LABEL .Label(mkl_op_registry::kMklQuantizedOpLabel) 2299 #define TEMPLATE_ARGS(CPUDevice, input_type, bias_type, output_type, \ 2300 accu_type, has_bias, is_depthwise, is_native) \ 2301 <CPUDevice, input_type, bias_type, output_type, accu_type, has_bias, \ 2302 is_depthwise, is_native> 2303 #define BIAS_TYPE_CONSTRAINT(bias_type) 2304 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("_MklQuantizedConv2D", MklQuantizedConv2DOp, 2305 float, qint32, qint32, false, false, true); 2306 REGISTER_MKL_KERNEL("_MklQuantizedConv2DPerChannel", MklQuantizedConv2DOp, 2307 quint8, float, qint32, qint32, false, false, true); 2308 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("_MklQuantizedConv2DWithBias", 2309 MklQuantizedConv2DOp, float, qint32, qint32, 2310 true, false, true); 2311 REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("_MklQuantizedConv2DWithBiasAndRelu", 2312 MklQuantizedConv2DReluOp, float, qint32, 2313 qint32, true, false, true); 2314 REGISTER_MKL_KERNEL("_MklQuantizedConv2DWithBiasSumAndRelu", 2315 MklQuantizedConv2DSumReluOp, quint8, float, qint32, qint32, 2316 true, false, true); 2317 REGISTER_MKL_KERNEL("_MklQuantizedConv2DAndRequantize", MklQuantizedConv2DOp, 2318 quint8, float, qint8, qint8, false, false, true); 2319 REGISTER_MKL_KERNEL("_MklQuantizedConv2DAndRelu", MklQuantizedConv2DReluOp, 2320 quint8, float, qint32, qint32, false, false, true); 2321 REGISTER_MKL_KERNEL("_MklQuantizedConv2DAndReluAndRequantize", 2322 MklQuantizedConv2DReluOp, quint8, float, quint8, quint8, 2323 false, false, true); 2324 REGISTER_MKL_KERNEL("_MklQuantizedDepthwiseConv2D", MklQuantizedConv2DOp, 2325 quint8, float, qint32, qint32, false, true, true); 2326 REGISTER_MKL_KERNEL("_MklQuantizedDepthwiseConv2DWithBias", 2327 MklQuantizedConv2DOp, quint8, float, qint32, qint32, true, 2328 true, true); 2329 REGISTER_MKL_KERNEL("_MklQuantizedDepthwiseConv2DWithBiasAndRelu", 2330 MklQuantizedConv2DReluOp, quint8, float, qint32, qint32, 2331 true, true, true); 2332 #undef BIAS_TYPE_CONSTRAINT 2333 2334 #define BIAS_TYPE_CONSTRAINT(bias_type) .TypeConstraint<bias_type>("Tbias") 2335 REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES( 2336 "_MklQuantizedConv2DWithBiasAndRequantize", MklQuantizedConv2DOp, qint8, 2337 qint8, true, false, true); 2338 REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES( 2339 "_MklQuantizedConv2DWithBiasAndReluAndRequantize", MklQuantizedConv2DReluOp, 2340 quint8, quint8, true, false, true); 2341 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( 2342 "_MklQuantizedConv2DWithBiasSumAndReluAndRequantize", 2343 MklQuantizedConv2DSumReluOp, quint8, quint8, quint8, true, false, true); 2344 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( 2345 "_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize", 2346 MklQuantizedConv2DSumReluOp, quint8, quint8, qint8, true, false, true); 2347 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( 2348 "_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", 2349 MklQuantizedConv2DReluOp, quint8, quint8, quint8, true, true, true); 2350 #undef BIAS_TYPE_CONSTRAINT 2351 #undef TEMPLATE_ARGS 2352 #undef LABEL 2353 2354 // Register NoOp kernel for ops that will be rewritten to the _Mkl* version 2355 2356 #define REGISTER_NO_OP_CPU_2D_DEPTHWISE(T) \ 2357 REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative") \ 2358 .Device(DEVICE_CPU) \ 2359 .TypeConstraint<T>("T"), \ 2360 NoOp); 2361 2362 TF_CALL_float(REGISTER_NO_OP_CPU_2D_DEPTHWISE); 2363 TF_CALL_bfloat16(REGISTER_NO_OP_CPU_2D_DEPTHWISE); 2364 2365 // Register 2D operations 2366 #define REGISTER_MKL_CPU_2D(T) \ 2367 REGISTER_KERNEL_BUILDER( \ 2368 Name("_MklConv2D") \ 2369 .Device(DEVICE_CPU) \ 2370 .TypeConstraint<T>("T") \ 2371 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2372 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \ 2373 REGISTER_KERNEL_BUILDER( \ 2374 Name("_MklConv2DWithBias") \ 2375 .Device(DEVICE_CPU) \ 2376 .TypeConstraint<T>("T") \ 2377 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2378 MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, false>); \ 2379 REGISTER_KERNEL_BUILDER( \ 2380 Name("__MklDummyConv2DWithBias") \ 2381 .Device(DEVICE_CPU) \ 2382 .TypeConstraint<T>("T") \ 2383 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2384 MklDummyOp<CPUDevice, T>); \ 2385 REGISTER_KERNEL_BUILDER( \ 2386 Name("_MklPadWithConv2D") \ 2387 .Device(DEVICE_CPU) \ 2388 .TypeConstraint<T>("T") \ 2389 .TypeConstraint<int32>("Tpaddings") \ 2390 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2391 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, false>); \ 2392 REGISTER_KERNEL_BUILDER( \ 2393 Name("_MklPadWithConv2D") \ 2394 .Device(DEVICE_CPU) \ 2395 .TypeConstraint<T>("T") \ 2396 .TypeConstraint<int64_t>("Tpaddings") \ 2397 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2398 MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, false>); \ 2399 REGISTER_KERNEL_BUILDER( \ 2400 Name("__MklDummyPadWithConv2D") \ 2401 .Device(DEVICE_CPU) \ 2402 .TypeConstraint<T>("T") \ 2403 .TypeConstraint<int32>("Tpaddings") \ 2404 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2405 MklDummyOp<CPUDevice, T>); \ 2406 REGISTER_KERNEL_BUILDER( \ 2407 Name("_MklNativeConv2D") \ 2408 .Device(DEVICE_CPU) \ 2409 .TypeConstraint<T>("T") \ 2410 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2411 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>); \ 2412 REGISTER_KERNEL_BUILDER( \ 2413 Name("_MklNativeConv2DWithBias") \ 2414 .Device(DEVICE_CPU) \ 2415 .TypeConstraint<T>("T") \ 2416 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2417 MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, true>); \ 2418 REGISTER_KERNEL_BUILDER( \ 2419 Name("_MklNativePadWithConv2D") \ 2420 .Device(DEVICE_CPU) \ 2421 .TypeConstraint<T>("T") \ 2422 .TypeConstraint<int32>("Tpaddings") \ 2423 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2424 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, true>); \ 2425 REGISTER_KERNEL_BUILDER( \ 2426 Name("_MklNativePadWithConv2D") \ 2427 .Device(DEVICE_CPU) \ 2428 .TypeConstraint<T>("T") \ 2429 .TypeConstraint<int64_t>("Tpaddings") \ 2430 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2431 MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, true>); 2432 2433 TF_CALL_float(REGISTER_MKL_CPU_2D); 2434 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D); 2435 2436 #define REGISTER_MKL_CPU_2D_DEPTHWISE(T) \ 2437 REGISTER_KERNEL_BUILDER( \ 2438 Name("_MklDepthwiseConv2dNative") \ 2439 .Device(DEVICE_CPU) \ 2440 .TypeConstraint<T>("T") \ 2441 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2442 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, false>); \ 2443 REGISTER_KERNEL_BUILDER( \ 2444 Name("_MklFusedDepthwiseConv2dNative") \ 2445 .Device(DEVICE_CPU) \ 2446 .TypeConstraint<T>("T") \ 2447 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2448 MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true, \ 2449 true, false>); \ 2450 REGISTER_KERNEL_BUILDER( \ 2451 Name("_MklNativeFusedDepthwiseConv2dNative") \ 2452 .Device(DEVICE_CPU) \ 2453 .TypeConstraint<T>("T") \ 2454 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2455 MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true, \ 2456 true, true>); \ 2457 REGISTER_KERNEL_BUILDER( \ 2458 Name("_MklNativeDepthwiseConv2dNative") \ 2459 .Device(DEVICE_CPU) \ 2460 .TypeConstraint<T>("T") \ 2461 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2462 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, true>); 2463 2464 TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE); 2465 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE); 2466 2467 // Note we are registering _MklFusedConv2D. 2468 // We check the fused_ops attributes to decide if bias is enabled or not. 2469 #define REGISTER_MKL_CPU_2D_FUSED(T) \ 2470 REGISTER_KERNEL_BUILDER( \ 2471 Name("_MklFusedConv2D") \ 2472 .Device(DEVICE_CPU) \ 2473 .TypeConstraint<T>("T") \ 2474 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2475 MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false, false>); \ 2476 REGISTER_KERNEL_BUILDER( \ 2477 Name("_MklPadWithFusedConv2D") \ 2478 .Device(DEVICE_CPU) \ 2479 .TypeConstraint<int32>("Tpaddings") \ 2480 .TypeConstraint<T>("T") \ 2481 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2482 MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true, false>); \ 2483 REGISTER_KERNEL_BUILDER( \ 2484 Name("_MklPadWithFusedConv2D") \ 2485 .Device(DEVICE_CPU) \ 2486 .TypeConstraint<T>("T") \ 2487 .TypeConstraint<int64_t>("Tpaddings") \ 2488 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2489 MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true, false>); \ 2490 REGISTER_KERNEL_BUILDER( \ 2491 Name("__MklDummyPadWithFusedConv2D") \ 2492 .Device(DEVICE_CPU) \ 2493 .TypeConstraint<T>("T") \ 2494 .TypeConstraint<int32>("Tpaddings") \ 2495 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2496 MklDummyOp<CPUDevice, T>); \ 2497 REGISTER_KERNEL_BUILDER( \ 2498 Name("_MklNativeFusedConv2D") \ 2499 .Device(DEVICE_CPU) \ 2500 .TypeConstraint<T>("T") \ 2501 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2502 MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false, true>); \ 2503 REGISTER_KERNEL_BUILDER( \ 2504 Name("_MklNativePadWithFusedConv2D") \ 2505 .Device(DEVICE_CPU) \ 2506 .TypeConstraint<int32>("Tpaddings") \ 2507 .TypeConstraint<T>("T") \ 2508 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2509 MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true, true>); \ 2510 REGISTER_KERNEL_BUILDER( \ 2511 Name("_MklNativePadWithFusedConv2D") \ 2512 .Device(DEVICE_CPU) \ 2513 .TypeConstraint<T>("T") \ 2514 .TypeConstraint<int64_t>("Tpaddings") \ 2515 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2516 MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true, true>); 2517 2518 TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED); 2519 TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED); 2520 2521 // Register 3D operations 2522 #define REGISTER_MKL_CPU_3D(T) \ 2523 REGISTER_KERNEL_BUILDER( \ 2524 Name("_MklConv3D") \ 2525 .Device(DEVICE_CPU) \ 2526 .TypeConstraint<T>("T") \ 2527 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 2528 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \ 2529 REGISTER_KERNEL_BUILDER( \ 2530 Name("_MklNativeConv3D") \ 2531 .Device(DEVICE_CPU) \ 2532 .TypeConstraint<T>("T") \ 2533 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2534 MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>); \ 2535 REGISTER_KERNEL_BUILDER( \ 2536 Name("_MklNativeFusedConv3D") \ 2537 .Device(DEVICE_CPU) \ 2538 .TypeConstraint<T>("T") \ 2539 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 2540 MklFusedConv3DOp<CPUDevice, T, T, T, T, T, int32, false, true>); 2541 TF_CALL_float(REGISTER_MKL_CPU_3D); 2542 TF_CALL_bfloat16(REGISTER_MKL_CPU_3D); 2543 2544 REGISTER_KERNEL_BUILDER( 2545 Name("_FusedConv3D").Device(DEVICE_CPU).TypeConstraint<float>("T"), NoOp); 2546 REGISTER_KERNEL_BUILDER( 2547 Name("_FusedConv3D").Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"), 2548 NoOp); 2549 } // namespace tensorflow 2550 #endif // INTEL_MKL 2551