xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_fused_batch_norm_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifdef INTEL_MKL
16 
17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
18 #include "dnnl.hpp"
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
24 #include "tensorflow/core/kernels/no_op.h"
25 #include "tensorflow/core/util/mkl_util.h"
26 #include "tensorflow/core/util/tensor_format.h"
27 #ifdef DNNL_AARCH64_USE_ACL
28 #include "tensorflow/core/platform/mutex.h"
29 #endif
30 
31 #define GET_FLAG(bn_flag) static_cast<int>(dnnl::normalization_flags::bn_flag)
32 #define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
33 
34 using dnnl::batch_normalization_backward;
35 using dnnl::batch_normalization_forward;
36 using dnnl::prop_kind;
37 using dnnl::stream;
38 
39 using BatchNormFwdPd = dnnl::batch_normalization_forward::primitive_desc;
40 using BatchNormBwdPd = dnnl::batch_normalization_backward::primitive_desc;
41 
42 namespace tensorflow {
43 using CPUDevice = Eigen::ThreadPoolDevice;
44 
45 using FusedBNActivationMode = functor::FusedBatchNormActivationMode;
46 
47 struct MklBatchNormFwdParams {
48   memory::dims src_dims;
49   int depth;
50   float eps;
51   bool training;
52   TensorFormat data_format;
53   FusedBNActivationMode activation_mode;
54   memory::desc src_md;
55 
MklBatchNormFwdParamstensorflow::MklBatchNormFwdParams56   MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
57                         bool training, TensorFormat data_format,
58                         memory::desc src_md,
59                         FusedBNActivationMode activation_mode)
60       : src_dims(src_dims),
61         depth(depth),
62         eps(eps),
63         training(training),
64         data_format(data_format),
65         activation_mode(activation_mode),
66         src_md(src_md) {}
67 };
68 
69 template <typename T, typename U>
70 class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
71  public:
MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams & fwdParams)72   explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams)
73       : MklPrimitive(engine(engine::kind::cpu, 0)) {
74     if (context_.bn_fwd == nullptr) Setup(fwdParams);
75   }
76 
~MklFusedBatchNormFwdPrimitive()77   ~MklFusedBatchNormFwdPrimitive() {}
78 
79   // BatchNormalization forward execute
80   //   src_data:     input data buffer of src
81   //   weights_data: input data buffer of weights
82   //   dst_data:     output data buffer of dst
83   //   mean_data:     output data buffer of means
84   //   variance_data: output data buffer of variances
Execute(const T * src_data,const U * weights_data,T * dst_data,U * mean_data,U * variance_data,std::shared_ptr<stream> fwd_stream,U * workspace_data)85   void Execute(const T* src_data, const U* weights_data, T* dst_data,
86                U* mean_data, U* variance_data,
87                std::shared_ptr<stream> fwd_stream, U* workspace_data) {
88 #ifdef DNNL_AARCH64_USE_ACL
89     mutex_lock lock(primitive_execution_mu_);
90 #endif
91 #ifndef ENABLE_ONEDNN_OPENMP
92     // TODO(intel-tf): Create a common function and avoid the duplicate code
93     context_.src_mem->set_data_handle(
94         static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
95     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
96                                       *fwd_stream);
97 
98     if (IS_SET(use_scale_shift))
99       context_.weights_mem->set_data_handle(
100           static_cast<void*>(const_cast<U*>(weights_data)), *fwd_stream);
101 
102     if ((context_.pkind == prop_kind::forward_training) ||
103         (IS_SET(use_global_stats))) {
104       context_.mean_mem->set_data_handle(static_cast<void*>(mean_data),
105                                          *fwd_stream);
106       context_.variance_mem->set_data_handle(static_cast<void*>(variance_data),
107                                              *fwd_stream);
108     }
109     if (workspace_data != nullptr) {
110       context_.ws_mem->set_data_handle(workspace_data, *fwd_stream);
111     }
112 #else
113     context_.src_mem->set_data_handle(
114         static_cast<void*>(const_cast<T*>(src_data)));
115     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
116 
117     if (IS_SET(use_scale_shift))
118       context_.weights_mem->set_data_handle(
119           static_cast<void*>(const_cast<U*>(weights_data)));
120 
121     if ((context_.pkind == prop_kind::forward_training) ||
122         (IS_SET(use_global_stats))) {
123       context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
124       context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
125     }
126     if (workspace_data != nullptr) {
127       context_.ws_mem->set_data_handle(workspace_data);
128     }
129 #endif  // !ENABLE_ONEDNN_OPENMP
130 
131     // Execute batch-normalization forward primitives.
132     execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
133 
134     context_.src_mem->set_data_handle(DummyData);
135     context_.dst_mem->set_data_handle(DummyData);
136 
137     if (IS_SET(use_scale_shift))
138       context_.weights_mem->set_data_handle(DummyData);
139 
140     if ((context_.pkind == prop_kind::forward_training) ||
141         (IS_SET(use_global_stats))) {
142       context_.mean_mem->set_data_handle(DummyData);
143       context_.variance_mem->set_data_handle(DummyData);
144     }
145 
146     if (workspace_data != nullptr) {
147       context_.ws_mem->set_data_handle(DummyData);
148     }
149   }
150 
GetDstPd() const151   memory::desc GetDstPd() const { return context_.dst_mem->get_desc(); }
152 
GetBatchNormFwdPd() const153   std::shared_ptr<BatchNormFwdPd> GetBatchNormFwdPd() const {
154     return context_.fwd_pd;
155   }
156 
157  private:
158   // Primitive reuse context for BatchNorm forward op.
159   struct BatchNormFwdContext {
160     // Flags indicating if it is training or inference mode.
161     int64 flags;
162 
163     // Algorithm kind.
164     dnnl::prop_kind pkind;
165 
166     // Inputs/outputs memory.
167     std::shared_ptr<dnnl::memory> src_mem;
168     std::shared_ptr<dnnl::memory> weights_mem;
169     std::shared_ptr<dnnl::memory> dst_mem;
170     std::shared_ptr<dnnl::memory> mean_mem;
171     std::shared_ptr<dnnl::memory> variance_mem;
172     std::shared_ptr<dnnl::memory> ws_mem;
173 
174     // Forward BatchNorm primitive descriptor.
175     std::shared_ptr<BatchNormFwdPd> fwd_pd;
176 
177     // BatchNorm forward primitive.
178     std::shared_ptr<dnnl::primitive> bn_fwd;
179     std::vector<dnnl::primitive> fwd_primitives;
180 
181     std::vector<std::unordered_map<int, memory>> net_args;
182 
BatchNormFwdContexttensorflow::MklFusedBatchNormFwdPrimitive::BatchNormFwdContext183     BatchNormFwdContext()
184         : flags(0),
185           pkind(prop_kind::forward_training),
186           src_mem(nullptr),
187           weights_mem(nullptr),
188           dst_mem(nullptr),
189           mean_mem(nullptr),
190           variance_mem(nullptr),
191           ws_mem(nullptr),
192           bn_fwd(nullptr) {}
193   };
194 
Setup(const MklBatchNormFwdParams & fwdParams)195   void Setup(const MklBatchNormFwdParams& fwdParams) {
196     context_.flags =
197         fwdParams.training
198             ? GET_FLAG(use_scale_shift)
199             : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats));
200     context_.pkind = fwdParams.training ? prop_kind::forward_training
201                                         : prop_kind::forward_scoring;
202 
203     if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
204       context_.flags |= GET_FLAG(fuse_norm_relu);
205     }
206     // Memory descriptor
207     auto src_md = fwdParams.src_md;
208     // Create forward BatchNorm descriptor and primitive descriptor.
209     auto fwd_desc = batch_normalization_forward::desc(
210         context_.pkind, src_md, fwdParams.eps,
211         static_cast<dnnl::normalization_flags>(context_.flags));
212 
213     context_.fwd_pd.reset(new BatchNormFwdPd(fwd_desc, cpu_engine_));
214 
215     // Create memory primitive based on dummy data
216     context_.src_mem.reset(
217         new memory(context_.fwd_pd->src_desc(), cpu_engine_, DummyData));
218     context_.dst_mem.reset(
219         new memory(context_.fwd_pd->dst_desc(), cpu_engine_, DummyData));
220 
221     memory::dims s_dims = {2, fwdParams.depth};
222     memory::dims m_dims = {1, fwdParams.depth};
223     if (IS_SET(use_scale_shift)) {
224       context_.weights_mem.reset(
225           new memory({{s_dims}, MklDnnType<U>(), memory::format_tag::nc},
226                      cpu_engine_, DummyData));
227     }
228 
229     if (fwdParams.training || (IS_SET(use_global_stats))) {
230       context_.mean_mem.reset(
231           new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc},
232                      cpu_engine_, DummyData));
233 
234       context_.variance_mem.reset(
235           new memory({{m_dims}, MklDnnType<U>(), memory::format_tag::nc},
236                      cpu_engine_, DummyData));
237     }
238 
239     if (IS_SET(fuse_norm_relu)) {
240       context_.ws_mem.reset(new memory(context_.fwd_pd->workspace_desc(),
241                                        cpu_engine_, DummyData));
242     }
243 
244     // BatchNorm forward primitive.
245     // TODO(intel-tf): Merge all the #ifdefs and simplify code
246     if (!fwdParams.training && !(IS_SET(use_global_stats))) {
247       if (IS_SET(use_scale_shift)) {
248         context_.net_args.push_back({{DNNL_ARG_SRC, *context_.src_mem},
249                                      {DNNL_ARG_WEIGHTS, *context_.weights_mem},
250                                      {DNNL_ARG_DST, *context_.dst_mem}});
251       } else {
252         context_.net_args.push_back({{DNNL_ARG_SRC, *context_.src_mem},
253                                      {DNNL_ARG_DST, *context_.dst_mem}});
254       }
255       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
256     } else if (IS_SET(use_global_stats)) {
257       if (IS_SET(use_scale_shift)) {
258         if (IS_SET(fuse_norm_relu)) {
259           context_.net_args.push_back(
260               {{DNNL_ARG_SRC, *context_.src_mem},
261                {DNNL_ARG_MEAN, *context_.mean_mem},
262                {DNNL_ARG_VARIANCE, *context_.variance_mem},
263                {DNNL_ARG_WEIGHTS, *context_.weights_mem},
264                {DNNL_ARG_DST, *context_.dst_mem},
265                {DNNL_ARG_WORKSPACE, *context_.ws_mem}});
266         } else {
267           context_.net_args.push_back(
268               {{DNNL_ARG_SRC, *context_.src_mem},
269                {DNNL_ARG_MEAN, *context_.mean_mem},
270                {DNNL_ARG_VARIANCE, *context_.variance_mem},
271                {DNNL_ARG_WEIGHTS, *context_.weights_mem},
272                {DNNL_ARG_DST, *context_.dst_mem}});
273         }
274       } else {
275         if (IS_SET(fuse_norm_relu)) {
276           context_.net_args.push_back(
277               {{DNNL_ARG_SRC, *context_.src_mem},
278                {DNNL_ARG_MEAN, *context_.mean_mem},
279                {DNNL_ARG_VARIANCE, *context_.variance_mem},
280                {DNNL_ARG_DST, *context_.dst_mem},
281                {DNNL_ARG_WORKSPACE, *context_.ws_mem}});
282         } else {
283           context_.net_args.push_back(
284               {{DNNL_ARG_SRC, *context_.src_mem},
285                {DNNL_ARG_MEAN, *context_.mean_mem},
286                {DNNL_ARG_VARIANCE, *context_.variance_mem},
287                {DNNL_ARG_DST, *context_.dst_mem}});
288         }
289       }
290       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
291     } else {
292       if (IS_SET(use_scale_shift)) {
293         if (IS_SET(fuse_norm_relu)) {
294           context_.net_args.push_back(
295               {{DNNL_ARG_SRC, *context_.src_mem},
296                {DNNL_ARG_WEIGHTS, *context_.weights_mem},
297                {DNNL_ARG_DST, *context_.dst_mem},
298                {DNNL_ARG_MEAN, *context_.mean_mem},
299                {DNNL_ARG_VARIANCE, *context_.variance_mem},
300                {DNNL_ARG_WORKSPACE, *context_.ws_mem}});
301         } else {
302           context_.net_args.push_back(
303               {{DNNL_ARG_SRC, *context_.src_mem},
304                {DNNL_ARG_WEIGHTS, *context_.weights_mem},
305                {DNNL_ARG_DST, *context_.dst_mem},
306                {DNNL_ARG_MEAN, *context_.mean_mem},
307                {DNNL_ARG_VARIANCE, *context_.variance_mem}});
308         }
309       } else {
310         if (IS_SET(fuse_norm_relu)) {
311           context_.net_args.push_back(
312               {{DNNL_ARG_SRC, *context_.src_mem},
313                {DNNL_ARG_DST, *context_.dst_mem},
314                {DNNL_ARG_MEAN, *context_.mean_mem},
315                {DNNL_ARG_VARIANCE, *context_.variance_mem},
316                {DNNL_ARG_WORKSPACE, *context_.ws_mem}});
317         } else {
318           context_.net_args.push_back(
319               {{DNNL_ARG_SRC, *context_.src_mem},
320                {DNNL_ARG_DST, *context_.dst_mem},
321                {DNNL_ARG_MEAN, *context_.mean_mem},
322                {DNNL_ARG_VARIANCE, *context_.variance_mem}});
323         }
324       }
325       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
326     }
327 
328     context_.fwd_primitives.push_back(*context_.bn_fwd);
329   }
330 
331   struct BatchNormFwdContext context_;
332 
333 #ifdef DNNL_AARCH64_USE_ACL
334   mutex primitive_execution_mu_;
335 #endif
336 };
337 
338 template <typename T, typename U>
339 class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
340  public:
Get(const MklBatchNormFwdParams & fwdParams)341   static MklFusedBatchNormFwdPrimitive<T, U>* Get(
342       const MklBatchNormFwdParams& fwdParams) {
343     auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T, U>*>(
344         MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance()
345             .GetBatchNormFwd(fwdParams));
346 
347     if (bn_fwd == nullptr) {
348       bn_fwd = new MklFusedBatchNormFwdPrimitive<T, U>(fwdParams);
349       MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormFwd(
350           fwdParams, bn_fwd);
351     }
352     return bn_fwd;
353   }
354 
GetInstance()355   static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() {
356     static MklFusedBatchNormFwdPrimitiveFactory instance_;
357     return instance_;
358   }
359 
360  private:
MklFusedBatchNormFwdPrimitiveFactory()361   MklFusedBatchNormFwdPrimitiveFactory() {}
~MklFusedBatchNormFwdPrimitiveFactory()362   ~MklFusedBatchNormFwdPrimitiveFactory() {}
363 
CreateKey(const MklBatchNormFwdParams & fwdParams)364   static string CreateKey(const MklBatchNormFwdParams& fwdParams) {
365     string prefix = "bn_fwd";
366     FactoryKeyCreator key_creator;
367     key_creator.AddAsKey(prefix);
368     key_creator.AddAsKey(fwdParams.src_dims);
369     key_creator.AddAsKey<int>(fwdParams.depth);
370     key_creator.AddAsKey<float>(fwdParams.eps);
371     key_creator.AddAsKey<bool>(fwdParams.training);
372     key_creator.AddAsKey<TensorFormat>(fwdParams.data_format);
373     key_creator.AddAsKey<FusedBNActivationMode>(fwdParams.activation_mode);
374     key_creator.AddAsKey(typeid(T).name());
375     key_creator.AddAsKey(typeid(U).name());
376     return key_creator.GetKey();
377   }
378 
GetBatchNormFwd(const MklBatchNormFwdParams & fwdParams)379   MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) {
380     string key = CreateKey(fwdParams);
381     return this->GetOp(key);
382   }
383 
SetBatchNormFwd(const MklBatchNormFwdParams & fwdParams,MklPrimitive * op)384   void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams,
385                        MklPrimitive* op) {
386     string key = CreateKey(fwdParams);
387     this->SetOp(key, op);
388   }
389 };
390 
391 struct MklBatchNormBwdParams {
392   memory::dims src_dims;
393   memory::dims diff_dst_dims;
394   int depth;
395   float eps;
396   bool training;
397   TensorFormat data_format;
398   memory::desc src_md;
399   memory::desc diff_dst_md;
400 
MklBatchNormBwdParamstensorflow::MklBatchNormBwdParams401   MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims,
402                         int depth, float eps, bool training,
403                         TensorFormat data_format, memory::desc src_md,
404                         memory::desc diff_dst_md)
405       : src_dims(src_dims),
406         diff_dst_dims(diff_dst_dims),
407         depth(depth),
408         eps(eps),
409         training(training),
410         data_format(data_format),
411         src_md(src_md),
412         diff_dst_md(diff_dst_md) {}
413 };
414 
415 template <typename T, typename U>
416 class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
417  public:
MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams & bwdParams)418   explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams)
419       : MklPrimitive(engine(engine::kind::cpu, 0)) {
420     if (context_.bn_bwd == nullptr) Setup(bwdParams);
421   }
422 
~MklFusedBatchNormBwdPrimitive()423   ~MklFusedBatchNormBwdPrimitive() {}
424 
425   // BatchNormalization backward execute
426   //   src_data:       input data buffer of src
427   //   mean_data:      input data buffer of mean
428   //   variance_data:  input data buffer of variance
429   //   diff_dst_data:  input data buffer of diff_dst
430   //   weights_data:   input data buffer of weights
431   //   diff_src_data:      output data buffer of diff_src
432   //   diff_weights_data:  output data buffer of diff_weights
433   //   res_space_data:     output data buffer or reserved_space_3.
434   //                       TODO: reserved_space_3: temp mem to hold
435   //                          intermediate results is not implemented
436   //                          on CPU as of now.
Execute(const T * src_data,const U * mean_data,const U * variance_data,const T * diff_dst_data,const U * weights_data,T * diff_src_data,U * diff_weights_data,U * res_space_data,std::shared_ptr<stream> bwd_stream)437   void Execute(const T* src_data, const U* mean_data, const U* variance_data,
438                const T* diff_dst_data, const U* weights_data, T* diff_src_data,
439                U* diff_weights_data, U* res_space_data,
440                std::shared_ptr<stream> bwd_stream) {
441 #ifdef DNNL_AARCH64_USE_ACL
442     mutex_lock lock(primitive_execution_mu_);
443 #endif
444 #ifndef ENABLE_ONEDNN_OPENMP
445     // TODO(intel-tf): Create a common function and avoid the duplicate code
446     context_.src_mem->set_data_handle(
447         static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
448     context_.mean_mem->set_data_handle(
449         static_cast<void*>(const_cast<U*>(mean_data)), *bwd_stream);
450     context_.variance_mem->set_data_handle(
451         static_cast<void*>(const_cast<U*>(variance_data)), *bwd_stream);
452     context_.diff_dst_mem->set_data_handle(
453         static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
454 
455     if (IS_SET(use_scale_shift)) {
456       context_.weights_mem->set_data_handle(
457           static_cast<void*>(const_cast<U*>(weights_data)), *bwd_stream);
458       context_.diff_weights_mem->set_data_handle(
459           static_cast<void*>(diff_weights_data), *bwd_stream);
460     }
461 
462     context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
463                                            *bwd_stream);
464 #else
465     context_.src_mem->set_data_handle(
466         static_cast<void*>(const_cast<T*>(src_data)));
467     context_.mean_mem->set_data_handle(
468         static_cast<void*>(const_cast<U*>(mean_data)));
469     context_.variance_mem->set_data_handle(
470         static_cast<void*>(const_cast<U*>(variance_data)));
471     context_.diff_dst_mem->set_data_handle(
472         static_cast<void*>(const_cast<T*>(diff_dst_data)));
473 
474     if (IS_SET(use_scale_shift)) {
475       context_.weights_mem->set_data_handle(
476           static_cast<void*>(const_cast<U*>(weights_data)));
477       context_.diff_weights_mem->set_data_handle(
478           static_cast<void*>(diff_weights_data));
479     }
480 
481     context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
482 #endif  // !ENABLE_ONEDNN_OPENMP
483     // Execute backward batch-normalization primitives.
484     DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
485     execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
486 
487     // After execution, set data handle back to DummyData.
488     context_.src_mem->set_data_handle(DummyData);
489     context_.mean_mem->set_data_handle(DummyData);
490     context_.variance_mem->set_data_handle(DummyData);
491     context_.diff_dst_mem->set_data_handle(DummyData);
492     if (IS_SET(use_scale_shift)) {
493       context_.weights_mem->set_data_handle(DummyData);
494       context_.diff_weights_mem->set_data_handle(DummyData);
495     }
496     context_.diff_src_mem->set_data_handle(DummyData);
497   }
498 
GetBatchNormBwdPd() const499   std::shared_ptr<BatchNormBwdPd> GetBatchNormBwdPd() const {
500     return context_.bwd_pd;
501   }
502 
GetDiffSrcPd()503   memory::desc GetDiffSrcPd() { return context_.diff_src_mem->get_desc(); }
504 
505  private:
506   struct BatchNormBwdContext {
507     // Flags to indicate whether it is training or inference.
508     int64 flags;
509 
510     // Inputs/output memory.
511     std::shared_ptr<dnnl::memory> src_mem;
512     std::shared_ptr<dnnl::memory> mean_mem;
513     std::shared_ptr<dnnl::memory> variance_mem;
514     std::shared_ptr<dnnl::memory> diff_dst_mem;
515     std::shared_ptr<dnnl::memory> weights_mem;
516     std::shared_ptr<dnnl::memory> diff_weights_mem;
517     std::shared_ptr<dnnl::memory> diff_src_mem;
518 
519     // Backward batch-normalization primitive descriptor.
520     std::shared_ptr<BatchNormBwdPd> bwd_pd;
521 
522     // Backward batch-normalization primitive.
523     std::shared_ptr<dnnl::primitive> bn_bwd;
524     std::vector<dnnl::primitive> bwd_primitives;
525 
526     std::vector<std::unordered_map<int, memory>> net_args;
527 
BatchNormBwdContexttensorflow::MklFusedBatchNormBwdPrimitive::BatchNormBwdContext528     BatchNormBwdContext()
529         : src_mem(nullptr),
530           mean_mem(nullptr),
531           variance_mem(nullptr),
532           diff_dst_mem(nullptr),
533           weights_mem(nullptr),
534           diff_weights_mem(nullptr),
535           diff_src_mem(nullptr) {}
536   };
537 
Setup(const MklBatchNormBwdParams & bwdParams)538   void Setup(const MklBatchNormBwdParams& bwdParams) {
539     context_.flags =
540         bwdParams.training
541             ? GET_FLAG(use_scale_shift)
542             : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats));
543 
544     // Memory descriptors.
545     auto src_md = bwdParams.src_md;
546     auto diff_dst_md = bwdParams.diff_dst_md;
547     auto variance_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(),
548                                       memory::format_tag::nc);
549     auto mean_desc = memory::desc({1, bwdParams.depth}, MklDnnType<U>(),
550                                   memory::format_tag::nc);
551     auto weights_desc = memory::desc({2, bwdParams.depth}, MklDnnType<U>(),
552                                      memory::format_tag::nc);
553     auto diff_weights_desc = weights_desc;
554 
555     // Forward batch-normalization descriptor and primitive descriptor.
556     // Adding this back due to type difference with context.flags
557     auto bn_flags = bwdParams.training
558                         ? dnnl::normalization_flags::use_scale_shift
559                         : (dnnl::normalization_flags::use_scale_shift |
560                            dnnl::normalization_flags::use_global_stats);
561     auto fwd_desc = batch_normalization_forward::desc(
562         prop_kind::forward_training, src_md, bwdParams.eps, bn_flags);
563     auto fwd_pd = BatchNormFwdPd(fwd_desc, cpu_engine_);
564 
565     // Backward batch-normalization primitive.
566     // For inference, specify use_global_stats
567     //   1. on fwd propagation, use mean and variance provided as inputs.
568     //   2. on bwd propagation, mean and variance are considered as constants.
569     //      Thus, reduce the amount of MKL computation.
570     auto bwd_desc = batch_normalization_backward::desc(
571         prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, bn_flags);
572     context_.bwd_pd.reset(new BatchNormBwdPd(bwd_desc, cpu_engine_, fwd_pd));
573 
574     // Create memory primitives.
575     context_.src_mem.reset(new memory(src_md, cpu_engine_, DummyData));
576     context_.diff_dst_mem.reset(
577         new memory(diff_dst_md, cpu_engine_, DummyData));
578     context_.variance_mem.reset(
579         new memory(variance_desc, cpu_engine_, DummyData));
580     context_.mean_mem.reset(new memory(mean_desc, cpu_engine_, DummyData));
581     context_.weights_mem.reset(
582         new memory(weights_desc, cpu_engine_, DummyData));
583     context_.diff_weights_mem.reset(
584         new memory(diff_weights_desc, cpu_engine_, DummyData));
585     context_.diff_src_mem.reset(new memory(src_md, cpu_engine_, DummyData));
586 
587     context_.bn_bwd.reset(new batch_normalization_backward(*context_.bwd_pd));
588     context_.net_args.push_back(
589         {{DNNL_ARG_SRC, *context_.src_mem},
590          {DNNL_ARG_MEAN, *context_.mean_mem},
591          {DNNL_ARG_VARIANCE, *context_.variance_mem},
592          {DNNL_ARG_DIFF_DST, *context_.diff_dst_mem},
593          {DNNL_ARG_WEIGHTS, *context_.weights_mem},
594          {DNNL_ARG_DIFF_SRC, *context_.diff_src_mem},
595          {DNNL_ARG_DIFF_WEIGHTS, *context_.diff_weights_mem}});
596     context_.bwd_primitives.push_back(*context_.bn_bwd);
597   }
598 
599   struct BatchNormBwdContext context_;
600 
601 #ifdef DNNL_AARCH64_USE_ACL
602   mutex primitive_execution_mu_;
603 #endif
604 };
605 
606 template <typename T, typename U>
607 class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
608  public:
Get(const MklBatchNormBwdParams & bwdParams)609   static MklFusedBatchNormBwdPrimitive<T, U>* Get(
610       const MklBatchNormBwdParams& bwdParams) {
611     auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T, U>*>(
612         MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance()
613             .GetBatchNormBwd(bwdParams));
614     if (bn_bwd == nullptr) {
615       bn_bwd = new MklFusedBatchNormBwdPrimitive<T, U>(bwdParams);
616       MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormBwd(
617           bwdParams, bn_bwd);
618     }
619     return bn_bwd;
620   }
621 
GetInstance()622   static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() {
623     static MklFusedBatchNormBwdPrimitiveFactory instance_;
624     return instance_;
625   }
626 
627  private:
MklFusedBatchNormBwdPrimitiveFactory()628   MklFusedBatchNormBwdPrimitiveFactory() {}
~MklFusedBatchNormBwdPrimitiveFactory()629   ~MklFusedBatchNormBwdPrimitiveFactory() {}
630 
CreateKey(const MklBatchNormBwdParams & bwdParams)631   static string CreateKey(const MklBatchNormBwdParams& bwdParams) {
632     string prefix = "bn_bwd";
633     FactoryKeyCreator key_creator;
634     key_creator.AddAsKey(prefix);
635     key_creator.AddAsKey(bwdParams.src_dims);
636     key_creator.AddAsKey(bwdParams.diff_dst_dims);
637     key_creator.AddAsKey<int>(bwdParams.depth);
638     key_creator.AddAsKey<float>(bwdParams.eps);
639     key_creator.AddAsKey<bool>(bwdParams.training);
640     key_creator.AddAsKey<TensorFormat>(bwdParams.data_format);
641     key_creator.AddAsKey(typeid(T).name());
642     key_creator.AddAsKey(typeid(U).name());
643     return key_creator.GetKey();
644   }
645 
GetBatchNormBwd(const MklBatchNormBwdParams & bwdParams)646   MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) {
647     string key = CreateKey(bwdParams);
648     return this->GetOp(key);
649   }
650 
SetBatchNormBwd(const MklBatchNormBwdParams & bwdParams,MklPrimitive * op)651   void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams,
652                        MklPrimitive* op) {
653     string key = CreateKey(bwdParams);
654     this->SetOp(key, op);
655   }
656 };
657 
658 //  Adding a third parameter to the template to support FusedBatchNormV3
659 //  with MKL. This is different from default where the classes are
660 //  derived. Moves enabling to compile-time rather than runtime.
661 template <typename Device, typename T, typename U, bool reserved_space,
662           bool is_batch_norm_ex = false, bool native_format = false>
663 class MklFusedBatchNormOp : public OpKernel {
664  public:
MklFusedBatchNormOp(OpKernelConstruction * context)665   explicit MklFusedBatchNormOp(OpKernelConstruction* context)
666       : OpKernel(context) {
667     float epsilon;
668     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
669     epsilon_ = epsilon;
670     float exponential_avg_factor;
671     OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor",
672                                              &exponential_avg_factor));
673     exponential_avg_factor_ = static_cast<U>(exponential_avg_factor);
674     string tensor_format;
675     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
676     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
677                 errors::InvalidArgument("Invalid data format"));
678     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
679     depth_ = 0;
680     mean_values_ = nullptr;
681     variance_values_ = nullptr;
682 
683     if (!is_batch_norm_ex) {
684       activation_mode_ = FusedBNActivationMode::kIdentity;
685     } else {
686       int num_side_inputs;
687       OP_REQUIRES_OK(context,
688                      context->GetAttr("num_side_inputs", &num_side_inputs));
689       // Currently _MKLFusedBatchNormEx do not support "SideInput"
690       OP_REQUIRES(context, num_side_inputs == 0,
691                   errors::InvalidArgument(
692                       "_MKLFusedBatchNorm do not support side input now."));
693 
694       OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
695       OP_REQUIRES(context, activation_mode_ == FusedBNActivationMode::kRelu,
696                   errors::InvalidArgument(
697                       "_MKLFusedBatchNorm only support Relu activation"));
698     }
699   }
700 
Compute(OpKernelContext * context)701   void Compute(OpKernelContext* context) override {
702     try {
703       const size_t kSrcIndex = 0;       // index of src input tensor
704       const size_t kScaleIndex = 1;     // index of scale tensor
705       const size_t kShiftIndex = 2;     // index of shift tensor
706       const size_t kMeanIndex = 3;      // index of est_mean tensor
707       const size_t kVarianceIndex = 4;  // index of est_variance tensor
708 
709       const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
710       const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
711       const Tensor& shift_tensor = MklGetInput(context, kShiftIndex);
712       const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex);
713       const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex);
714 
715       TensorShape tf_shape_src;
716       MklDnnShape dnn_shape_src;
717       GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format);
718 
719       if (dnn_shape_src.IsMklTensor()) {
720         tf_shape_src = dnn_shape_src.GetTfShape();
721         OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
722                     errors::InvalidArgument("input must be 4-dimensional",
723                                             src_tensor.shape().DebugString()));
724       } else {
725         tf_shape_src = src_tensor.shape();
726         OP_REQUIRES(context, src_tensor.dims() == 4,
727                     errors::InvalidArgument("input must be 4-dimensional",
728                                             src_tensor.shape().DebugString()));
729       }
730       OP_REQUIRES(context, scale_tensor.dims() == 1,
731                   errors::InvalidArgument("scale must be 1-dimensional",
732                                           scale_tensor.shape().DebugString()));
733       OP_REQUIRES(context, shift_tensor.dims() == 1,
734                   errors::InvalidArgument("offset must be 1-dimensional",
735                                           shift_tensor.shape().DebugString()));
736       OP_REQUIRES(
737           context, est_mean_tensor.dims() == 1,
738           errors::InvalidArgument("estimated_mean must be 1-dimensional",
739                                   est_mean_tensor.shape().DebugString()));
740       OP_REQUIRES(
741           context, est_variance_tensor.dims() == 1,
742           errors::InvalidArgument("estimated_variance must be 1-dimensional",
743                                   est_variance_tensor.shape().DebugString()));
744 
745       int num_channels;
746       if (dnn_shape_src.IsMklTensor()) {
747         num_channels = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
748       } else {
749         num_channels = GetTensorDim(src_tensor, tensor_format_, 'C');
750       }
751 
752       OP_REQUIRES(context, scale_tensor.NumElements() == num_channels,
753                   errors::InvalidArgument(
754                       "scale must have the same number of elements "
755                       "as the channels of x, got ",
756                       scale_tensor.NumElements(), " and ", num_channels));
757 
758       OP_REQUIRES(context, shift_tensor.NumElements() == num_channels,
759                   errors::InvalidArgument(
760                       "offset must have the same number of elements "
761                       "as the channels of x, got ",
762                       shift_tensor.NumElements(), " and ", num_channels));
763       if (!is_training_ || exponential_avg_factor_ != 1.) {
764         std::string prefix_msg = is_training_
765                                      ? "When exponential_avg_factor != 1"
766                                      : "When is_training=false";
767         OP_REQUIRES(context, est_mean_tensor.NumElements() == num_channels,
768                     errors::InvalidArgument(
769                         prefix_msg,
770                         ", mean must have the same number "
771                         "of elements as the channels of x, got ",
772                         est_mean_tensor.NumElements(), " and ", num_channels));
773         OP_REQUIRES(
774             context, est_variance_tensor.NumElements() == num_channels,
775             errors::InvalidArgument(
776                 prefix_msg,
777                 ", variance must have the same "
778                 "number of elements as the channels of x, got ",
779                 est_variance_tensor.NumElements(), " and ", num_channels));
780       }
781 
782       // Handle the special case: input with 0 element and 0 batch size.
783       Tensor* dst_tensor = nullptr;
784       TensorShape workspace_tf_shape;
785       if (tf_shape_src.num_elements() == 0) {
786         size_t workspace_bytes = 0;
787         workspace_tf_shape.AddDim(workspace_bytes);
788         HandleEmptyInput(context, tf_shape_src, workspace_tf_shape,
789                          scale_tensor.shape(), &dst_tensor);
790         return;
791       }
792 
793       if (dnn_shape_src.IsMklTensor())
794         depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
795       else
796         ExtractParams(context);
797 
798       // Index of output tensor(diff_src).
799       const size_t kDstIndex = 0;
800 
801       // Allocate 5 output TF tensors.
802       Tensor* batch_mean_tensor = nullptr;
803       Tensor* batch_variance_tensor = nullptr;
804       Tensor* saved_mean_tensor = nullptr;
805       Tensor* saved_variance_tensor = nullptr;
806       Tensor* reserved_space_tensor = nullptr;
807 
808       MklDnnData<T> src(&cpu_engine_);
809       MklDnnData<U> weights(&cpu_engine_);
810       MklDnnData<U> wksp(&cpu_engine_);
811 
812       memory::format_tag dnn_fmt;
813       MklTensorFormat mkl_tensor_fmt;
814       if (dnn_shape_src.IsMklTensor()) {
815         if (dnn_shape_src.IsTensorInNCHWFormat()) {
816           dnn_fmt = memory::format_tag::nchw;
817           mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW;
818         } else {
819           dnn_fmt = memory::format_tag::nhwc;
820           mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC;
821         }
822       } else {
823         mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_);
824         dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt);
825       }
826 
827       // Set src memory descriptor.
828       memory::dims src_dims =
829           dnn_shape_src.IsMklTensor()
830               ? dnn_shape_src.GetSizesAsMklDnnDims()
831               : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
832 
833       auto src_md = dnn_shape_src.IsMklTensor()
834                         ? dnn_shape_src.GetMklLayout()
835                         : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
836 
837       MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_,
838                                       tensor_format_, src_md, activation_mode_);
839 
840       // Get forward batch-normalization op from the primitive caching pool.
841       MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd =
842           MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams);
843 
844       // Allocate workspace tensor
845       U* ws_data = nullptr;
846       if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) {
847         memory::desc workspace_md =
848             bn_fwd->GetBatchNormFwdPd()->workspace_desc();
849         size_t workspace_bytes = workspace_md.get_size();
850         workspace_tf_shape.AddDim(workspace_bytes);
851 
852         AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
853                           &batch_mean_tensor, &batch_variance_tensor,
854                           &saved_mean_tensor, &saved_variance_tensor,
855                           &reserved_space_tensor);
856         if (reserved_space) {
857           wksp.SetUsrMem(workspace_md, reserved_space_tensor);
858           ws_data = static_cast<U*>(wksp.GetOpMem().get_data_handle());
859         }
860       } else {
861         // There is actually no workspace tensor out, so we make a dummy one.
862         size_t workspace_bytes = 0;
863         workspace_tf_shape.AddDim(workspace_bytes);
864         AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape,
865                           &batch_mean_tensor, &batch_variance_tensor,
866                           &saved_mean_tensor, &saved_variance_tensor,
867                           &reserved_space_tensor);
868       }
869 
870       if (is_training_)
871         SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
872       else
873         SetMeanVariance(est_mean_tensor, est_variance_tensor);
874 
875       // oneDNN packs scale & shift as "weights":
876       // <scale>...<scale><shift>...<shift>
877       weights.AllocateBuffer(2 * depth_ * sizeof(U));
878       U* weights_data = reinterpret_cast<U*>(weights.GetAllocatedBuffer());
879       const U* scale_tf = scale_tensor.flat<U>().data();
880       const U* shift_tf = shift_tensor.flat<U>().data();
881 
882       std::memcpy(weights_data, scale_tf, depth_ * sizeof(U));
883       std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(U));
884       char* saved_mean_data_tf =
885           reinterpret_cast<char*>(saved_mean_tensor->flat<U>().data());
886       std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_),
887                   depth_ * sizeof(U));
888 
889       char* saved_variance_data_tf =
890           reinterpret_cast<char*>(saved_variance_tensor->flat<U>().data());
891       std::memcpy(saved_variance_data_tf,
892                   reinterpret_cast<char*>(variance_values_),
893                   depth_ * sizeof(U));
894 
895       // Check if reorder is needed for src.
896       const T* src_data = nullptr;
897       std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd();
898       if (!native_format && src_md != bn_fwd_pd->src_desc()) {
899         src.SetUsrMem(src_md, &src_tensor);
900         src.CheckReorderToOpMem(bn_fwd_pd->src_desc(), cpu_engine_, context);
901         src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
902       } else {
903         src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
904       }
905 
906       // Allocate output (dst) tensor
907       MklDnnShape dnn_shape_dst;
908       TensorShape tf_shape_dst;
909       dnn_shape_dst.SetMklTensor(true);
910       auto dst_pd = bn_fwd->GetDstPd();
911       dnn_shape_dst.SetMklLayout(&dst_pd);
912       dnn_shape_dst.SetElemType(MklDnnType<T>());
913       auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension()
914                                                : src_tensor.shape().dims();
915       dnn_shape_dst.SetTfLayout(ndims, src_dims, mkl_tensor_fmt);
916       tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
917       if (native_format) {
918         tf_shape_dst = dnn_shape_dst.GetTfShape();
919       }
920       AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst,
921                                 dnn_shape_dst, native_format);
922 
923       U* weights_op_data = weights_data;
924       U* mean_op_data = saved_mean_tensor->flat<U>().data();
925       U* variance_op_data = saved_variance_tensor->flat<U>().data();
926       T* dst_data = dst_tensor->flat<T>().data();
927 
928       // Execute
929       std::shared_ptr<stream> fwd_cpu_stream;
930       MklDnnThreadPool eigen_tp(context);
931       fwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_fwd->GetEngine()));
932       bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
933                       variance_op_data, fwd_cpu_stream, ws_data);
934       float adjust_factor = 1.0;
935       if (is_training_) {
936         size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3];
937         size_t adjust_size = (orig_size > 1) ? (orig_size - 1) : 1;
938         adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
939       }
940 
941       auto mean_data = reinterpret_cast<U*>(saved_mean_data_tf);
942       auto variance_data = reinterpret_cast<U*>(saved_variance_data_tf);
943       auto batch_mean_data = batch_mean_tensor->flat<U>().data();
944       auto batch_variance_data = batch_variance_tensor->flat<U>().data();
945       auto est_mean_data = est_mean_tensor.flat<U>().data();
946       auto est_variance_data = est_variance_tensor.flat<U>().data();
947       if (is_training_) {
948         if (exponential_avg_factor_ == U(1.0)) {
949           for (int k = 0; k < depth_; k++) {
950             batch_mean_data[k] = mean_data[k];
951             batch_variance_data[k] =
952                 static_cast<U>(adjust_factor) * variance_data[k];
953           }
954         } else {
955           U one_minus_factor = U(1.0) - exponential_avg_factor_;
956           for (int k = 0; k < depth_; k++) {
957             batch_mean_data[k] = one_minus_factor * est_mean_data[k] +
958                                  exponential_avg_factor_ * mean_data[k];
959             batch_variance_data[k] = one_minus_factor * est_variance_data[k] +
960                                      exponential_avg_factor_ *
961                                          static_cast<U>(adjust_factor) *
962                                          variance_data[k];
963           }
964         }
965       } else {
966         std::memcpy(batch_mean_data, mean_data, depth_ * sizeof(U));
967         std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(U));
968       }
969     } catch (dnnl::error& e) {
970       string error_msg = "Status: " + std::to_string(e.status) +
971                          ", message: " + string(e.message) + ", in file " +
972                          string(__FILE__) + ":" + std::to_string(__LINE__);
973       OP_REQUIRES_OK(
974           context,
975           errors::Aborted("Operation received an exception:", error_msg));
976     }
977   }
978 
979  private:
980   float epsilon_;
981   U exponential_avg_factor_;
982   TensorFormat tensor_format_;
983   bool is_training_;
984   U* mean_values_;
985   U* variance_values_;
986   size_t depth_;  // Batch normalization is performed for per channel.
987   FusedBNActivationMode activation_mode_;
988   engine cpu_engine_ = engine(engine::kind::cpu, 0);
989 
ExtractParams(OpKernelContext * context)990   void ExtractParams(OpKernelContext* context) {
991     const Tensor& input = MklGetInput(context, 0);
992     depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
993   }
994 
SetMeanVariance(const Tensor & mean,const Tensor & variance)995   void SetMeanVariance(const Tensor& mean, const Tensor& variance) {
996     mean_values_ = reinterpret_cast<U*>(const_cast<U*>(mean.flat<U>().data()));
997     variance_values_ =
998         reinterpret_cast<U*>(const_cast<U*>(variance.flat<U>().data()));
999   }
1000 
HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape workspace_tf_shape,TensorShape tf_shape_scale,Tensor ** dst_tensor)1001   void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
1002                         TensorShape workspace_tf_shape,
1003                         TensorShape tf_shape_scale, Tensor** dst_tensor) {
1004     DCHECK(dst_tensor);
1005 
1006     const size_t kDstIndex = 0;
1007     MklDnnShape dnn_shape_dst;
1008     dnn_shape_dst.SetMklTensor(false);
1009     AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src,
1010                               dnn_shape_dst, native_format);
1011     DCHECK(*dst_tensor);
1012     memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0,
1013            (*dst_tensor)->tensor_data().size());
1014 
1015     Tensor* batch_mean_tensor = nullptr;
1016     Tensor* batch_variance_tensor = nullptr;
1017     Tensor* saved_mean_tensor = nullptr;
1018     Tensor* saved_variance_tensor = nullptr;
1019     Tensor* reserved_space_tensor = nullptr;
1020     AllocateTFOutputs(context, tf_shape_scale, workspace_tf_shape,
1021                       &batch_mean_tensor, &batch_variance_tensor,
1022                       &saved_mean_tensor, &saved_variance_tensor,
1023                       &reserved_space_tensor);
1024   }
1025 
AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale,TensorShape workspace_tf_shape,Tensor ** batch_mean_tensor,Tensor ** batch_variance_tensor,Tensor ** saved_mean_tensor,Tensor ** saved_variance_tensor,Tensor ** reserved_space_tensor)1026   void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale,
1027                          TensorShape workspace_tf_shape,
1028                          Tensor** batch_mean_tensor,
1029                          Tensor** batch_variance_tensor,
1030                          Tensor** saved_mean_tensor,
1031                          Tensor** saved_variance_tensor,
1032                          Tensor** reserved_space_tensor) {
1033     DCHECK(batch_mean_tensor);
1034     DCHECK(batch_variance_tensor);
1035     DCHECK(saved_mean_tensor);
1036     DCHECK(saved_variance_tensor);
1037 
1038     const size_t kBatchMeanIndex = 1;
1039     const size_t kBatchVarianceIndex = 2;
1040     const size_t kSavedMeanIndex = 3;
1041     const size_t kSavedVarianceIndex = 4;
1042     const size_t kReservedSpaceIndex = 5;
1043 
1044     // Allocate batch mean output tensor.
1045     MklDnnShape mkl_shape_batch_mean;
1046     mkl_shape_batch_mean.SetMklTensor(false);
1047     AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor,
1048                               tf_shape_scale, mkl_shape_batch_mean,
1049                               native_format);
1050     DCHECK(*batch_mean_tensor);
1051 
1052     // Set NAN mean value in case of empty input tensor
1053     int num_elements = tf_shape_scale.num_elements();
1054     auto batch_mean_data = (*batch_mean_tensor)->flat<U>().data();
1055     std::fill_n(batch_mean_data, num_elements, static_cast<U>(NAN));
1056 
1057     // Allocate batch variance output tensor.
1058     MklDnnShape mkl_shape_batch_variance;
1059     mkl_shape_batch_variance.SetMklTensor(false);
1060     AllocateOutputSetMklShape(context, kBatchVarianceIndex,
1061                               batch_variance_tensor, tf_shape_scale,
1062                               mkl_shape_batch_variance, native_format);
1063     DCHECK(*batch_variance_tensor);
1064 
1065     // Set NAN variance value in case of empty input tensor
1066     auto batch_variance_data = (*batch_variance_tensor)->flat<U>().data();
1067     std::fill_n(batch_variance_data, num_elements, static_cast<U>(NAN));
1068     // Mean and variance (without Bessel's correction) saved for backward
1069     // computation to serve as pre-computed mean and variance.
1070     MklDnnShape mkl_shape_saved_mean;
1071     mkl_shape_saved_mean.SetMklTensor(false);
1072     AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor,
1073                               tf_shape_scale, mkl_shape_saved_mean,
1074                               native_format);
1075     DCHECK(*saved_mean_tensor);
1076 
1077     // Set 0 mean value in case of empty input tensor
1078     auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data();
1079     std::fill_n(saved_mean_data, num_elements, static_cast<U>(0));
1080 
1081     MklDnnShape mkl_shape_saved_variance;
1082     mkl_shape_saved_variance.SetMklTensor(false);
1083     AllocateOutputSetMklShape(context, kSavedVarianceIndex,
1084                               saved_variance_tensor, tf_shape_scale,
1085                               mkl_shape_saved_variance, native_format);
1086     DCHECK(*saved_variance_tensor);
1087 
1088     // Set 0 variance value in case of empty input tensor
1089     auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data();
1090     std::fill_n(saved_variance_data, num_elements, static_cast<U>(0));
1091 
1092     // Changes to support reserved_space_3 parameter in FusedBatchNormV3.
1093     if (reserved_space) {
1094       DCHECK(reserved_space_tensor != nullptr);
1095 
1096       MklDnnShape mkl_shape_reserved_space;
1097       mkl_shape_reserved_space.SetMklTensor(false);
1098       AllocateOutputSetMklShape(context, kReservedSpaceIndex,
1099                                 reserved_space_tensor, workspace_tf_shape,
1100                                 mkl_shape_reserved_space, native_format);
1101       DCHECK((*reserved_space_tensor) != nullptr);
1102     }
1103   }
1104 };
1105 
1106 template <typename Device, typename T, typename U, bool reserved_space,
1107           bool native_format = false>
1108 class MklFusedBatchNormGradOp : public OpKernel {
1109  public:
MklFusedBatchNormGradOp(OpKernelConstruction * context)1110   explicit MklFusedBatchNormGradOp(OpKernelConstruction* context)
1111       : OpKernel(context) {
1112     float epsilon;
1113     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1114     epsilon_ = epsilon;
1115     string tensor_format;
1116     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1117     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1118                 errors::InvalidArgument("Invalid data format"));
1119     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1120     depth_ = 0;
1121   }
1122 
Compute(OpKernelContext * context)1123   void Compute(OpKernelContext* context) override {
1124     try {
1125       const size_t kDiffDstIndex = 0;        // index of diff_dst tensor
1126       const size_t kSrcIndex = 1;            // index of src input tensor
1127       const size_t kScaleIndex = 2;          // index of scale tensor
1128       const size_t kMeanIndex = 3;           // index of saved_mean tensor
1129       const size_t kVarianceIndex = 4;       // index of saved_variance tensor
1130       const size_t kReservedSpaceIndex = 5;  // index of reserved space 3 tensor
1131 
1132       const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex);
1133       const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
1134       const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
1135       const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex);
1136       const Tensor& saved_variance_tensor =
1137           MklGetInput(context, kVarianceIndex);
1138       const Tensor& reserved_space_tensor =
1139           (reserved_space) ? MklGetInput(context, kReservedSpaceIndex)
1140                            : Tensor();
1141 
1142       MklDnnShape dnn_shape_src, dnn_shape_diff_dst;
1143       GetMklShape(context, kSrcIndex, &dnn_shape_src, native_format);
1144       GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst, native_format);
1145 
1146       TensorShape tf_shape_src, tf_shape_diff_dst;
1147       if (dnn_shape_diff_dst.IsMklTensor()) {
1148         tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape();
1149         OP_REQUIRES(
1150             context, dnn_shape_diff_dst.GetDimension() == 4,
1151             errors::InvalidArgument("input must be 4-dimensional",
1152                                     diff_dst_tensor.shape().DebugString()));
1153       } else {
1154         tf_shape_diff_dst = diff_dst_tensor.shape();
1155         OP_REQUIRES(
1156             context, diff_dst_tensor.dims() == 4,
1157             errors::InvalidArgument("input must be 4-dimensional",
1158                                     diff_dst_tensor.shape().DebugString()));
1159       }
1160 
1161       if (dnn_shape_src.IsMklTensor()) {
1162         tf_shape_src = dnn_shape_src.GetTfShape();
1163         OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
1164                     errors::InvalidArgument("input must be 4-dimensional",
1165                                             src_tensor.shape().DebugString()));
1166       } else {
1167         tf_shape_src = src_tensor.shape();
1168         OP_REQUIRES(context, src_tensor.dims() == 4,
1169                     errors::InvalidArgument("input must be 4-dimensional",
1170                                             src_tensor.shape().DebugString()));
1171       }
1172 
1173       OP_REQUIRES(context, scale_tensor.dims() == 1,
1174                   errors::InvalidArgument("scale must be 1-dimensional",
1175                                           scale_tensor.shape().DebugString()));
1176       OP_REQUIRES(
1177           context, saved_mean_tensor.dims() == 1,
1178           errors::InvalidArgument("saved mean must be 1-dimensional",
1179                                   saved_mean_tensor.shape().DebugString()));
1180 
1181       OP_REQUIRES(
1182           context, saved_variance_tensor.dims() == 1,
1183           errors::InvalidArgument("saved variance must be 1-dimensional",
1184                                   saved_variance_tensor.shape().DebugString()));
1185 
1186       OP_REQUIRES(context, tf_shape_src == tf_shape_diff_dst,
1187                   errors::InvalidArgument(
1188                       "x and y_backprop must have same shape, but x has shape ",
1189                       src_tensor.shape(), " and y_backprop has shape ",
1190                       diff_dst_tensor.shape()));
1191 
1192       int num_channels;
1193       if (dnn_shape_src.IsMklTensor()) {
1194         num_channels = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
1195       } else {
1196         num_channels = GetTensorDim(src_tensor, tensor_format_, 'C');
1197       }
1198       OP_REQUIRES(context, scale_tensor.NumElements() == num_channels,
1199                   errors::InvalidArgument(
1200                       "scale must have the same number of elements "
1201                       "as the channels of x, got ",
1202                       scale_tensor.NumElements(), " and ", num_channels));
1203       OP_REQUIRES(context, saved_mean_tensor.NumElements() == num_channels,
1204                   errors::InvalidArgument(
1205                       "reserve_space_1 must have the same number of "
1206                       "elements as the channels of x, got ",
1207                       saved_mean_tensor.NumElements(), " and ", num_channels));
1208       OP_REQUIRES(
1209           context, saved_variance_tensor.NumElements() == num_channels,
1210           errors::InvalidArgument(
1211               "reserve_space_2 must have the same number of "
1212               "elements as the channels of x, got ",
1213               saved_variance_tensor.NumElements(), " and ", num_channels));
1214 
1215       // Handle the special case: input with 0 element and 0 batch size.
1216       Tensor* diff_src_tensor = nullptr;
1217       if (tf_shape_src.num_elements() == 0 ||
1218           tf_shape_diff_dst.num_elements() == 0) {
1219         HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
1220                          &diff_src_tensor);
1221         return;
1222       }
1223 
1224       if (dnn_shape_src.IsMklTensor()) {
1225         depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
1226       } else if (dnn_shape_diff_dst.IsMklTensor()) {
1227         depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C);
1228       } else {
1229         ExtractParams(context);
1230       }
1231 
1232       memory::format_tag dnn_fmt;
1233       MklTensorFormat mkl_tensor_fmt;
1234       if (dnn_shape_src.IsMklTensor()) {
1235         if (dnn_shape_src.IsTensorInNCHWFormat()) {
1236           dnn_fmt = memory::format_tag::nchw;
1237           mkl_tensor_fmt = MklTensorFormat::FORMAT_NCHW;
1238         } else {
1239           dnn_fmt = memory::format_tag::nhwc;
1240           mkl_tensor_fmt = MklTensorFormat::FORMAT_NHWC;
1241         }
1242       } else {
1243         mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_);
1244         dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt);
1245       }
1246 
1247       MklDnnData<T> src(&cpu_engine_);
1248       MklDnnData<T> diff_dst(&cpu_engine_);
1249       MklDnnData<U> weights(&cpu_engine_);
1250       MklDnnData<U> diff_weights(&cpu_engine_);
1251 
1252       memory::dims src_dims =
1253           dnn_shape_src.IsMklTensor()
1254               ? dnn_shape_src.GetSizesAsMklDnnDims()
1255               : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
1256       memory::dims diff_dst_dims =
1257           dnn_shape_diff_dst.IsMklTensor()
1258               ? dnn_shape_diff_dst.GetSizesAsMklDnnDims()
1259               : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
1260                                           tensor_format_);
1261 
1262       // Set src and diff_dst primitive descriptors.
1263       memory::desc src_md =
1264           dnn_shape_src.IsMklTensor()
1265               ? dnn_shape_src.GetMklLayout()
1266               : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt);
1267       memory::desc diff_dst_md =
1268           dnn_shape_diff_dst.IsMklTensor()
1269               ? dnn_shape_diff_dst.GetMklLayout()
1270               : memory::desc(diff_dst_dims, MklDnnType<T>(), dnn_fmt);
1271 
1272       MklDnnData<T> reorder_src(&cpu_engine_);
1273       MklDnnData<T> reorder_diff_dst(&cpu_engine_);
1274       T* diff_dst_data =
1275           static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
1276       T* src_data =
1277           static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
1278 
1279       if (!native_format) {
1280         // oneDNN requires src and diff_dst to be in same memory layout, either
1281         // blocked or native format. If these inputs are in different formats,
1282         // convert the one in native format to blocked format as oneDNN gives
1283         // better performance for blocked format.
1284         if (dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
1285           reorder_diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
1286           reorder_diff_dst.CheckReorderToOpMem(src_md, cpu_engine_, context);
1287           diff_dst_md = src_md;
1288           diff_dst_data =
1289               static_cast<T*>(reorder_diff_dst.GetOpMem().get_data_handle());
1290         } else if (!dnn_shape_src.IsMklTensor() &&
1291                    dnn_shape_diff_dst.IsMklTensor()) {
1292           reorder_src.SetUsrMem(src_md, &src_tensor);
1293           reorder_src.CheckReorderToOpMem(diff_dst_md, cpu_engine_, context);
1294           src_md = diff_dst_md;
1295           src_data = static_cast<T*>(reorder_src.GetOpMem().get_data_handle());
1296         }
1297       }
1298 
1299       // weights -- oneDNN packs scales/shifts as weights in order
1300       // of scale, ..., scale, shift, ...., shift
1301       weights.AllocateBuffer(2 * depth_ * sizeof(U));
1302       U* weights_data_tf = reinterpret_cast<U*>(weights.GetAllocatedBuffer());
1303       const U* scale_tf = scale_tensor.flat<U>().data();
1304       for (int k = 0; k < depth_; k++) {
1305         weights_data_tf[k] = scale_tf[k];
1306         weights_data_tf[k + depth_] = static_cast<U>(0);
1307       }
1308 
1309       diff_weights.AllocateBuffer(2 * depth_ * sizeof(U));
1310 
1311       MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_,
1312                                       is_training_, tensor_format_, src_md,
1313                                       diff_dst_md);
1314       MklFusedBatchNormBwdPrimitive<T, U>* bn_bwd =
1315           MklFusedBatchNormBwdPrimitiveFactory<T, U>::Get(bwdParams);
1316 
1317       // Check if diff_dst input needs to be reordered
1318       std::shared_ptr<BatchNormBwdPd> bn_bwd_pd = bn_bwd->GetBatchNormBwdPd();
1319       if (!native_format && diff_dst_md != bn_bwd_pd->diff_dst_desc()) {
1320         diff_dst.SetUsrMem(diff_dst_md, diff_dst_data);
1321         diff_dst.CheckReorderToOpMem(bn_bwd_pd->diff_dst_desc(), cpu_engine_,
1322                                      context);
1323         diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
1324       }
1325 
1326       if (!native_format && (src_md != bn_bwd_pd->src_desc())) {
1327         src.SetUsrMem(src_md, src_data);
1328         src.CheckReorderToOpMem(bn_bwd_pd->src_desc(), cpu_engine_, context);
1329         src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
1330       }
1331 
1332       // Indices of output tensors
1333       const size_t kDiffSrcIndex = 0;
1334 
1335       // Allocate output tensor diff_src, always set as oneDNN layout.
1336       MklDnnShape dnn_shape_diff_src;
1337       TensorShape tf_shape_diff_src;
1338       dnn_shape_diff_src.SetMklTensor(true);
1339       auto diff_src_pd = bn_bwd->GetDiffSrcPd();
1340       dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
1341       dnn_shape_diff_src.SetElemType(MklDnnType<T>());
1342       dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, mkl_tensor_fmt);
1343       dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_);
1344       tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
1345       if (native_format) {
1346         tf_shape_diff_src = dnn_shape_diff_src.GetTfShape();
1347       }
1348       AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor,
1349                                 tf_shape_diff_src, dnn_shape_diff_src,
1350                                 native_format);
1351 
1352       U* mean_data =
1353           static_cast<U*>(const_cast<U*>(saved_mean_tensor.flat<U>().data()));
1354       U* variance_data = static_cast<U*>(
1355           const_cast<U*>(saved_variance_tensor.flat<U>().data()));
1356       U* weights_data = weights_data_tf;
1357       T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data());
1358       U* diff_weights_data = static_cast<U*>(diff_weights.GetAllocatedBuffer());
1359 
1360       U* res_space_data =
1361           ((reserved_space) ? static_cast<U*>(const_cast<U*>(
1362                                   reserved_space_tensor.flat<U>().data()))
1363                             : nullptr);
1364 
1365       // Execute
1366       std::shared_ptr<stream> bwd_cpu_stream;
1367       MklDnnThreadPool eigen_tp(context);
1368       bwd_cpu_stream.reset(CreateStream(&eigen_tp, bn_bwd->GetEngine()));
1369       bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data,
1370                       weights_data, diff_src_data, diff_weights_data,
1371                       res_space_data, bwd_cpu_stream);
1372       // Allocate output TF tensors diff_scale and diff_shift.
1373       Tensor* diff_scale_tensor = nullptr;
1374       Tensor* diff_shift_tensor = nullptr;
1375       AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor,
1376                         &diff_shift_tensor);
1377 
1378       // Copy data for tensors diff_scale and diff_shift.
1379       auto diff_scale_data = diff_scale_tensor->flat<U>().data();
1380       auto diff_shift_data = diff_shift_tensor->flat<U>().data();
1381       std::memcpy(reinterpret_cast<char*>(diff_scale_data),
1382                   reinterpret_cast<char*>(diff_weights_data),
1383                   depth_ * sizeof(U));
1384       std::memcpy(reinterpret_cast<char*>(diff_shift_data),
1385                   reinterpret_cast<char*>(diff_weights_data + depth_),
1386                   depth_ * sizeof(U));
1387     } catch (dnnl::error& e) {
1388       string error_msg = "Status: " + std::to_string(e.status) +
1389                          ", message: " + string(e.message) + ", in file " +
1390                          string(__FILE__) + ":" + std::to_string(__LINE__);
1391       OP_REQUIRES_OK(
1392           context,
1393           errors::Aborted("Operation received an exception:", error_msg));
1394     }
1395   }
1396 
1397  private:
1398   float epsilon_;
1399   TensorFormat tensor_format_;
1400   size_t depth_;  // Batch normalization is performed for per channel.
1401   bool is_training_;
1402   engine cpu_engine_ = engine(engine::kind::cpu, 0);
1403 
ExtractParams(OpKernelContext * context)1404   void ExtractParams(OpKernelContext* context) {
1405     const Tensor& input = MklGetInput(context, 0);
1406     depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
1407   }
1408 
HandleEmptyInput(OpKernelContext * context,TensorShape tf_shape_src,TensorShape tf_shape_scale_shift,Tensor ** diff_src_tensor)1409   void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
1410                         TensorShape tf_shape_scale_shift,
1411                         Tensor** diff_src_tensor) {
1412     const size_t kDiffSrcIndex = 0;
1413 
1414     MklDnnShape dnn_shape_diff_src;
1415     dnn_shape_diff_src.SetMklTensor(false);
1416     AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor,
1417                               tf_shape_src, dnn_shape_diff_src, native_format);
1418     auto diff_src_data = (*diff_src_tensor)->flat<T>().data();
1419     std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(),
1420                 static_cast<T>(0));
1421 
1422     Tensor* diff_scale_tensor = nullptr;
1423     Tensor* diff_shift_tensor = nullptr;
1424     AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor,
1425                       &diff_shift_tensor);
1426   }
1427 
AllocateTFOutputs(OpKernelContext * context,TensorShape tf_shape_scale_shift,Tensor ** diff_scale_tensor,Tensor ** diff_shift_tensor)1428   void AllocateTFOutputs(OpKernelContext* context,
1429                          TensorShape tf_shape_scale_shift,
1430                          Tensor** diff_scale_tensor,
1431                          Tensor** diff_shift_tensor) {
1432     DCHECK(diff_scale_tensor);
1433     DCHECK(diff_shift_tensor);
1434 
1435     const size_t kDiffScaleIndex = 1;
1436     const size_t kDiffShiftIndex = 2;
1437     const size_t kP1Index = 3;
1438     const size_t kP2Index = 4;
1439 
1440     // Separate out scale and shift grad and copy to individual tensors
1441     MklDnnShape mkl_shape_diff_scale;
1442     mkl_shape_diff_scale.SetMklTensor(false);
1443     AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor,
1444                               tf_shape_scale_shift, mkl_shape_diff_scale,
1445                               native_format);
1446     DCHECK(*diff_scale_tensor);
1447 
1448     auto diff_scale_data = (*diff_scale_tensor)->flat<U>().data();
1449     std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(),
1450                 static_cast<U>(0));
1451 
1452     MklDnnShape mkl_shape_diff_shift;
1453     mkl_shape_diff_shift.SetMklTensor(false);
1454     AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor,
1455                               tf_shape_scale_shift, mkl_shape_diff_shift,
1456                               native_format);
1457     DCHECK(*diff_shift_tensor);
1458 
1459     auto diff_shift_data = (*diff_shift_tensor)->flat<U>().data();
1460     std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(),
1461                 static_cast<U>(0));
1462 
1463     // Placeholders for estimated_mean and estimated_variance, which are
1464     // used for inference and thus not needed here for gradient computation.
1465     Tensor *p1_tensor = nullptr, *p2_tensor = nullptr;
1466     MklDnnShape mkl_shape_p;
1467     mkl_shape_p.SetMklTensor(false);
1468     AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}),
1469                               mkl_shape_p, native_format);
1470     std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(),
1471                 static_cast<U>(0));
1472     AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}),
1473                               mkl_shape_p, native_format);
1474     std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(),
1475                 static_cast<U>(0));
1476   }
1477 
GetMeanVarianceDims()1478   memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); }
1479 };
1480 
1481 #define REGISTER_MKL_FUSED_BATCHNORM_CPU(T)                    \
1482   REGISTER_KERNEL_BUILDER(                                     \
1483       Name("_MklFusedBatchNorm")                               \
1484           .Device(DEVICE_CPU)                                  \
1485           .TypeConstraint<T>("T")                              \
1486           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1487       MklFusedBatchNormOp<CPUDevice, T, T, false, false>);     \
1488   REGISTER_KERNEL_BUILDER(                                     \
1489       Name("_MklNativeFusedBatchNorm")                         \
1490           .Device(DEVICE_CPU)                                  \
1491           .TypeConstraint<T>("T")                              \
1492           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1493       MklFusedBatchNormOp<CPUDevice, T, T, false, false, true>);
1494 
1495 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU);
1496 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU);
1497 #undef REGISTER_MKL_FUSED_BATCHNORM_CPU
1498 
1499 #define REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(T, U)              \
1500   REGISTER_KERNEL_BUILDER(                                     \
1501       Name("_MklFusedBatchNormV2")                             \
1502           .Device(DEVICE_CPU)                                  \
1503           .TypeConstraint<T>("T")                              \
1504           .TypeConstraint<U>("U")                              \
1505           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1506       MklFusedBatchNormOp<CPUDevice, T, U, false, false>);     \
1507   REGISTER_KERNEL_BUILDER(                                     \
1508       Name("_MklNativeFusedBatchNormV2")                       \
1509           .Device(DEVICE_CPU)                                  \
1510           .TypeConstraint<T>("T")                              \
1511           .TypeConstraint<U>("U")                              \
1512           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1513       MklFusedBatchNormOp<CPUDevice, T, U, false, false, true>);
1514 
1515 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float);
1516 REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float);
1517 #undef REGISTER_MKL_FUSED_BATCHNORM_V2_CPU
1518 
1519 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU(T)               \
1520   REGISTER_KERNEL_BUILDER(                                     \
1521       Name("_MklFusedBatchNormGrad")                           \
1522           .Device(DEVICE_CPU)                                  \
1523           .TypeConstraint<T>("T")                              \
1524           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1525       MklFusedBatchNormGradOp<CPUDevice, T, T, false>);        \
1526   REGISTER_KERNEL_BUILDER(                                     \
1527       Name("_MklNativeFusedBatchNormGrad")                     \
1528           .Device(DEVICE_CPU)                                  \
1529           .TypeConstraint<T>("T")                              \
1530           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1531       MklFusedBatchNormGradOp<CPUDevice, T, T, false, true>);
1532 
1533 TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU);
1534 TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU);
1535 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU
1536 
1537 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(T, U)         \
1538   REGISTER_KERNEL_BUILDER(                                     \
1539       Name("_MklFusedBatchNormGradV2")                         \
1540           .Device(DEVICE_CPU)                                  \
1541           .TypeConstraint<T>("T")                              \
1542           .TypeConstraint<U>("U")                              \
1543           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1544       MklFusedBatchNormGradOp<CPUDevice, T, U, false>);        \
1545   REGISTER_KERNEL_BUILDER(                                     \
1546       Name("_MklNativeFusedBatchNormGradV2")                   \
1547           .Device(DEVICE_CPU)                                  \
1548           .TypeConstraint<T>("T")                              \
1549           .TypeConstraint<U>("U")                              \
1550           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1551       MklFusedBatchNormGradOp<CPUDevice, T, U, false, true>);
1552 
1553 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(float, float);
1554 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float);
1555 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU
1556 
1557 // TODO(intel-tf): FusedBatchNormV3 has an additional output that
1558 //       is used to hold intermediate results. This parameter
1559 //       functionality is not implemented on CPU.
1560 #define REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(T, U)               \
1561   REGISTER_KERNEL_BUILDER(                                      \
1562       Name("_MklFusedBatchNormV3")                              \
1563           .Device(DEVICE_CPU)                                   \
1564           .TypeConstraint<T>("T")                               \
1565           .TypeConstraint<U>("U")                               \
1566           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),  \
1567       MklFusedBatchNormOp<CPUDevice, T, U, true, false>);       \
1568   REGISTER_KERNEL_BUILDER(                                      \
1569       Name("_MklFusedBatchNormEx")                              \
1570           .Device(DEVICE_CPU)                                   \
1571           .TypeConstraint<T>("T")                               \
1572           .TypeConstraint<U>("U")                               \
1573           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),  \
1574       MklFusedBatchNormOp<CPUDevice, T, U, true, true>);        \
1575   REGISTER_KERNEL_BUILDER(                                      \
1576       Name("_MklNativeFusedBatchNormV3")                        \
1577           .Device(DEVICE_CPU)                                   \
1578           .TypeConstraint<T>("T")                               \
1579           .TypeConstraint<U>("U")                               \
1580           .Label(mkl_op_registry::kMklNameChangeOpLabel),       \
1581       MklFusedBatchNormOp<CPUDevice, T, U, true, false, true>); \
1582   REGISTER_KERNEL_BUILDER(                                      \
1583       Name("_MklNativeFusedBatchNormEx")                        \
1584           .Device(DEVICE_CPU)                                   \
1585           .TypeConstraint<T>("T")                               \
1586           .TypeConstraint<U>("U")                               \
1587           .Label(mkl_op_registry::kMklNameChangeOpLabel),       \
1588       MklFusedBatchNormOp<CPUDevice, T, U, true, true, true>);
1589 
1590 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float);
1591 REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float);
1592 #undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU
1593 
1594 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1595                             .Device(DEVICE_CPU)
1596                             .TypeConstraint<float>("T")
1597                             .TypeConstraint<float>("U"),
1598                         NoOp);
1599 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1600                             .Device(DEVICE_CPU)
1601                             .TypeConstraint<bfloat16>("T")
1602                             .TypeConstraint<float>("U"),
1603                         NoOp);
1604 
1605 #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U)         \
1606   REGISTER_KERNEL_BUILDER(                                     \
1607       Name("_MklFusedBatchNormGradV3")                         \
1608           .Device(DEVICE_CPU)                                  \
1609           .TypeConstraint<T>("T")                              \
1610           .TypeConstraint<U>("U")                              \
1611           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
1612       MklFusedBatchNormGradOp<CPUDevice, T, U, true>);         \
1613   REGISTER_KERNEL_BUILDER(                                     \
1614       Name("_MklNativeFusedBatchNormGradV3")                   \
1615           .Device(DEVICE_CPU)                                  \
1616           .TypeConstraint<T>("T")                              \
1617           .TypeConstraint<U>("U")                              \
1618           .Label(mkl_op_registry::kMklNameChangeOpLabel),      \
1619       MklFusedBatchNormGradOp<CPUDevice, T, U, true, true>);
1620 
1621 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(float, float);
1622 REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(bfloat16, float);
1623 #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU
1624 
1625 }  // namespace tensorflow
1626 
1627 #undef GET_FLAG
1628 #undef IS_SET
1629 
1630 #endif  // INTEL_MKL
1631