1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifdef INTEL_MKL 16 17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 18 #include "dnnl.hpp" 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor_types.h" 23 #include "tensorflow/core/kernels/fused_batch_norm_op.h" 24 #include "tensorflow/core/kernels/no_op.h" 25 #include "tensorflow/core/util/mkl_util.h" 26 #include "tensorflow/core/util/tensor_format.h" 27 #ifdef DNNL_AARCH64_USE_ACL 28 #include "tensorflow/core/platform/mutex.h" 29 #endif 30 31 #define GET_FLAG(bn_flag) static_cast<int>(dnnl::normalization_flags::bn_flag) 32 #define IS_SET(cflag) (context_.flags & GET_FLAG(cflag)) 33 34 using dnnl::batch_normalization_backward; 35 using dnnl::batch_normalization_forward; 36 using dnnl::prop_kind; 37 using dnnl::stream; 38 39 using BatchNormFwdPd = dnnl::batch_normalization_forward::primitive_desc; 40 using BatchNormBwdPd = dnnl::batch_normalization_backward::primitive_desc; 41 42 namespace tensorflow { 43 using CPUDevice = Eigen::ThreadPoolDevice; 44 45 using FusedBNActivationMode = functor::FusedBatchNormActivationMode; 46 47 struct MklBatchNormFwdParams { 48 memory::dims src_dims; 49 int depth; 50 float eps; 51 bool training; 52 TensorFormat data_format; 53 FusedBNActivationMode activation_mode; 54 memory::desc src_md; 55 MklBatchNormFwdParamstensorflow::MklBatchNormFwdParams56 MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, 57 bool training, TensorFormat data_format, 58 memory::desc src_md, 59 FusedBNActivationMode activation_mode) 60 : src_dims(src_dims), 61 depth(depth), 62 eps(eps), 63 training(training), 64 data_format(data_format), 65 activation_mode(activation_mode), 66 src_md(src_md) {} 67 }; 68 69 template <typename T, typename U> 70 class MklFusedBatchNormFwdPrimitive : public MklPrimitive { 71 public: MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams & fwdParams)72 explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams) 73 : MklPrimitive(engine(engine::kind::cpu, 0)) { 74 if (context_.bn_fwd == nullptr) Setup(fwdParams); 75 } 76 ~MklFusedBatchNormFwdPrimitive()77 ~MklFusedBatchNormFwdPrimitive() {} 78 79 // BatchNormalization forward execute 80 // src_data: input data buffer of src 81 // weights_data: input data buffer of weights 82 // dst_data: output data buffer of dst 83 // mean_data: output data buffer of means 84 // variance_data: output data buffer of variances Execute(const T * src_data,const U * weights_data,T * dst_data,U * mean_data,U * variance_data,std::shared_ptr<stream> fwd_stream,U * workspace_data)85 void Execute(const T* src_data, const U* weights_data, T* dst_data, 86 U* mean_data, U* variance_data, 87 std::shared_ptr<stream> fwd_stream, U* workspace_data) { 88 #ifdef DNNL_AARCH64_USE_ACL 89 mutex_lock lock(primitive_execution_mu_); 90 #endif 91 #ifndef ENABLE_ONEDNN_OPENMP 92 // TODO(intel-tf): Create a common function and avoid the duplicate code 93 context_.src_mem->set_data_handle( 94 static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream); 95 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), 96 *fwd_stream); 97 98 if (IS_SET(use_scale_shift)) 99 context_.weights_mem->set_data_handle( 100 static_cast<void*>(const_cast<U*>(weights_data)), *fwd_stream); 101 102 if ((context_.pkind == prop_kind::forward_training) || 103 (IS_SET(use_global_stats))) { 104 context_.mean_mem->set_data_handle(static_cast<void*>(mean_data), 105 *fwd_stream); 106 context_.variance_mem->set_data_handle(static_cast<void*>(variance_data), 107 *fwd_stream); 108 } 109 if (workspace_data != nullptr) { 110 context_.ws_mem->set_data_handle(workspace_data, *fwd_stream); 111 } 112 #else 113 context_.src_mem->set_data_handle( 114 static_cast<void*>(const_cast<T*>(src_data))); 115 context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); 116 117 if (IS_SET(use_scale_shift)) 118 context_.weights_mem->set_data_handle( 119 static_cast<void*>(const_cast<U*>(weights_data))); 120 121 if ((context_.pkind == prop_kind::forward_training) || 122 (IS_SET(use_global_stats))) { 123 context_.mean_mem->set_data_handle(static_cast<void*>(mean_data)); 124 context_.variance_mem->set_data_handle(static_cast<void*>(variance_data)); 125 } 126 if (workspace_data != nullptr) { 127 context_.ws_mem->set_data_handle(workspace_data); 128 } 129 #endif // !ENABLE_ONEDNN_OPENMP 130 131 // Execute batch-normalization forward primitives. 132 execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); 133 134 context_.src_mem->set_data_handle(DummyData); 135 context_.dst_mem->set_data_handle(DummyData); 136 137 if (IS_SET(use_scale_shift)) 138 context_.weights_mem->set_data_handle(DummyData); 139 140 if ((context_.pkind == prop_kind::forward_training) || 141 (IS_SET(use_global_stats))) { 142 context_.mean_mem->set_data_handle(DummyData); 143 context_.variance_mem->set_data_handle(DummyData); 144 } 145 146 if (workspace_data != nullptr) { 147 context_.ws_mem->set_data_handle(DummyData); 148 } 149 } 150 GetDstPd() const151 memory::desc GetDstPd() const { return context_.dst_mem->get_desc(); } 152 GetBatchNormFwdPd() const153 std::shared_ptr<BatchNormFwdPd> GetBatchNormFwdPd() const { 154 return context_.fwd_pd; 155 } 156 157 private: 158 // Primitive reuse context for BatchNorm forward op. 159 struct BatchNormFwdContext { 160 // Flags indicating if it is training or inference mode. 161 int64 flags; 162 163 // Algorithm kind. 164 dnnl::prop_kind pkind; 165 166 // Inputs/outputs memory. 167 std::shared_ptr<dnnl::memory> src_mem; 168 std::shared_ptr<dnnl::memory> weights_mem; 169 std::shared_ptr<dnnl::memory> dst_mem; 170 std::shared_ptr<dnnl::memory> mean_mem; 171 std::shared_ptr<dnnl::memory> variance_mem; 172 std::shared_ptr<dnnl::memory> ws_mem; 173 174 // Forward BatchNorm primitive descriptor. 175 std::shared_ptr<BatchNormFwdPd> fwd_pd; 176 177 // BatchNorm forward primitive. 178 std::shared_ptr<dnnl::primitive> bn_fwd; 179 std::vector<dnnl::primitive> fwd_primitives; 180 181 std::vector<std::unordered_map<int, memory>> net_args; 182 BatchNormFwdContexttensorflow::MklFusedBatchNormFwdPrimitive::BatchNormFwdContext183 BatchNormFwdContext() 184 : flags(0), 185 pkind(prop_kind::forward_training), 186 src_mem(nullptr), 187 weights_mem(nullptr), 188 dst_mem(nullptr), 189 mean_mem(nullptr), 190 variance_mem(nullptr), 191 ws_mem(nullptr), 192 bn_fwd(nullptr) {} 193 }; 194 Setup(const MklBatchNormFwdParams & fwdParams)195 void Setup(const MklBatchNormFwdParams& fwdParams) { 196 context_.flags = 197 fwdParams.training 198 ? GET_FLAG(use_scale_shift) 199 : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats)); 200 context_.pkind = fwdParams.training ? prop_kind::forward_training 201 : prop_kind::forward_scoring; 202 203 if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) { 204 context_.flags |= GET_FLAG(fuse_norm_relu); 205 } 206 // Memory descriptor 207 auto src_md = fwdParams.src_md; 208 // Create forward BatchNorm descriptor and primitive descriptor. 209 auto fwd_desc = batch_normalization_forward::desc( 210 context_.pkind, src_md, fwdParams.eps, 211 static_cast<dnnl::normalization_flags>(context_.flags)); 212 213 context_.fwd_pd.reset(new BatchNormFwdPd(fwd_desc, cpu_engine_)); 214 215 // Create memory primitive based on dummy data 216 context_.src_mem.reset( 217 new memory(context_.fwd_pd->src_desc(), cpu_engine_, DummyData)); 218 context_.dst_mem.reset( 219 new memory(context_.fwd_pd->dst_desc(), cpu_engine_, DummyData)); 220 221 memory::dims s_dims = {2, fwdParams.depth}; 222 memory::dims m_dims = {1, fwdParams.depth}; 223 if (IS_SET(use_scale_shift)) { 224 context_.weights_mem.reset( 225 new memory({{s_dims}, MklDnnType<U>(), memory::format_tag::nc}, 226 cpu_engine_, DummyData)); 227 } 228 229 if (fwdParams.training || (IS_SET(use_global_stats))) { 230 context_.mean_mem.reset( 231 new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc}, 232 cpu_engine_, DummyData)); 233 234 context_.variance_mem.reset( 235 new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc}, 236 cpu_engine_, DummyData)); 237 } 238 239 if (IS_SET(fuse_norm_relu)) { 240 context_.ws_mem.reset(new memory(context_.fwd_pd->workspace_desc(), 241 cpu_engine_, DummyData)); 242 } 243 244 // BatchNorm forward primitive. 245 // TODO(intel-tf): Merge all the #ifdefs and simplify code 246 if (!fwdParams.training && !(IS_SET(use_global_stats))) { 247 if (IS_SET(use_scale_shift)) { 248 context_.net_args.push_back({{DNNL_ARG_SRC, *context_.src_mem}, 249 {DNNL_ARG_WEIGHTS, *context_.weights_mem}, 250 {DNNL_ARG_DST, *context_.dst_mem}}); 251 } else { 252 context_.net_args.push_back({{DNNL_ARG_SRC, *context_.src_mem}, 253 {DNNL_ARG_DST, *context_.dst_mem}}); 254 } 255 context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); 256 } else if (IS_SET(use_global_stats)) { 257 if (IS_SET(use_scale_shift)) { 258 if (IS_SET(fuse_norm_relu)) { 259 context_.net_args.push_back( 260 {{DNNL_ARG_SRC, *context_.src_mem}, 261 {DNNL_ARG_MEAN, *context_.mean_mem}, 262 {DNNL_ARG_VARIANCE, *context_.variance_mem}, 263 {DNNL_ARG_WEIGHTS, *context_.weights_mem}, 264 {DNNL_ARG_DST, *context_.dst_mem}, 265 {DNNL_ARG_WORKSPACE, *context_.ws_mem}}); 266 } else { 267 context_.net_args.push_back( 268 {{DNNL_ARG_SRC, *context_.src_mem}, 269 {DNNL_ARG_MEAN, *context_.mean_mem}, 270 {DNNL_ARG_VARIANCE, *context_.variance_mem}, 271 {DNNL_ARG_WEIGHTS, *context_.weights_mem}, 272 {DNNL_ARG_DST, *context_.dst_mem}}); 273 } 274 } else { 275 if (IS_SET(fuse_norm_relu)) { 276 context_.net_args.push_back( 277 {{DNNL_ARG_SRC, *context_.src_mem}, 278 {DNNL_ARG_MEAN, *context_.mean_mem}, 279 {DNNL_ARG_VARIANCE, *context_.variance_mem}, 280 {DNNL_ARG_DST, *context_.dst_mem}, 281 {DNNL_ARG_WORKSPACE, *context_.ws_mem}}); 282 } else { 283 context_.net_args.push_back( 284 {{DNNL_ARG_SRC, *context_.src_mem}, 285 {DNNL_ARG_MEAN, *context_.mean_mem}, 286 {DNNL_ARG_VARIANCE, *context_.variance_mem}, 287 {DNNL_ARG_DST, *context_.dst_mem}}); 288 } 289 } 290 context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); 291 } else { 292 if (IS_SET(use_scale_shift)) { 293 if (IS_SET(fuse_norm_relu)) { 294 context_.net_args.push_back( 295 {{DNNL_ARG_SRC, *context_.src_mem}, 296 {DNNL_ARG_WEIGHTS, *context_.weights_mem}, 297 {DNNL_ARG_DST, *context_.dst_mem}, 298 {DNNL_ARG_MEAN, *context_.mean_mem}, 299 {DNNL_ARG_VARIANCE, *context_.variance_mem}, 300 {DNNL_ARG_WORKSPACE, *context_.ws_mem}}); 301 } else { 302 context_.net_args.push_back( 303 {{DNNL_ARG_SRC, *context_.src_mem}, 304 {DNNL_ARG_WEIGHTS, *context_.weights_mem}, 305 {DNNL_ARG_DST, *context_.dst_mem}, 306 {DNNL_ARG_MEAN, *context_.mean_mem}, 307 {DNNL_ARG_VARIANCE, *context_.variance_mem}}); 308 } 309 } else { 310 if (IS_SET(fuse_norm_relu)) { 311 context_.net_args.push_back( 312 {{DNNL_ARG_SRC, *context_.src_mem}, 313 {DNNL_ARG_DST, *context_.dst_mem}, 314 {DNNL_ARG_MEAN, *context_.mean_mem}, 315 {DNNL_ARG_VARIANCE, *context_.variance_mem}, 316 {DNNL_ARG_WORKSPACE, *context_.ws_mem}}); 317 } else { 318 context_.net_args.push_back( 319 {{DNNL_ARG_SRC, *context_.src_mem}, 320 {DNNL_ARG_DST, *context_.dst_mem}, 321 {DNNL_ARG_MEAN, *context_.mean_mem}, 322 {DNNL_ARG_VARIANCE, *context_.variance_mem}}); 323 } 324 } 325 context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); 326 } 327 328 context_.fwd_primitives.push_back(*context_.bn_fwd); 329 } 330 331 struct BatchNormFwdContext context_; 332 333 #ifdef DNNL_AARCH64_USE_ACL 334 mutex primitive_execution_mu_; 335 #endif 336 }; 337 338 template <typename T, typename U> 339 class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> { 340 public: Get(const MklBatchNormFwdParams & fwdParams)341 static MklFusedBatchNormFwdPrimitive<T, U>* Get( 342 const MklBatchNormFwdParams& fwdParams) { 343 auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T, U>*>( 344 MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance() 345 .GetBatchNormFwd(fwdParams)); 346 347 if (bn_fwd == nullptr) { 348 bn_fwd = new MklFusedBatchNormFwdPrimitive<T, U>(fwdParams); 349 MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormFwd( 350 fwdParams, bn_fwd); 351 } 352 return bn_fwd; 353 } 354 GetInstance()355 static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() { 356 static MklFusedBatchNormFwdPrimitiveFactory instance_; 357 return instance_; 358 } 359 360 private: MklFusedBatchNormFwdPrimitiveFactory()361 MklFusedBatchNormFwdPrimitiveFactory() {} ~MklFusedBatchNormFwdPrimitiveFactory()362 ~MklFusedBatchNormFwdPrimitiveFactory() {} 363 CreateKey(const MklBatchNormFwdParams & fwdParams)364 static string CreateKey(const MklBatchNormFwdParams& fwdParams) { 365 string prefix = "bn_fwd"; 366 FactoryKeyCreator key_creator; 367 key_creator.AddAsKey(prefix); 368 key_creator.AddAsKey(fwdParams.src_dims); 369 key_creator.AddAsKey<int>(fwdParams.depth); 370 key_creator.AddAsKey<float>(fwdParams.eps); 371 key_creator.AddAsKey<bool>(fwdParams.training); 372 key_creator.AddAsKey<TensorFormat>(fwdParams.data_format); 373 key_creator.AddAsKey<FusedBNActivationMode>(fwdParams.activation_mode); 374 key_creator.AddAsKey(typeid(T).name()); 375 key_creator.AddAsKey(typeid(U).name()); 376 return key_creator.GetKey(); 377 } 378 GetBatchNormFwd(const MklBatchNormFwdParams & fwdParams)379 MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) { 380 string key = CreateKey(fwdParams); 381 return this->GetOp(key); 382 } 383 SetBatchNormFwd(const MklBatchNormFwdParams & fwdParams,MklPrimitive * op)384 void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams, 385 MklPrimitive* op) { 386 string key = CreateKey(fwdParams); 387 this->SetOp(key, op); 388 } 389 }; 390 391 struct MklBatchNormBwdParams { 392 memory::dims src_dims; 393 memory::dims diff_dst_dims; 394 int depth; 395 float eps; 396 bool training; 397 TensorFormat data_format; 398 memory::desc src_md; 399 memory::desc diff_dst_md; 400 MklBatchNormBwdParamstensorflow::MklBatchNormBwdParams401 MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, 402 int depth, float eps, bool training, 403 TensorFormat data_format, memory::desc src_md, 404 memory::desc diff_dst_md) 405 : src_dims(src_dims), 406 diff_dst_dims(diff_dst_dims), 407 depth(depth), 408 eps(eps), 409 training(training), 410 data_format(data_format), 411 src_md(src_md), 412 diff_dst_md(diff_dst_md) {} 413 }; 414 415 template <typename T, typename U> 416 class MklFusedBatchNormBwdPrimitive : public MklPrimitive { 417 public: MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams & bwdParams)418 explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams) 419 : MklPrimitive(engine(engine::kind::cpu, 0)) { 420 if (context_.bn_bwd == nullptr) Setup(bwdParams); 421 } 422 ~MklFusedBatchNormBwdPrimitive()423 ~MklFusedBatchNormBwdPrimitive() {} 424 425 // BatchNormalization backward execute 426 // src_data: input data buffer of src 427 // mean_data: input data buffer of mean 428 // variance_data: input data buffer of variance 429 // diff_dst_data: input data buffer of diff_dst 430 // weights_data: input data buffer of weights 431 // diff_src_data: output data buffer of diff_src 432 // diff_weights_data: output data buffer of diff_weights 433 // res_space_data: output data buffer or reserved_space_3. 434 // TODO: reserved_space_3: temp mem to hold 435 // intermediate results is not implemented 436 // on CPU as of now. Execute(const T * src_data,const U * mean_data,const U * variance_data,const T * diff_dst_data,const U * weights_data,T * diff_src_data,U * diff_weights_data,U * res_space_data,std::shared_ptr<stream> bwd_stream)437 void Execute(const T* src_data, const U* mean_data, const U* variance_data, 438 const T* diff_dst_data, const U* weights_data, T* diff_src_data, 439 U* diff_weights_data, U* res_space_data, 440 std::shared_ptr<stream> bwd_stream) { 441 #ifdef DNNL_AARCH64_USE_ACL 442 mutex_lock lock(primitive_execution_mu_); 443 #endif 444 #ifndef ENABLE_ONEDNN_OPENMP 445 // TODO(intel-tf): Create a common function and avoid the duplicate code 446 context_.src_mem->set_data_handle( 447 static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream); 448 context_.mean_mem->set_data_handle( 449 static_cast<void*>(const_cast<U*>(mean_data)), *bwd_stream); 450 context_.variance_mem->set_data_handle( 451 static_cast<void*>(const_cast<U*>(variance_data)), *bwd_stream); 452 context_.diff_dst_mem->set_data_handle( 453 static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream); 454 455 if (IS_SET(use_scale_shift)) { 456 context_.weights_mem->set_data_handle( 457 static_cast<void*>(const_cast<U*>(weights_data)), *bwd_stream); 458 context_.diff_weights_mem->set_data_handle( 459 static_cast<void*>(diff_weights_data), *bwd_stream); 460 } 461 462 context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data), 463 *bwd_stream); 464 #else 465 context_.src_mem->set_data_handle( 466 static_cast<void*>(const_cast<T*>(src_data))); 467 context_.mean_mem->set_data_handle( 468 static_cast<void*>(const_cast<U*>(mean_data))); 469 context_.variance_mem->set_data_handle( 470 static_cast<void*>(const_cast<U*>(variance_data))); 471 context_.diff_dst_mem->set_data_handle( 472 static_cast<void*>(const_cast<T*>(diff_dst_data))); 473 474 if (IS_SET(use_scale_shift)) { 475 context_.weights_mem->set_data_handle( 476 static_cast<void*>(const_cast<U*>(weights_data))); 477 context_.diff_weights_mem->set_data_handle( 478 static_cast<void*>(diff_weights_data)); 479 } 480 481 context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); 482 #endif // !ENABLE_ONEDNN_OPENMP 483 // Execute backward batch-normalization primitives. 484 DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size()); 485 execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args); 486 487 // After execution, set data handle back to DummyData. 488 context_.src_mem->set_data_handle(DummyData); 489 context_.mean_mem->set_data_handle(DummyData); 490 context_.variance_mem->set_data_handle(DummyData); 491 context_.diff_dst_mem->set_data_handle(DummyData); 492 if (IS_SET(use_scale_shift)) { 493 context_.weights_mem->set_data_handle(DummyData); 494 context_.diff_weights_mem->set_data_handle(DummyData); 495 } 496 context_.diff_src_mem->set_data_handle(DummyData); 497 } 498 GetBatchNormBwdPd() const499 std::shared_ptr<BatchNormBwdPd> GetBatchNormBwdPd() const { 500 return context_.bwd_pd; 501 } 502 GetDiffSrcPd()503 memory::desc GetDiffSrcPd() { return context_.diff_src_mem->get_desc(); } 504 505 private: 506 struct BatchNormBwdContext { 507 // Flags to indicate whether it is training or inference. 508 int64 flags; 509 510 // Inputs/output memory. 511 std::shared_ptr<dnnl::memory> src_mem; 512 std::shared_ptr<dnnl::memory> mean_mem; 513 std::shared_ptr<dnnl::memory> variance_mem; 514 std::shared_ptr<dnnl::memory> diff_dst_mem; 515 std::shared_ptr<dnnl::memory> weights_mem; 516 std::shared_ptr<dnnl::memory> diff_weights_mem; 517 std::shared_ptr<dnnl::memory> diff_src_mem; 518 519 // Backward batch-normalization primitive descriptor. 520 std::shared_ptr<BatchNormBwdPd> bwd_pd; 521 522 // Backward batch-normalization primitive. 523 std::shared_ptr<dnnl::primitive> bn_bwd; 524 std::vector<dnnl::primitive> bwd_primitives; 525 526 std::vector<std::unordered_map<int, memory>> net_args; 527 BatchNormBwdContexttensorflow::MklFusedBatchNormBwdPrimitive::BatchNormBwdContext528 BatchNormBwdContext() 529 : src_mem(nullptr), 530 mean_mem(nullptr), 531 variance_mem(nullptr), 532 diff_dst_mem(nullptr), 533 weights_mem(nullptr), 534 diff_weights_mem(nullptr), 535 diff_src_mem(nullptr) {} 536 }; 537 Setup(const MklBatchNormBwdParams & bwdParams)538 void Setup(const MklBatchNormBwdParams& bwdParams) { 539 context_.flags = 540 bwdParams.training 541 ? GET_FLAG(use_scale_shift) 542 : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats)); 543 544 // Memory descriptors. 545 auto src_md = bwdParams.src_md; 546 auto diff_dst_md = bwdParams.diff_dst_md; 547 auto variance_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(), 548 memory::format_tag::nc); 549 auto mean_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(), 550 memory::format_tag::nc); 551 auto weights_desc = memory::desc({2, bwdParams.depth}, MklDnnType<U>(), 552 memory::format_tag::nc); 553 auto diff_weights_desc = weights_desc; 554 555 // Forward batch-normalization descriptor and primitive descriptor. 556 // Adding this back due to type difference with context.flags 557 auto bn_flags = bwdParams.training 558 ? dnnl::normalization_flags::use_scale_shift 559 : (dnnl::normalization_flags::use_scale_shift | 560 dnnl::normalization_flags::use_global_stats); 561 auto fwd_desc = batch_normalization_forward::desc( 562 prop_kind::forward_training, src_md, bwdParams.eps, bn_flags); 563 auto fwd_pd = BatchNormFwdPd(fwd_desc, cpu_engine_); 564 565 // Backward batch-normalization primitive. 566 // For inference, specify use_global_stats 567 // 1. on fwd propagation, use mean and variance provided as inputs. 568 // 2. on bwd propagation, mean and variance are considered as constants. 569 // Thus, reduce the amount of MKL computation. 570 auto bwd_desc = batch_normalization_backward::desc( 571 prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, bn_flags); 572 context_.bwd_pd.reset(new BatchNormBwdPd(bwd_desc, cpu_engine_, fwd_pd)); 573 574 // Create memory primitives. 575 context_.src_mem.reset(new memory(src_md, cpu_engine_, DummyData)); 576 context_.diff_dst_mem.reset( 577 new memory(diff_dst_md, cpu_engine_, DummyData)); 578 context_.variance_mem.reset( 579 new memory(variance_desc, cpu_engine_, DummyData)); 580 context_.mean_mem.reset(new memory(mean_desc, cpu_engine_, DummyData)); 581 context_.weights_mem.reset( 582 new memory(weights_desc, cpu_engine_, DummyData)); 583 context_.diff_weights_mem.reset( 584 new memory(diff_weights_desc, cpu_engine_, DummyData)); 585 context_.diff_src_mem.reset(new memory(src_md, cpu_engine_, DummyData)); 586 587 context_.bn_bwd.reset(new batch_normalization_backward(*context_.bwd_pd)); 588 context_.net_args.push_back( 589 {{DNNL_ARG_SRC, *context_.src_mem}, 590 {DNNL_ARG_MEAN, *context_.mean_mem}, 591 {DNNL_ARG_VARIANCE, *context_.variance_mem}, 592 {DNNL_ARG_DIFF_DST, *context_.diff_dst_mem}, 593 {DNNL_ARG_WEIGHTS, *context_.weights_mem}, 594 {DNNL_ARG_DIFF_SRC, *context_.diff_src_mem}, 595 {DNNL_ARG_DIFF_WEIGHTS, *context_.diff_weights_mem}}); 596 context_.bwd_primitives.push_back(*context_.bn_bwd); 597 } 598 599 struct BatchNormBwdContext context_; 600 601 #ifdef DNNL_AARCH64_USE_ACL 602 mutex primitive_execution_mu_; 603 #endif 604 }; 605 606 template <typename T, typename U> 607 class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> { 608 public: Get(const MklBatchNormBwdParams & bwdParams)609 static MklFusedBatchNormBwdPrimitive<T, U>* Get( 610 const MklBatchNormBwdParams& bwdParams) { 611 auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T, U>*>( 612 MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance() 613 .GetBatchNormBwd(bwdParams)); 614 if (bn_bwd == nullptr) { 615 bn_bwd = new MklFusedBatchNormBwdPrimitive<T, U>(bwdParams); 616 MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormBwd( 617 bwdParams, bn_bwd); 618 } 619 return bn_bwd; 620 } 621 GetInstance()622 static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() { 623 static MklFusedBatchNormBwdPrimitiveFactory instance_; 624 return instance_; 625 } 626 627 private: MklFusedBatchNormBwdPrimitiveFactory()628 MklFusedBatchNormBwdPrimitiveFactory() {} ~MklFusedBatchNormBwdPrimitiveFactory()629 ~MklFusedBatchNormBwdPrimitiveFactory() {} 630 CreateKey(const MklBatchNormBwdParams & bwdParams)631 static string CreateKey(const MklBatchNormBwdParams& bwdParams) { 632 string prefix = "bn_bwd"; 633 FactoryKeyCreator key_creator; 634 key_creator.AddAsKey(prefix); 635 key_creator.AddAsKey(bwdParams.src_dims); 636 key_creator.AddAsKey(bwdParams.diff_dst_dims); 637 key_creator.AddAsKey<int>(bwdParams.depth); 638 key_creator.AddAsKey<float>(bwdParams.eps); 639 key_creator.AddAsKey<bool>(bwdParams.training); 640 key_creator.AddAsKey<TensorFormat>(bwdParams.data_format); 641 key_creator.AddAsKey(typeid(T).name()); 642 key_creator.AddAsKey(typeid(U).name()); 643 return key_creator.GetKey(); 644 } 645 GetBatchNormBwd(const MklBatchNormBwdParams & bwdParams)646 MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) { 647 string key = CreateKey(bwdParams); 648 return this->GetOp(key); 649 } 650 SetBatchNormBwd(const MklBatchNormBwdParams & bwdParams,MklPrimitive * op)651 void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams, 652 MklPrimitive* op) { 653 string key = CreateKey(bwdParams); 654 this->SetOp(key, op); 655 } 656 }; 657 658 // Adding a third parameter to the template to support FusedBatchNormV3 659 // with MKL. This is different from default where the classes are 660 // derived. Moves enabling to compile-time rather than runtime. 661 template <typename Device, typename T, typename U, bool reserved_space, 662 bool is_batch_norm_ex = false, bool native_format = false> 663 class MklFusedBatchNormOp : public OpKernel { 664 public: MklFusedBatchNormOp(OpKernelConstruction * context)665 explicit MklFusedBatchNormOp(OpKernelConstruction* context) 666 : OpKernel(context) { 667 float epsilon; 668 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 669 epsilon_ = epsilon; 670 float exponential_avg_factor; 671 OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor", 672 &exponential_avg_factor)); 673 exponential_avg_factor_ = static_cast<U>(exponential_avg_factor); 674 string tensor_format; 675 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 676 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 677 errors::InvalidArgument("Invalid data format")); 678 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 679 depth_ = 0; 680 mean_values_ = nullptr; 681 variance_values_ = nullptr; 682 683 if (!is_batch_norm_ex) { 684 activation_mode_ = FusedBNActivationMode::kIdentity; 685 } else { 686 int num_side_inputs; 687 OP_REQUIRES_OK(context, 688 context->GetAttr("num_side_inputs", &num_side_inputs)); 689 // Currently _MKLFusedBatchNormEx do not support "SideInput" 690 OP_REQUIRES(context, num_side_inputs == 0, 691 errors::InvalidArgument( 692 "_MKLFusedBatchNorm do not support side input now.")); 693 694 OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_)); 695 OP_REQUIRES(context, activation_mode_ == FusedBNActivationMode::kRelu, 696 errors::InvalidArgument( 697 "_MKLFusedBatchNorm only support Relu activation")); 698 } 699 } 700 Compute(OpKernelContext * context)701 void Compute(OpKernelContext* context) override { 702 try { 703 const size_t kSrcIndex = 0; // index of src input tensor 704 const size_t kScaleIndex = 1; // index of scale tensor 705 const size_t kShiftIndex = 2; // index of shift tensor 706 const size_t kMeanIndex = 3; // index of est_mean tensor 707 const size_t kVarianceIndex = 4; // index of est_variance tensor 708 709 const Tensor& src_tensor = MklGetInput(context, kSrcIndex); 710 const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); 711 const Tensor& shift_tensor = MklGetInput(context, kShiftIndex); 712 const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex); 713 const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex); 714 715 TensorShape tf_shape_src; 716 MklDnnShape dnn_shape_src; 717 GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format); 718 719 if (dnn_shape_src.IsMklTensor()) { 720 tf_shape_src = dnn_shape_src.GetTfShape(); 721 OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, 722 errors::InvalidArgument("input must be 4-dimensional", 723 src_tensor.shape().DebugString())); 724 } else { 725 tf_shape_src = src_tensor.shape(); 726 OP_REQUIRES(context, src_tensor.dims() == 4, 727 errors::InvalidArgument("input must be 4-dimensional", 728 src_tensor.shape().DebugString())); 729 } 730 OP_REQUIRES(context, scale_tensor.dims() == 1, 731 errors::InvalidArgument("scale must be 1-dimensional", 732 scale_tensor.shape().DebugString())); 733 OP_REQUIRES(context, shift_tensor.dims() == 1, 734 errors::InvalidArgument("offset must be 1-dimensional", 735 shift_tensor.shape().DebugString())); 736 OP_REQUIRES( 737 context, est_mean_tensor.dims() == 1, 738 errors::InvalidArgument("estimated_mean must be 1-dimensional", 739 est_mean_tensor.shape().DebugString())); 740 OP_REQUIRES( 741 context, est_variance_tensor.dims() == 1, 742 errors::InvalidArgument("estimated_variance must be 1-dimensional", 743 est_variance_tensor.shape().DebugString())); 744 745 int num_channels; 746 if (dnn_shape_src.IsMklTensor()) { 747 num_channels = dnn_shape_src.DimSize(MklDnnDims::Dim_C); 748 } else { 749 num_channels = GetTensorDim(src_tensor, tensor_format_, 'C'); 750 } 751 752 OP_REQUIRES(context, scale_tensor.NumElements() == num_channels, 753 errors::InvalidArgument( 754 "scale must have the same number of elements " 755 "as the channels of x, got ", 756 scale_tensor.NumElements(), " and ", num_channels)); 757 758 OP_REQUIRES(context, shift_tensor.NumElements() == num_channels, 759 errors::InvalidArgument( 760 "offset must have the same number of elements " 761 "as the channels of x, got ", 762 shift_tensor.NumElements(), " and ", num_channels)); 763 if (!is_training_ || exponential_avg_factor_ != 1.) { 764 std::string prefix_msg = is_training_ 765 ? "When exponential_avg_factor != 1" 766 : "When is_training=false"; 767 OP_REQUIRES(context, est_mean_tensor.NumElements() == num_channels, 768 errors::InvalidArgument( 769 prefix_msg, 770 ", mean must have the same number " 771 "of elements as the channels of x, got ", 772 est_mean_tensor.NumElements(), " and ", num_channels)); 773 OP_REQUIRES( 774 context, est_variance_tensor.NumElements() == num_channels, 775 errors::InvalidArgument( 776 prefix_msg, 777 ", variance must have the same " 778 "number of elements as the channels of x, got ", 779 est_variance_tensor.NumElements(), " and ", num_channels)); 780 } 781 782 // Handle the special case: input with 0 element and 0 batch size. 783 Tensor* dst_tensor = nullptr; 784 TensorShape workspace_tf_shape; 785 if (tf_shape_src.num_elements() == 0) { 786 size_t workspace_bytes = 0; 787 workspace_tf_shape.AddDim(workspace_bytes); 788 HandleEmptyInput(context, tf_shape_src, workspace_tf_shape, 789 scale_tensor.shape(), &dst_tensor); 790 return; 791 } 792 793 if (dnn_shape_src.IsMklTensor()) 794 depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); 795 else 796 ExtractParams(context); 797 798 // Index of output tensor(diff_src). 799 const size_t kDstIndex = 0; 800 801 // Allocate 5 output TF tensors. 802 Tensor* batch_mean_tensor = nullptr; 803 Tensor* batch_variance_tensor = nullptr; 804 Tensor* saved_mean_tensor = nullptr; 805 Tensor* saved_variance_tensor = nullptr; 806 Tensor* reserved_space_tensor = nullptr; 807 808 MklDnnData<T> src(&cpu_engine_); 809 MklDnnData<U> weights(&cpu_engine_); 810 MklDnnData<U> wksp(&cpu_engine_); 811 812 memory::format_tag dnn_fmt; 813 MklTensorFormat mkl_tensor_fmt; 814 if (dnn_shape_src.IsMklTensor()) { 815 if (dnn_shape_src.IsTensorInNCHWFormat()) { 816 dnn_fmt = memory::format_tag::nchw; 817 mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW; 818 } else { 819 dnn_fmt = memory::format_tag::nhwc; 820 mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC; 821 } 822 } else { 823 mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_); 824 dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt); 825 } 826 827 // Set src memory descriptor. 828 memory::dims src_dims = 829 dnn_shape_src.IsMklTensor() 830 ? dnn_shape_src.GetSizesAsMklDnnDims() 831 : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); 832 833 auto src_md = dnn_shape_src.IsMklTensor() 834 ? dnn_shape_src.GetMklLayout() 835 : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt); 836 837 MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, 838 tensor_format_, src_md, activation_mode_); 839 840 // Get forward batch-normalization op from the primitive caching pool. 841 MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd = 842 MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams); 843 844 // Allocate workspace tensor 845 U* ws_data = nullptr; 846 if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) { 847 memory::desc workspace_md = 848 bn_fwd->GetBatchNormFwdPd()->workspace_desc(); 849 size_t workspace_bytes = workspace_md.get_size(); 850 workspace_tf_shape.AddDim(workspace_bytes); 851 852 AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape, 853 &batch_mean_tensor, &batch_variance_tensor, 854 &saved_mean_tensor, &saved_variance_tensor, 855 &reserved_space_tensor); 856 if (reserved_space) { 857 wksp.SetUsrMem(workspace_md, reserved_space_tensor); 858 ws_data = static_cast<U*>(wksp.GetOpMem().get_data_handle()); 859 } 860 } else { 861 // There is actually no workspace tensor out, so we make a dummy one. 862 size_t workspace_bytes = 0; 863 workspace_tf_shape.AddDim(workspace_bytes); 864 AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape, 865 &batch_mean_tensor, &batch_variance_tensor, 866 &saved_mean_tensor, &saved_variance_tensor, 867 &reserved_space_tensor); 868 } 869 870 if (is_training_) 871 SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor); 872 else 873 SetMeanVariance(est_mean_tensor, est_variance_tensor); 874 875 // oneDNN packs scale & shift as "weights": 876 // <scale>...<scale><shift>...<shift> 877 weights.AllocateBuffer(2 * depth_ * sizeof(U)); 878 U* weights_data = reinterpret_cast<U*>(weights.GetAllocatedBuffer()); 879 const U* scale_tf = scale_tensor.flat<U>().data(); 880 const U* shift_tf = shift_tensor.flat<U>().data(); 881 882 std::memcpy(weights_data, scale_tf, depth_ * sizeof(U)); 883 std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(U)); 884 char* saved_mean_data_tf = 885 reinterpret_cast<char*>(saved_mean_tensor->flat<U>().data()); 886 std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_), 887 depth_ * sizeof(U)); 888 889 char* saved_variance_data_tf = 890 reinterpret_cast<char*>(saved_variance_tensor->flat<U>().data()); 891 std::memcpy(saved_variance_data_tf, 892 reinterpret_cast<char*>(variance_values_), 893 depth_ * sizeof(U)); 894 895 // Check if reorder is needed for src. 896 const T* src_data = nullptr; 897 std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd(); 898 if (!native_format && src_md != bn_fwd_pd->src_desc()) { 899 src.SetUsrMem(src_md, &src_tensor); 900 src.CheckReorderToOpMem(bn_fwd_pd->src_desc(), cpu_engine_, context); 901 src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); 902 } else { 903 src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); 904 } 905 906 // Allocate output (dst) tensor 907 MklDnnShape dnn_shape_dst; 908 TensorShape tf_shape_dst; 909 dnn_shape_dst.SetMklTensor(true); 910 auto dst_pd = bn_fwd->GetDstPd(); 911 dnn_shape_dst.SetMklLayout(&dst_pd); 912 dnn_shape_dst.SetElemType(MklDnnType<T>()); 913 auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension() 914 : src_tensor.shape().dims(); 915 dnn_shape_dst.SetTfLayout(ndims, src_dims, mkl_tensor_fmt); 916 tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); 917 if (native_format) { 918 tf_shape_dst = dnn_shape_dst.GetTfShape(); 919 } 920 AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, 921 dnn_shape_dst, native_format); 922 923 U* weights_op_data = weights_data; 924 U* mean_op_data = saved_mean_tensor->flat<U>().data(); 925 U* variance_op_data = saved_variance_tensor->flat<U>().data(); 926 T* dst_data = dst_tensor->flat<T>().data(); 927 928 // Execute 929 std::shared_ptr<stream> fwd_cpu_stream; 930 MklDnnThreadPool eigen_tp(context); 931 fwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_fwd->GetEngine())); 932 bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data, 933 variance_op_data, fwd_cpu_stream, ws_data); 934 float adjust_factor = 1.0; 935 if (is_training_) { 936 size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3]; 937 size_t adjust_size = (orig_size > 1) ? (orig_size - 1) : 1; 938 adjust_factor = (static_cast<float>(orig_size)) / adjust_size; 939 } 940 941 auto mean_data = reinterpret_cast<U*>(saved_mean_data_tf); 942 auto variance_data = reinterpret_cast<U*>(saved_variance_data_tf); 943 auto batch_mean_data = batch_mean_tensor->flat<U>().data(); 944 auto batch_variance_data = batch_variance_tensor->flat<U>().data(); 945 auto est_mean_data = est_mean_tensor.flat<U>().data(); 946 auto est_variance_data = est_variance_tensor.flat<U>().data(); 947 if (is_training_) { 948 if (exponential_avg_factor_ == U(1.0)) { 949 for (int k = 0; k < depth_; k++) { 950 batch_mean_data[k] = mean_data[k]; 951 batch_variance_data[k] = 952 static_cast<U>(adjust_factor) * variance_data[k]; 953 } 954 } else { 955 U one_minus_factor = U(1.0) - exponential_avg_factor_; 956 for (int k = 0; k < depth_; k++) { 957 batch_mean_data[k] = one_minus_factor * est_mean_data[k] + 958 exponential_avg_factor_ * mean_data[k]; 959 batch_variance_data[k] = one_minus_factor * est_variance_data[k] + 960 exponential_avg_factor_ * 961 static_cast<U>(adjust_factor) * 962 variance_data[k]; 963 } 964 } 965 } else { 966 std::memcpy(batch_mean_data, mean_data, depth_ * sizeof(U)); 967 std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(U)); 968 } 969 } catch (dnnl::error& e) { 970 string error_msg = "Status: " + std::to_string(e.status) + 971 ", message: " + string(e.message) + ", in file " + 972 string(__FILE__) + ":" + std::to_string(__LINE__); 973 OP_REQUIRES_OK( 974 context, 975 errors::Aborted("Operation received an exception:", error_msg)); 976 } 977 } 978 979 private: 980 float epsilon_; 981 U exponential_avg_factor_; 982 TensorFormat tensor_format_; 983 bool is_training_; 984 U* mean_values_; 985 U* variance_values_; 986 size_t depth_; // Batch normalization is performed for per channel. 987 FusedBNActivationMode activation_mode_; 988 engine cpu_engine_ = engine(engine::kind::cpu, 0); 989 ExtractParams(OpKernelContext * context)990 void ExtractParams(OpKernelContext* context) { 991 const Tensor& input = MklGetInput(context, 0); 992 depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); 993 } 994 SetMeanVariance(const Tensor & mean,const Tensor & variance)995 void SetMeanVariance(const Tensor& mean, const Tensor& variance) { 996 mean_values_ = reinterpret_cast<U*>(const_cast<U*>(mean.flat<U>().data())); 997 variance_values_ = 998 reinterpret_cast<U*>(const_cast<U*>(variance.flat<U>().data())); 999 } 1000 HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape workspace_tf_shape,TensorShape tf_shape_scale,Tensor ** dst_tensor)1001 void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, 1002 TensorShape workspace_tf_shape, 1003 TensorShape tf_shape_scale, Tensor** dst_tensor) { 1004 DCHECK(dst_tensor); 1005 1006 const size_t kDstIndex = 0; 1007 MklDnnShape dnn_shape_dst; 1008 dnn_shape_dst.SetMklTensor(false); 1009 AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src, 1010 dnn_shape_dst, native_format); 1011 DCHECK(*dst_tensor); 1012 memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0, 1013 (*dst_tensor)->tensor_data().size()); 1014 1015 Tensor* batch_mean_tensor = nullptr; 1016 Tensor* batch_variance_tensor = nullptr; 1017 Tensor* saved_mean_tensor = nullptr; 1018 Tensor* saved_variance_tensor = nullptr; 1019 Tensor* reserved_space_tensor = nullptr; 1020 AllocateTFOutputs(context, tf_shape_scale, workspace_tf_shape, 1021 &batch_mean_tensor, &batch_variance_tensor, 1022 &saved_mean_tensor, &saved_variance_tensor, 1023 &reserved_space_tensor); 1024 } 1025 AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale,TensorShape workspace_tf_shape,Tensor ** batch_mean_tensor,Tensor ** batch_variance_tensor,Tensor ** saved_mean_tensor,Tensor ** saved_variance_tensor,Tensor ** reserved_space_tensor)1026 void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale, 1027 TensorShape workspace_tf_shape, 1028 Tensor** batch_mean_tensor, 1029 Tensor** batch_variance_tensor, 1030 Tensor** saved_mean_tensor, 1031 Tensor** saved_variance_tensor, 1032 Tensor** reserved_space_tensor) { 1033 DCHECK(batch_mean_tensor); 1034 DCHECK(batch_variance_tensor); 1035 DCHECK(saved_mean_tensor); 1036 DCHECK(saved_variance_tensor); 1037 1038 const size_t kBatchMeanIndex = 1; 1039 const size_t kBatchVarianceIndex = 2; 1040 const size_t kSavedMeanIndex = 3; 1041 const size_t kSavedVarianceIndex = 4; 1042 const size_t kReservedSpaceIndex = 5; 1043 1044 // Allocate batch mean output tensor. 1045 MklDnnShape mkl_shape_batch_mean; 1046 mkl_shape_batch_mean.SetMklTensor(false); 1047 AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor, 1048 tf_shape_scale, mkl_shape_batch_mean, 1049 native_format); 1050 DCHECK(*batch_mean_tensor); 1051 1052 // Set NAN mean value in case of empty input tensor 1053 int num_elements = tf_shape_scale.num_elements(); 1054 auto batch_mean_data = (*batch_mean_tensor)->flat<U>().data(); 1055 std::fill_n(batch_mean_data, num_elements, static_cast<U>(NAN)); 1056 1057 // Allocate batch variance output tensor. 1058 MklDnnShape mkl_shape_batch_variance; 1059 mkl_shape_batch_variance.SetMklTensor(false); 1060 AllocateOutputSetMklShape(context, kBatchVarianceIndex, 1061 batch_variance_tensor, tf_shape_scale, 1062 mkl_shape_batch_variance, native_format); 1063 DCHECK(*batch_variance_tensor); 1064 1065 // Set NAN variance value in case of empty input tensor 1066 auto batch_variance_data = (*batch_variance_tensor)->flat<U>().data(); 1067 std::fill_n(batch_variance_data, num_elements, static_cast<U>(NAN)); 1068 // Mean and variance (without Bessel's correction) saved for backward 1069 // computation to serve as pre-computed mean and variance. 1070 MklDnnShape mkl_shape_saved_mean; 1071 mkl_shape_saved_mean.SetMklTensor(false); 1072 AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor, 1073 tf_shape_scale, mkl_shape_saved_mean, 1074 native_format); 1075 DCHECK(*saved_mean_tensor); 1076 1077 // Set 0 mean value in case of empty input tensor 1078 auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data(); 1079 std::fill_n(saved_mean_data, num_elements, static_cast<U>(0)); 1080 1081 MklDnnShape mkl_shape_saved_variance; 1082 mkl_shape_saved_variance.SetMklTensor(false); 1083 AllocateOutputSetMklShape(context, kSavedVarianceIndex, 1084 saved_variance_tensor, tf_shape_scale, 1085 mkl_shape_saved_variance, native_format); 1086 DCHECK(*saved_variance_tensor); 1087 1088 // Set 0 variance value in case of empty input tensor 1089 auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data(); 1090 std::fill_n(saved_variance_data, num_elements, static_cast<U>(0)); 1091 1092 // Changes to support reserved_space_3 parameter in FusedBatchNormV3. 1093 if (reserved_space) { 1094 DCHECK(reserved_space_tensor != nullptr); 1095 1096 MklDnnShape mkl_shape_reserved_space; 1097 mkl_shape_reserved_space.SetMklTensor(false); 1098 AllocateOutputSetMklShape(context, kReservedSpaceIndex, 1099 reserved_space_tensor, workspace_tf_shape, 1100 mkl_shape_reserved_space, native_format); 1101 DCHECK((*reserved_space_tensor) != nullptr); 1102 } 1103 } 1104 }; 1105 1106 template <typename Device, typename T, typename U, bool reserved_space, 1107 bool native_format = false> 1108 class MklFusedBatchNormGradOp : public OpKernel { 1109 public: MklFusedBatchNormGradOp(OpKernelConstruction * context)1110 explicit MklFusedBatchNormGradOp(OpKernelConstruction* context) 1111 : OpKernel(context) { 1112 float epsilon; 1113 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 1114 epsilon_ = epsilon; 1115 string tensor_format; 1116 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 1117 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 1118 errors::InvalidArgument("Invalid data format")); 1119 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 1120 depth_ = 0; 1121 } 1122 Compute(OpKernelContext * context)1123 void Compute(OpKernelContext* context) override { 1124 try { 1125 const size_t kDiffDstIndex = 0; // index of diff_dst tensor 1126 const size_t kSrcIndex = 1; // index of src input tensor 1127 const size_t kScaleIndex = 2; // index of scale tensor 1128 const size_t kMeanIndex = 3; // index of saved_mean tensor 1129 const size_t kVarianceIndex = 4; // index of saved_variance tensor 1130 const size_t kReservedSpaceIndex = 5; // index of reserved space 3 tensor 1131 1132 const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex); 1133 const Tensor& src_tensor = MklGetInput(context, kSrcIndex); 1134 const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); 1135 const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex); 1136 const Tensor& saved_variance_tensor = 1137 MklGetInput(context, kVarianceIndex); 1138 const Tensor& reserved_space_tensor = 1139 (reserved_space) ? MklGetInput(context, kReservedSpaceIndex) 1140 : Tensor(); 1141 1142 MklDnnShape dnn_shape_src, dnn_shape_diff_dst; 1143 GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format); 1144 GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst, native_format); 1145 1146 TensorShape tf_shape_src, tf_shape_diff_dst; 1147 if (dnn_shape_diff_dst.IsMklTensor()) { 1148 tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); 1149 OP_REQUIRES( 1150 context, dnn_shape_diff_dst.GetDimension() == 4, 1151 errors::InvalidArgument("input must be 4-dimensional", 1152 diff_dst_tensor.shape().DebugString())); 1153 } else { 1154 tf_shape_diff_dst = diff_dst_tensor.shape(); 1155 OP_REQUIRES( 1156 context, diff_dst_tensor.dims() == 4, 1157 errors::InvalidArgument("input must be 4-dimensional", 1158 diff_dst_tensor.shape().DebugString())); 1159 } 1160 1161 if (dnn_shape_src.IsMklTensor()) { 1162 tf_shape_src = dnn_shape_src.GetTfShape(); 1163 OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, 1164 errors::InvalidArgument("input must be 4-dimensional", 1165 src_tensor.shape().DebugString())); 1166 } else { 1167 tf_shape_src = src_tensor.shape(); 1168 OP_REQUIRES(context, src_tensor.dims() == 4, 1169 errors::InvalidArgument("input must be 4-dimensional", 1170 src_tensor.shape().DebugString())); 1171 } 1172 1173 OP_REQUIRES(context, scale_tensor.dims() == 1, 1174 errors::InvalidArgument("scale must be 1-dimensional", 1175 scale_tensor.shape().DebugString())); 1176 OP_REQUIRES( 1177 context, saved_mean_tensor.dims() == 1, 1178 errors::InvalidArgument("saved mean must be 1-dimensional", 1179 saved_mean_tensor.shape().DebugString())); 1180 1181 OP_REQUIRES( 1182 context, saved_variance_tensor.dims() == 1, 1183 errors::InvalidArgument("saved variance must be 1-dimensional", 1184 saved_variance_tensor.shape().DebugString())); 1185 1186 OP_REQUIRES(context, tf_shape_src == tf_shape_diff_dst, 1187 errors::InvalidArgument( 1188 "x and y_backprop must have same shape, but x has shape ", 1189 src_tensor.shape(), " and y_backprop has shape ", 1190 diff_dst_tensor.shape())); 1191 1192 int num_channels; 1193 if (dnn_shape_src.IsMklTensor()) { 1194 num_channels = dnn_shape_src.DimSize(MklDnnDims::Dim_C); 1195 } else { 1196 num_channels = GetTensorDim(src_tensor, tensor_format_, 'C'); 1197 } 1198 OP_REQUIRES(context, scale_tensor.NumElements() == num_channels, 1199 errors::InvalidArgument( 1200 "scale must have the same number of elements " 1201 "as the channels of x, got ", 1202 scale_tensor.NumElements(), " and ", num_channels)); 1203 OP_REQUIRES(context, saved_mean_tensor.NumElements() == num_channels, 1204 errors::InvalidArgument( 1205 "reserve_space_1 must have the same number of " 1206 "elements as the channels of x, got ", 1207 saved_mean_tensor.NumElements(), " and ", num_channels)); 1208 OP_REQUIRES( 1209 context, saved_variance_tensor.NumElements() == num_channels, 1210 errors::InvalidArgument( 1211 "reserve_space_2 must have the same number of " 1212 "elements as the channels of x, got ", 1213 saved_variance_tensor.NumElements(), " and ", num_channels)); 1214 1215 // Handle the special case: input with 0 element and 0 batch size. 1216 Tensor* diff_src_tensor = nullptr; 1217 if (tf_shape_src.num_elements() == 0 || 1218 tf_shape_diff_dst.num_elements() == 0) { 1219 HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), 1220 &diff_src_tensor); 1221 return; 1222 } 1223 1224 if (dnn_shape_src.IsMklTensor()) { 1225 depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); 1226 } else if (dnn_shape_diff_dst.IsMklTensor()) { 1227 depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C); 1228 } else { 1229 ExtractParams(context); 1230 } 1231 1232 memory::format_tag dnn_fmt; 1233 MklTensorFormat mkl_tensor_fmt; 1234 if (dnn_shape_src.IsMklTensor()) { 1235 if (dnn_shape_src.IsTensorInNCHWFormat()) { 1236 dnn_fmt = memory::format_tag::nchw; 1237 mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW; 1238 } else { 1239 dnn_fmt = memory::format_tag::nhwc; 1240 mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC; 1241 } 1242 } else { 1243 mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_); 1244 dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt); 1245 } 1246 1247 MklDnnData<T> src(&cpu_engine_); 1248 MklDnnData<T> diff_dst(&cpu_engine_); 1249 MklDnnData<U> weights(&cpu_engine_); 1250 MklDnnData<U> diff_weights(&cpu_engine_); 1251 1252 memory::dims src_dims = 1253 dnn_shape_src.IsMklTensor() 1254 ? dnn_shape_src.GetSizesAsMklDnnDims() 1255 : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); 1256 memory::dims diff_dst_dims = 1257 dnn_shape_diff_dst.IsMklTensor() 1258 ? dnn_shape_diff_dst.GetSizesAsMklDnnDims() 1259 : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), 1260 tensor_format_); 1261 1262 // Set src and diff_dst primitive descriptors. 1263 memory::desc src_md = 1264 dnn_shape_src.IsMklTensor() 1265 ? dnn_shape_src.GetMklLayout() 1266 : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt); 1267 memory::desc diff_dst_md = 1268 dnn_shape_diff_dst.IsMklTensor() 1269 ? dnn_shape_diff_dst.GetMklLayout() 1270 : memory::desc(diff_dst_dims, MklDnnType<T>(), dnn_fmt); 1271 1272 MklDnnData<T> reorder_src(&cpu_engine_); 1273 MklDnnData<T> reorder_diff_dst(&cpu_engine_); 1274 T* diff_dst_data = 1275 static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data())); 1276 T* src_data = 1277 static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data())); 1278 1279 if (!native_format) { 1280 // oneDNN requires src and diff_dst to be in same memory layout, either 1281 // blocked or native format. If these inputs are in different formats, 1282 // convert the one in native format to blocked format as oneDNN gives 1283 // better performance for blocked format. 1284 if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { 1285 reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); 1286 reorder_diff_dst.CheckReorderToOpMem(src_md, cpu_engine_, context); 1287 diff_dst_md = src_md; 1288 diff_dst_data = 1289 static_cast<T*>(reorder_diff_dst.GetOpMem().get_data_handle()); 1290 } else if (!dnn_shape_src.IsMklTensor() && 1291 dnn_shape_diff_dst.IsMklTensor()) { 1292 reorder_src.SetUsrMem(src_md, &src_tensor); 1293 reorder_src.CheckReorderToOpMem(diff_dst_md, cpu_engine_, context); 1294 src_md = diff_dst_md; 1295 src_data = static_cast<T*>(reorder_src.GetOpMem().get_data_handle()); 1296 } 1297 } 1298 1299 // weights -- oneDNN packs scales/shifts as weights in order 1300 // of scale, ..., scale, shift, ...., shift 1301 weights.AllocateBuffer(2 * depth_ * sizeof(U)); 1302 U* weights_data_tf = reinterpret_cast<U*>(weights.GetAllocatedBuffer()); 1303 const U* scale_tf = scale_tensor.flat<U>().data(); 1304 for (int k = 0; k < depth_; k++) { 1305 weights_data_tf[k] = scale_tf[k]; 1306 weights_data_tf[k + depth_] = static_cast<U>(0); 1307 } 1308 1309 diff_weights.AllocateBuffer(2 * depth_ * sizeof(U)); 1310 1311 MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, 1312 is_training_, tensor_format_, src_md, 1313 diff_dst_md); 1314 MklFusedBatchNormBwdPrimitive<T, U>* bn_bwd = 1315 MklFusedBatchNormBwdPrimitiveFactory<T, U>::Get(bwdParams); 1316 1317 // Check if diff_dst input needs to be reordered 1318 std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd(); 1319 if (!native_format && diff_dst_md != bn_bwd_pd->diff_dst_desc()) { 1320 diff_dst.SetUsrMem(diff_dst_md, diff_dst_data); 1321 diff_dst.CheckReorderToOpMem(bn_bwd_pd->diff_dst_desc(), cpu_engine_, 1322 context); 1323 diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle()); 1324 } 1325 1326 if (!native_format && (src_md != bn_bwd_pd->src_desc())) { 1327 src.SetUsrMem(src_md, src_data); 1328 src.CheckReorderToOpMem(bn_bwd_pd->src_desc(), cpu_engine_, context); 1329 src_data = static_cast<T*>(src.GetOpMem().get_data_handle()); 1330 } 1331 1332 // Indices of output tensors 1333 const size_t kDiffSrcIndex = 0; 1334 1335 // Allocate output tensor diff_src, always set as oneDNN layout. 1336 MklDnnShape dnn_shape_diff_src; 1337 TensorShape tf_shape_diff_src; 1338 dnn_shape_diff_src.SetMklTensor(true); 1339 auto diff_src_pd = bn_bwd->GetDiffSrcPd(); 1340 dnn_shape_diff_src.SetMklLayout(&diff_src_pd); 1341 dnn_shape_diff_src.SetElemType(MklDnnType<T>()); 1342 dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, mkl_tensor_fmt); 1343 dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_); 1344 tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); 1345 if (native_format) { 1346 tf_shape_diff_src = dnn_shape_diff_src.GetTfShape(); 1347 } 1348 AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, 1349 tf_shape_diff_src, dnn_shape_diff_src, 1350 native_format); 1351 1352 U* mean_data = 1353 static_cast<U*>(const_cast<U*>(saved_mean_tensor.flat<U>().data())); 1354 U* variance_data = static_cast<U*>( 1355 const_cast<U*>(saved_variance_tensor.flat<U>().data())); 1356 U* weights_data = weights_data_tf; 1357 T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data()); 1358 U* diff_weights_data = static_cast<U*>(diff_weights.GetAllocatedBuffer()); 1359 1360 U* res_space_data = 1361 ((reserved_space) ? static_cast<U*>(const_cast<U*>( 1362 reserved_space_tensor.flat<U>().data())) 1363 : nullptr); 1364 1365 // Execute 1366 std::shared_ptr<stream> bwd_cpu_stream; 1367 MklDnnThreadPool eigen_tp(context); 1368 bwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_bwd->GetEngine())); 1369 bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data, 1370 weights_data, diff_src_data, diff_weights_data, 1371 res_space_data, bwd_cpu_stream); 1372 // Allocate output TF tensors diff_scale and diff_shift. 1373 Tensor* diff_scale_tensor = nullptr; 1374 Tensor* diff_shift_tensor = nullptr; 1375 AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor, 1376 &diff_shift_tensor); 1377 1378 // Copy data for tensors diff_scale and diff_shift. 1379 auto diff_scale_data = diff_scale_tensor->flat<U>().data(); 1380 auto diff_shift_data = diff_shift_tensor->flat<U>().data(); 1381 std::memcpy(reinterpret_cast<char*>(diff_scale_data), 1382 reinterpret_cast<char*>(diff_weights_data), 1383 depth_ * sizeof(U)); 1384 std::memcpy(reinterpret_cast<char*>(diff_shift_data), 1385 reinterpret_cast<char*>(diff_weights_data + depth_), 1386 depth_ * sizeof(U)); 1387 } catch (dnnl::error& e) { 1388 string error_msg = "Status: " + std::to_string(e.status) + 1389 ", message: " + string(e.message) + ", in file " + 1390 string(__FILE__) + ":" + std::to_string(__LINE__); 1391 OP_REQUIRES_OK( 1392 context, 1393 errors::Aborted("Operation received an exception:", error_msg)); 1394 } 1395 } 1396 1397 private: 1398 float epsilon_; 1399 TensorFormat tensor_format_; 1400 size_t depth_; // Batch normalization is performed for per channel. 1401 bool is_training_; 1402 engine cpu_engine_ = engine(engine::kind::cpu, 0); 1403 ExtractParams(OpKernelContext * context)1404 void ExtractParams(OpKernelContext* context) { 1405 const Tensor& input = MklGetInput(context, 0); 1406 depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); 1407 } 1408 HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape tf_shape_scale_shift,Tensor ** diff_src_tensor)1409 void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, 1410 TensorShape tf_shape_scale_shift, 1411 Tensor** diff_src_tensor) { 1412 const size_t kDiffSrcIndex = 0; 1413 1414 MklDnnShape dnn_shape_diff_src; 1415 dnn_shape_diff_src.SetMklTensor(false); 1416 AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, 1417 tf_shape_src, dnn_shape_diff_src, native_format); 1418 auto diff_src_data = (*diff_src_tensor)->flat<T>().data(); 1419 std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), 1420 static_cast<T>(0)); 1421 1422 Tensor* diff_scale_tensor = nullptr; 1423 Tensor* diff_shift_tensor = nullptr; 1424 AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor, 1425 &diff_shift_tensor); 1426 } 1427 AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale_shift,Tensor ** diff_scale_tensor,Tensor ** diff_shift_tensor)1428 void AllocateTFOutputs(OpKernelContext* context, 1429 TensorShape tf_shape_scale_shift, 1430 Tensor** diff_scale_tensor, 1431 Tensor** diff_shift_tensor) { 1432 DCHECK(diff_scale_tensor); 1433 DCHECK(diff_shift_tensor); 1434 1435 const size_t kDiffScaleIndex = 1; 1436 const size_t kDiffShiftIndex = 2; 1437 const size_t kP1Index = 3; 1438 const size_t kP2Index = 4; 1439 1440 // Separate out scale and shift grad and copy to individual tensors 1441 MklDnnShape mkl_shape_diff_scale; 1442 mkl_shape_diff_scale.SetMklTensor(false); 1443 AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, 1444 tf_shape_scale_shift, mkl_shape_diff_scale, 1445 native_format); 1446 DCHECK(*diff_scale_tensor); 1447 1448 auto diff_scale_data = (*diff_scale_tensor)->flat<U>().data(); 1449 std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(), 1450 static_cast<U>(0)); 1451 1452 MklDnnShape mkl_shape_diff_shift; 1453 mkl_shape_diff_shift.SetMklTensor(false); 1454 AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, 1455 tf_shape_scale_shift, mkl_shape_diff_shift, 1456 native_format); 1457 DCHECK(*diff_shift_tensor); 1458 1459 auto diff_shift_data = (*diff_shift_tensor)->flat<U>().data(); 1460 std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(), 1461 static_cast<U>(0)); 1462 1463 // Placeholders for estimated_mean and estimated_variance, which are 1464 // used for inference and thus not needed here for gradient computation. 1465 Tensor *p1_tensor = nullptr, *p2_tensor = nullptr; 1466 MklDnnShape mkl_shape_p; 1467 mkl_shape_p.SetMklTensor(false); 1468 AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}), 1469 mkl_shape_p, native_format); 1470 std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(), 1471 static_cast<U>(0)); 1472 AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}), 1473 mkl_shape_p, native_format); 1474 std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(), 1475 static_cast<U>(0)); 1476 } 1477 GetMeanVarianceDims()1478 memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } 1479 }; 1480 1481 #define REGISTER_MKL_FUSED_BATCHNORM_CPU(T) \ 1482 REGISTER_KERNEL_BUILDER( \ 1483 Name("_MklFusedBatchNorm") \ 1484 .Device(DEVICE_CPU) \ 1485 .TypeConstraint<T>("T") \ 1486 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1487 MklFusedBatchNormOp<CPUDevice, T, T, false, false>); \ 1488 REGISTER_KERNEL_BUILDER( \ 1489 Name("_MklNativeFusedBatchNorm") \ 1490 .Device(DEVICE_CPU) \ 1491 .TypeConstraint<T>("T") \ 1492 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1493 MklFusedBatchNormOp<CPUDevice, T, T, false, false, true>); 1494 1495 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU); 1496 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU); 1497 #undef REGISTER_MKL_FUSED_BATCHNORM_CPU 1498 1499 #define REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(T, U) \ 1500 REGISTER_KERNEL_BUILDER( \ 1501 Name("_MklFusedBatchNormV2") \ 1502 .Device(DEVICE_CPU) \ 1503 .TypeConstraint<T>("T") \ 1504 .TypeConstraint<U>("U") \ 1505 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1506 MklFusedBatchNormOp<CPUDevice, T, U, false, false>); \ 1507 REGISTER_KERNEL_BUILDER( \ 1508 Name("_MklNativeFusedBatchNormV2") \ 1509 .Device(DEVICE_CPU) \ 1510 .TypeConstraint<T>("T") \ 1511 .TypeConstraint<U>("U") \ 1512 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1513 MklFusedBatchNormOp<CPUDevice, T, U, false, false, true>); 1514 1515 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float); 1516 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float); 1517 #undef REGISTER_MKL_FUSED_BATCHNORM_V2_CPU 1518 1519 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU(T) \ 1520 REGISTER_KERNEL_BUILDER( \ 1521 Name("_MklFusedBatchNormGrad") \ 1522 .Device(DEVICE_CPU) \ 1523 .TypeConstraint<T>("T") \ 1524 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1525 MklFusedBatchNormGradOp<CPUDevice, T, T, false>); \ 1526 REGISTER_KERNEL_BUILDER( \ 1527 Name("_MklNativeFusedBatchNormGrad") \ 1528 .Device(DEVICE_CPU) \ 1529 .TypeConstraint<T>("T") \ 1530 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1531 MklFusedBatchNormGradOp<CPUDevice, T, T, false, true>); 1532 1533 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); 1534 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); 1535 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU 1536 1537 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(T, U) \ 1538 REGISTER_KERNEL_BUILDER( \ 1539 Name("_MklFusedBatchNormGradV2") \ 1540 .Device(DEVICE_CPU) \ 1541 .TypeConstraint<T>("T") \ 1542 .TypeConstraint<U>("U") \ 1543 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1544 MklFusedBatchNormGradOp<CPUDevice, T, U, false>); \ 1545 REGISTER_KERNEL_BUILDER( \ 1546 Name("_MklNativeFusedBatchNormGradV2") \ 1547 .Device(DEVICE_CPU) \ 1548 .TypeConstraint<T>("T") \ 1549 .TypeConstraint<U>("U") \ 1550 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1551 MklFusedBatchNormGradOp<CPUDevice, T, U, false, true>); 1552 1553 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(float, float); 1554 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float); 1555 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU 1556 1557 // TODO(intel-tf): FusedBatchNormV3 has an additional output that 1558 // is used to hold intermediate results. This parameter 1559 // functionality is not implemented on CPU. 1560 #define REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(T, U) \ 1561 REGISTER_KERNEL_BUILDER( \ 1562 Name("_MklFusedBatchNormV3") \ 1563 .Device(DEVICE_CPU) \ 1564 .TypeConstraint<T>("T") \ 1565 .TypeConstraint<U>("U") \ 1566 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1567 MklFusedBatchNormOp<CPUDevice, T, U, true, false>); \ 1568 REGISTER_KERNEL_BUILDER( \ 1569 Name("_MklFusedBatchNormEx") \ 1570 .Device(DEVICE_CPU) \ 1571 .TypeConstraint<T>("T") \ 1572 .TypeConstraint<U>("U") \ 1573 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1574 MklFusedBatchNormOp<CPUDevice, T, U, true, true>); \ 1575 REGISTER_KERNEL_BUILDER( \ 1576 Name("_MklNativeFusedBatchNormV3") \ 1577 .Device(DEVICE_CPU) \ 1578 .TypeConstraint<T>("T") \ 1579 .TypeConstraint<U>("U") \ 1580 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1581 MklFusedBatchNormOp<CPUDevice, T, U, true, false, true>); \ 1582 REGISTER_KERNEL_BUILDER( \ 1583 Name("_MklNativeFusedBatchNormEx") \ 1584 .Device(DEVICE_CPU) \ 1585 .TypeConstraint<T>("T") \ 1586 .TypeConstraint<U>("U") \ 1587 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1588 MklFusedBatchNormOp<CPUDevice, T, U, true, true, true>); 1589 1590 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float); 1591 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float); 1592 #undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU 1593 1594 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx") 1595 .Device(DEVICE_CPU) 1596 .TypeConstraint<float>("T") 1597 .TypeConstraint<float>("U"), 1598 NoOp); 1599 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx") 1600 .Device(DEVICE_CPU) 1601 .TypeConstraint<bfloat16>("T") 1602 .TypeConstraint<float>("U"), 1603 NoOp); 1604 1605 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U) \ 1606 REGISTER_KERNEL_BUILDER( \ 1607 Name("_MklFusedBatchNormGradV3") \ 1608 .Device(DEVICE_CPU) \ 1609 .TypeConstraint<T>("T") \ 1610 .TypeConstraint<U>("U") \ 1611 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 1612 MklFusedBatchNormGradOp<CPUDevice, T, U, true>); \ 1613 REGISTER_KERNEL_BUILDER( \ 1614 Name("_MklNativeFusedBatchNormGradV3") \ 1615 .Device(DEVICE_CPU) \ 1616 .TypeConstraint<T>("T") \ 1617 .TypeConstraint<U>("U") \ 1618 .Label(mkl_op_registry::kMklNameChangeOpLabel), \ 1619 MklFusedBatchNormGradOp<CPUDevice, T, U, true, true>); 1620 1621 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(float, float); 1622 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(bfloat16, float); 1623 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU 1624 1625 } // namespace tensorflow 1626 1627 #undef GET_FLAG 1628 #undef IS_SET 1629 1630 #endif // INTEL_MKL 1631