xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_lrn_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // LRN = Local Response Normalization
17 // See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
18 // layout and primitives, use MKL dnn primitives to compute local
19 // response normalization
20 
21 #ifdef INTEL_MKL
22 
23 #define EIGEN_USE_THREADS
24 
25 #include <unordered_map>
26 #include <vector>
27 
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "dnnl.hpp"
30 #include "tensorflow/core/framework/bounds_check.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/kernels/ops_util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/util/mkl_util.h"
37 #include "tensorflow/core/util/tensor_format.h"
38 
39 #if !defined(IS_MOBILE_PLATFORM)
40 #include "tensorflow/core/util/work_sharder.h"
41 #endif
42 
43 using dnnl::lrn_backward;
44 using dnnl::lrn_forward;
45 using dnnl::prop_kind;
46 using dnnl::stream;
47 
48 namespace tensorflow {
49 
50 namespace {
51 // Create a depth-by-depth band matrix with 1s along a swath of size (2 *
52 // depth_radius + 1) around the diagonal.
53 template <typename T>
GetBandMatrix(int depth,int depth_radius,Eigen::Tensor<T,2,Eigen::RowMajor> * result)54 void GetBandMatrix(int depth, int depth_radius,
55                    Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
56   result->setZero();
57   for (int row = 0; row < depth; ++row) {
58     const int begin = std::max<int>(0, row - depth_radius);
59     const int end = std::min<int>(depth, row + depth_radius + 1);
60     Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
61     Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
62     result->slice(start, sizes).setConstant(T(1));
63   }
64 }
65 
66 }  // namespace
67 
68 template <typename T>
69 class MklLRNOp : public OpKernel {
70  public:
~MklLRNOp()71   ~MklLRNOp() {}
72 
MklLRNOp(OpKernelConstruction * context)73   explicit MklLRNOp(OpKernelConstruction* context)
74       : OpKernel(context), cpu_engine_(engine::kind::cpu, 0) {
75     int64 depth_radius64;
76     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
77     OP_REQUIRES(
78         context,
79         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
80         errors::InvalidArgument("depth_radius = ", depth_radius64,
81                                 " larger than int max"));
82     depth_radius_ = static_cast<size_t>(depth_radius64);
83 
84     OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
85     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
86     OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
87     workspace_enabled_ = false;
88     OP_REQUIRES_OK(context,
89                    context->GetAttr("workspace_enabled", &workspace_enabled_));
90   }
91 
Compute(OpKernelContext * context)92   void Compute(OpKernelContext* context) override {
93     try {
94       SanityCheckInputs(context);
95       if (!context->status().ok()) return;
96 
97       const Tensor& src_tensor = MklGetInput(context, kIdxInput);
98       MklDnnShape src_dnn_shape;
99       GetMklShape(context, kIdxInput, &src_dnn_shape);
100 
101       // MKL-DNN has a notion of kernel_size and not depth_radius.
102       int kernel_size = 2 * depth_radius_ + 1;
103       float new_alpha = alpha_ * kernel_size;
104 
105       // if the input tensor is not an MKL Tensor, or if the last
106       // dimension is not channel, then just use Eigen.
107       // MKL only support normalization over the channel dimension.
108       if (!src_dnn_shape.IsMklTensor()) {
109         MklDefaultToEigen(context, src_tensor);
110         return;
111       } else if (!src_dnn_shape.IsMklChannelDim(src_dnn_shape.GetDimension() -
112                                                 1)) {
113         Tensor converted_tensor;
114         OP_REQUIRES_OK(context,
115                        ConvertMklToTF<T>(context, src_tensor, src_dnn_shape,
116                                          &converted_tensor));
117         MklDefaultToEigen(context, converted_tensor);
118         return;
119       }
120       // At this point, we can assume that the src is an MklTensor
121       // and we can enable the workspace
122       workspace_enabled_ = true;
123 
124       MklDnnData<T> src_dnn_data(&cpu_engine_);
125       MklDnnData<T> dst_dnn_data(&cpu_engine_);
126       MklDnnData<uint8> workspace_dnn_data(&cpu_engine_);
127 
128       TensorShape tf_output_shape = src_tensor.shape();
129 
130       memory::desc src_md = src_dnn_shape.GetCurLayout();
131       memory::dims input_dims = src_dnn_shape.GetSizesAsMklDnnDims();
132 
133       // Create memory for user input.
134       // Since Tensorflow always performs normalization over last dimension,
135       // and MKL-DNN performs normalization over Channel, we tell MKL-DNN
136       // that input is in NHWC layout with Channel being the last dimension.
137       src_dnn_data.SetUsrMem(src_md, &src_tensor);
138       src_dnn_data.SetOpMemDesc(input_dims, memory::format_tag::nhwc);
139       src_dnn_data.SetUsrMemDataHandle(&src_tensor, fwd_stream_);
140 
141       // dst_dnn_data has the same shape as input.
142       dst_dnn_data.SetUsrMem(src_md);
143       dst_dnn_data.SetOpMemDesc(input_dims, memory::format_tag::nhwc);
144 
145       // Create LRN primitive descriptor.
146       // Tensorflow's normalization semantics is across channels.
147       // MKL-DNN also supports normalization within channel.
148       auto lrn_desc = lrn_forward::desc(
149           prop_kind::forward, dnnl::algorithm::lrn_across_channels,
150           src_dnn_data.GetUsrMemDesc(), kernel_size, new_alpha, beta_, bias_);
151       auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine_);
152 
153       // Allocate output_dnn_data tensor.
154       Tensor* output_tensor = nullptr;
155       auto input_format = src_dnn_shape.GetTfDataFormat();
156       AllocateOutputTensor(context, lrn_prim_desc, input_dims, input_format,
157                            &output_tensor);
158       OP_REQUIRES_OK(context, context->status());
159       DCHECK(output_tensor != nullptr);
160       dst_dnn_data.SetUsrMemDataHandle(output_tensor, fwd_stream_);
161 
162       // Handle workspace required for MKL-DNN.
163       AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
164       OP_REQUIRES_OK(context, context->status());
165 
166       // Check for input reorder
167       src_dnn_data.CheckReorderToOpMem(lrn_prim_desc.src_desc(), cpu_engine_);
168 
169       std::vector<primitive> net;
170       MklDnnThreadPool eigen_tp(context);
171       fwd_stream_.reset(CreateStream(&eigen_tp, cpu_engine_));
172       net.push_back(lrn_forward(lrn_prim_desc));
173       std::vector<std::unordered_map<int, memory>> net_args;
174       net_args.push_back({{DNNL_ARG_SRC, src_dnn_data.GetOpMem()},
175                           {DNNL_ARG_WORKSPACE, workspace_dnn_data.GetOpMem()},
176                           {DNNL_ARG_DST, dst_dnn_data.GetOpMem()}});
177       net.push_back(lrn_forward(lrn_prim_desc));
178       net.at(0).execute(*fwd_stream_, net_args.at(0));
179     } catch (dnnl::error& e) {
180       string error_msg = "Status: " + std::to_string(e.status) +
181                          ", message: " + string(e.message) + ", in file " +
182                          string(__FILE__) + ":" + std::to_string(__LINE__);
183       OP_REQUIRES_OK(
184           context,
185           errors::Aborted("Operation received an exception:", error_msg));
186     }
187   }
188 
189  private:
AllocateOutputTensor(OpKernelContext * context,const lrn_forward::primitive_desc & lrn_fwd_prim_desc,const memory::dims output_dims_mkl_order,const MklTensorFormat & output_tf_format,Tensor ** output_tensor)190   void AllocateOutputTensor(
191       OpKernelContext* context,
192       const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
193       const memory::dims output_dims_mkl_order,
194       const MklTensorFormat& output_tf_format, Tensor** output_tensor) {
195     DCHECK(output_tensor != nullptr);
196     memory::desc dst_pd = lrn_fwd_prim_desc.dst_desc();
197 
198     MklDnnShape output_mkl_shape;
199     // We only handle the case when the inputs and output are in Mkl format
200     // Any other case is handled by Eigen
201     output_mkl_shape.SetMklTensor(true);
202     output_mkl_shape.SetMklLayout(&dst_pd);
203     output_mkl_shape.SetElemType(MklDnnType<T>());
204     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
205                                  output_dims_mkl_order, output_tf_format);
206     TensorShape output_tf_shape;
207     // only allocate enough space for the elements we need.
208     size_t num_bytes = dst_pd.get_size();
209     CHECK_EQ(num_bytes % sizeof(T), 0);
210     output_tf_shape.AddDim(num_bytes / sizeof(T));
211     AllocateOutputSetMklShape(context, kIdxOutput, output_tensor,
212                               output_tf_shape, output_mkl_shape);
213   }
214 
215   // Fallback implementation - Taken from lrn_op.cc
216   // TODO(intel-tf) Check if we can use EigenLRNOp directly instead of making a
217   // copy.
MklDefaultToEigen(OpKernelContext * context,const Tensor & input)218   void MklDefaultToEigen(OpKernelContext* context, const Tensor& input) {
219     const int batch = static_cast<int>(input.dim_size(0));
220     const int rows = static_cast<int>(input.dim_size(1));
221     const int cols = static_cast<int>(input.dim_size(2));
222     const int depth = static_cast<int>(input.dim_size(3));
223     const int nodes = cols * rows;
224 
225     auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
226     // Multiplying the input with the band matrix has the effect of reducing
227     // the correct patch along the depth.
228     Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
229     GetBandMatrix<T>(depth, depth_radius_, &multiplier);
230 
231     Tensor* output_dnn_data = nullptr;
232     MklDnnShape mkl_output_mkl_shape;
233     mkl_output_mkl_shape.SetMklTensor(false);
234     mkl_output_mkl_shape.SetDimensions(4);
235     AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
236                               input.shape(), mkl_output_mkl_shape);
237     DCHECK(output_dnn_data != nullptr);
238 
239     Tensor* workspace_tensor = nullptr;
240     MklDnnShape workspace_mkl_shape;
241     workspace_mkl_shape.SetMklTensor(false);
242     TensorShape workspace_tf_shape;
243     workspace_tf_shape.AddDim(0);
244     AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
245                               workspace_tf_shape, workspace_mkl_shape);
246     DCHECK(workspace_tensor);
247 
248     auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
249     Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
250     auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
251     if (beta_ == T(1)) {
252       out_shaped.device(context->eigen_cpu_device()) =
253           in_shaped * tmp.inverse();
254     } else if (beta_ == T(0.5)) {
255       out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt();
256     } else {
257       out_shaped.device(context->eigen_cpu_device()) =
258           in_shaped * (tmp.log() * -beta_).exp();
259     }
260   }
261 
AllocateWorkspaceTensor(OpKernelContext * context,const lrn_forward::primitive_desc & lrn_fwd_prim_desc,MklDnnData<uint8> * dnn_data_wksp)262   void AllocateWorkspaceTensor(
263       OpKernelContext* context,
264       const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
265       MklDnnData<uint8>* dnn_data_wksp) {
266     DCHECK(dnn_data_wksp != nullptr);
267     Tensor* workspace_tensor = nullptr;
268     memory::desc workspace_pd = lrn_fwd_prim_desc.workspace_desc();
269     size_t workspace_bytes = workspace_pd.get_size();
270     MklDnnShape workspace_mkl_shape;
271     // the workspace tensor is a uint8 tensor that has
272     // exactly the number of bytes necessary
273     workspace_mkl_shape.SetMklTensor(false);
274     TensorShape workspace_tf_shape;
275     workspace_tf_shape.AddDim(workspace_bytes);
276     AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
277                               workspace_tf_shape, workspace_mkl_shape);
278     DCHECK(workspace_tensor != nullptr);
279     dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
280   }
281 
SanityCheckInputs(OpKernelContext * context)282   void SanityCheckInputs(OpKernelContext* context) {
283     const Tensor& src_tensor = MklGetInput(context, kIdxInput);
284     MklDnnShape src_dnn_shape;
285     GetMklShape(context, kIdxInput, &src_dnn_shape);
286     if (src_dnn_shape.IsMklTensor()) {
287       OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4,
288                   errors::InvalidArgument("input must be 4-dimensional"));
289       OP_REQUIRES(context,
290                   FastBoundsCheck(src_tensor.NumElements(),
291                                   std::numeric_limits<int>::max()),
292                   errors::InvalidArgument("argument to LRN too large"));
293     } else {
294       OP_REQUIRES(context, src_tensor.dims() == 4,
295                   errors::InvalidArgument("input must be 4-dimensional"));
296       OP_REQUIRES(context,
297                   FastBoundsCheck(src_tensor.NumElements(),
298                                   std::numeric_limits<int>::max()),
299                   errors::InvalidArgument("argument to LRN too large"));
300     }
301   }
302   const int kIdxInput = 0, kIdxOutput = 0, kIdxWorkspace = 1;
303 
304   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
305   bool workspace_enabled_;
306   int depth_radius_;
307   float bias_;
308   float alpha_;
309   float beta_;
310   engine cpu_engine_;
311   std::shared_ptr<stream> fwd_stream_;
312 };
313 
314 template <typename T>
315 class MklLRNGradOp : public OpKernel {
316  public:
MklLRNGradOp(OpKernelConstruction * context)317   explicit MklLRNGradOp(OpKernelConstruction* context)
318       : OpKernel(context), cpu_engine_(engine::kind::cpu, 0) {
319     int64 depth_radius64;
320     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
321     OP_REQUIRES(
322         context,
323         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
324         errors::InvalidArgument("depth_radius = ", depth_radius64,
325                                 " larger than int max"));
326     depth_radius_ = static_cast<int>(depth_radius64);
327     OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
328     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
329     OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
330     workspace_enabled_ = false;
331     OP_REQUIRES_OK(context,
332                    context->GetAttr("workspace_enabled", &workspace_enabled_));
333     bwd_stream_.reset(new stream(cpu_engine_));
334   }
335 
Compute(OpKernelContext * context)336   void Compute(OpKernelContext* context) override {
337     try {
338       SanityCheckInputs(context);
339       if (!context->status().ok()) return;
340 
341       MklDnnData<T> input_grad_dnn_data(&cpu_engine_);
342       MklDnnData<T> orig_input_dnn_data(&cpu_engine_);
343       MklDnnData<T> orig_output_dnn_data(&cpu_engine_);
344       MklDnnData<T> output_dnn_data(&cpu_engine_);
345 
346       MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
347           orig_output_dnn_shape;
348       GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
349       GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
350       GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
351 
352       // We only use oneDNN if all of the necessary inputs are present
353       // in oneDNN format, and Channel is the last dimension
354       bool can_use_mkldnn = workspace_enabled_ &&
355                             input_grad_dnn_shape.IsMklTensor() &&
356                             orig_input_dnn_shape.IsMklTensor() &&
357                             orig_output_dnn_shape.IsMklTensor() &&
358                             input_grad_dnn_shape.IsMklChannelDim(
359                                 input_grad_dnn_shape.GetDimension() - 1) &&
360                             orig_input_dnn_shape.IsMklChannelDim(
361                                 orig_input_dnn_shape.GetDimension() - 1) &&
362                             orig_output_dnn_shape.IsMklChannelDim(
363                                 orig_output_dnn_shape.GetDimension() - 1);
364 
365       if (!can_use_mkldnn) {
366         // Fallback to eigen
367         MklDefaultToEigen(context);
368         return;
369       }
370       // At this point, we have the all clear to use MklDnn constructs
371       // Naming: diff_dst is input_gradient_tensor; src is orig_input_tensor.
372       const Tensor& input_grad_tensor = MklGetInput(context, kIdxGradient);
373       const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
374 
375       // Get input sizes in MKL-DNN required NCHW format.
376       // LRN does not have data_format attribute. But by default it has
377       // NHWC format.
378       memory::desc original_output_md = orig_output_dnn_shape.GetCurLayout();
379       memory::desc target_diff_dst_md = ConfigureInputGradient(
380           input_grad_tensor, input_grad_dnn_shape, &input_grad_dnn_data);
381 
382       memory::desc orig_input_md = orig_input_dnn_shape.GetCurLayout();
383       memory::dims orig_input_dims =
384           orig_input_dnn_shape.GetSizesAsMklDnnDims();
385       orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
386       orig_input_dnn_data.SetOpMemDesc(orig_input_dims,
387                                        memory::format_tag::nhwc);
388       orig_input_dnn_data.SetUsrMemDataHandle(&orig_input_tensor, bwd_stream_);
389 
390       // output_dnn_data has the same shape as original input
391       output_dnn_data.SetUsrMem(orig_input_md);
392       output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format_tag::nhwc);
393 
394       // MKL-DNN has a notion of kernel_size and not depth_radius.
395       int kernel_size = 2 * depth_radius_ + 1;
396       float new_alpha = alpha_ * kernel_size;
397 
398       // Create LRN backward primitive descriptor. It requires LRN forward
399       // primitive descriptor also.
400       auto lrn_fwd_desc = lrn_forward::desc(
401           prop_kind::forward, dnnl::algorithm::lrn_across_channels,
402           orig_input_md, kernel_size, new_alpha, beta_, bias_);
403       auto lrn_fwd_prim_desc =
404           lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine_);
405       auto lrn_bwd_desc = lrn_backward::desc(
406           dnnl::algorithm::lrn_across_channels, original_output_md,
407           target_diff_dst_md, kernel_size, new_alpha, beta_, bias_);
408       auto lrn_bwd_prim_desc = lrn_backward::primitive_desc(
409           lrn_bwd_desc, cpu_engine_, lrn_fwd_prim_desc);
410 
411       Tensor* output_tensor = nullptr;
412       auto orig_input_format = orig_input_dnn_shape.GetTfDataFormat();
413       AllocateOutputTensor(context, lrn_bwd_prim_desc, orig_input_dims,
414                            orig_input_format, &output_tensor);
415       OP_REQUIRES_OK(context, context->status());
416       DCHECK(output_tensor != nullptr);
417       output_dnn_data.SetUsrMemDataHandle(output_tensor, bwd_stream_);
418 
419       // Create LRN primitive and add it to the net
420       // At this point, workspace is enabled, so we don't need
421       // to check. Pass input workspace to LRN backward primitive.
422       const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
423       MklDnnData<uint8> workspace_dnn_data(&cpu_engine_);
424       ConfigureWorkspace(workspace_tensor, lrn_fwd_prim_desc.workspace_desc(),
425                          &workspace_dnn_data);
426 
427       // Check for input reordering on the diff dst input
428       input_grad_dnn_data.CheckReorderToOpMem(lrn_bwd_prim_desc.diff_dst_desc(),
429                                               cpu_engine_);
430 
431       // Check for input reordering on the original input
432       orig_input_dnn_data.CheckReorderToOpMem(lrn_fwd_prim_desc.src_desc(),
433                                               cpu_engine_);
434 
435       std::vector<primitive> net;
436       std::vector<std::unordered_map<int, memory>> net_args;
437       net.push_back(lrn_backward(lrn_bwd_prim_desc));
438       net_args.push_back({{DNNL_ARG_SRC, orig_input_dnn_data.GetOpMem()},
439                           {DNNL_ARG_DIFF_DST, input_grad_dnn_data.GetOpMem()},
440                           {DNNL_ARG_DST, output_dnn_data.GetOpMem()}});
441       net.push_back(lrn_backward(lrn_bwd_prim_desc));
442       net.at(0).execute(*bwd_stream_, net_args.at(0));
443     } catch (dnnl::error& e) {
444       string error_msg = "Status: " + std::to_string(e.status) +
445                          ", message: " + string(e.message) + ", in file " +
446                          string(__FILE__) + ":" + std::to_string(__LINE__);
447       OP_REQUIRES_OK(
448           context,
449           errors::Aborted("Operation received an exception:", error_msg));
450     }
451   }
452 
AllocateOutputTensor(OpKernelContext * context,const lrn_backward::primitive_desc & lrn_bkwd_prim_desc,const memory::dims output_dims_mkl_order,const MklTensorFormat & output_tf_format,Tensor ** output_tensor)453   void AllocateOutputTensor(
454       OpKernelContext* context,
455       const lrn_backward::primitive_desc& lrn_bkwd_prim_desc,
456       const memory::dims output_dims_mkl_order,
457       const MklTensorFormat& output_tf_format, Tensor** output_tensor) {
458     DCHECK(output_tensor != nullptr);
459     memory::desc dst_pd = lrn_bkwd_prim_desc.diff_src_desc();
460     MklDnnShape output_mkl_shape;
461 
462     // We assume that all outputs at this point are MKL Tensors
463     output_mkl_shape.SetMklTensor(true);
464     output_mkl_shape.SetMklLayout(&dst_pd);
465     output_mkl_shape.SetElemType(MklDnnType<T>());
466     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
467                                  output_dims_mkl_order, output_tf_format);
468 
469     TensorShape output_tf_shape;
470     size_t num_bytes = dst_pd.get_size();
471     CHECK_EQ(num_bytes % sizeof(T), 0);
472     output_tf_shape.AddDim(num_bytes / sizeof(T));
473     AllocateOutputSetMklShape(context, kIdxOutput, output_tensor,
474                               output_tf_shape, output_mkl_shape);
475   }
476 
ConfigureInputGradient(const Tensor & input_grad_tensor,const MklDnnShape & input_grad_dnn_shape,MklDnnData<T> * input_grad_dnn_data)477   memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor,
478                                       const MklDnnShape& input_grad_dnn_shape,
479                                       MklDnnData<T>* input_grad_dnn_data) {
480     DCHECK(input_grad_dnn_data != nullptr);
481     // This shouldn't be necessary at this point, but just in case
482     DCHECK(input_grad_dnn_shape.IsMklTensor() == true);
483 
484     memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout();
485     memory::dims orig_input_dims = input_grad_dnn_shape.GetSizesAsMklDnnDims();
486     input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor);
487     input_grad_dnn_data->SetOpMemDesc(orig_input_dims,
488                                       memory::format_tag::nhwc);
489     return input_grad_md;
490   }
491 
ConfigureWorkspace(const Tensor & workspace_tensor,memory::desc workspace_pd,MklDnnData<uint8> * workspace_dnn_data)492   void ConfigureWorkspace(const Tensor& workspace_tensor,
493                           memory::desc workspace_pd,
494                           MklDnnData<uint8>* workspace_dnn_data) {
495     DCHECK(workspace_dnn_data);
496 
497     workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
498   }
499 
500   // Fallback implementation - Taken from lrn_op.cc
501   // TODO(intel-tf) Check if we can use EigenLRNOp directly
502   // instead of making a copy.
MklDefaultToEigen(OpKernelContext * context)503   void MklDefaultToEigen(OpKernelContext* context) {
504     Tensor input_gradient_tensor;
505     Tensor orig_input_tensor;
506     Tensor orig_output_tensor;
507 
508     MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
509         orig_output_dnn_shape;
510     GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
511     GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
512     GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
513 
514     if (input_grad_dnn_shape.IsMklTensor()) {
515       OP_REQUIRES_OK(
516           context,
517           ConvertMklToTF<T>(context, MklGetInput(context, kIdxGradient),
518                             input_grad_dnn_shape, &input_gradient_tensor));
519     } else {
520       input_gradient_tensor = MklGetInput(context, kIdxGradient);
521     }
522 
523     if (orig_input_dnn_shape.IsMklTensor()) {
524       OP_REQUIRES_OK(context, ConvertMklToTF<T>(
525                                   context, MklGetInput(context, kIdxOrigInput),
526                                   orig_input_dnn_shape, &orig_input_tensor));
527     } else {
528       orig_input_tensor = MklGetInput(context, kIdxOrigInput);
529     }
530 
531     if (orig_output_dnn_shape.IsMklTensor()) {
532       OP_REQUIRES_OK(context, ConvertMklToTF<T>(
533                                   context, MklGetInput(context, kIdxOrigOutput),
534                                   orig_output_dnn_shape, &orig_output_tensor));
535     } else {
536       orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
537     }
538 
539     const int64 batch = static_cast<int64_t>(input_gradient_tensor.dim_size(0));
540     const int64 rows = static_cast<int64_t>(input_gradient_tensor.dim_size(1));
541     const int64 cols = static_cast<int64_t>(input_gradient_tensor.dim_size(2));
542     const int64 depth = static_cast<int64_t>(input_gradient_tensor.dim_size(3));
543     const auto nodes = cols * rows;
544 
545     auto grads_shaped =
546         input_gradient_tensor.shaped<T, 2>({nodes * batch, depth});
547 
548     auto in_shaped = orig_input_tensor.shaped<T, 2>({nodes * batch, depth});
549     auto activations = orig_output_tensor.shaped<T, 2>({nodes * batch, depth});
550 
551     Tensor* output_dnn_data;
552     MklDnnShape mkl_output_mkl_shape;
553     mkl_output_mkl_shape.SetMklTensor(false);
554     mkl_output_mkl_shape.SetDimensions(4);
555     AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
556                               input_gradient_tensor.shape(),
557                               mkl_output_mkl_shape);
558 
559     auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
560     out_shaped.setZero();
561     auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
562                   depth](int64 begin, int64 end) {
563       for (int64 i = begin; i < end; ++i) {
564         for (int64 j = 0; j < depth; ++j) {
565           int64 depth_begin = std::max<int64_t>(0, j - depth_radius_);
566           int64 depth_end = std::min<int64_t>(depth, j + depth_radius_ + 1);
567 
568           T norm(0);
569           for (int64 k = depth_begin; k < depth_end; ++k) {
570             norm += in_shaped(i, k) * in_shaped(i, k);
571           }
572           norm = alpha_ * norm + bias_;
573           DCHECK_GT(norm, T(1e-6));
574           for (int64 k = depth_begin; k < depth_end; ++k) {
575             T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
576                     activations(i, j) / norm;
577             if (k == j) {
578               dyi += Eigen::numext::pow(norm, -beta_);
579             }
580             dyi *= grads_shaped(i, j);
581             const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) += dyi;
582           }
583         }
584       }
585     };
586     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
587     Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
588           depth * depth, shard);
589   }
590 
SanityCheckInputs(OpKernelContext * context)591   void SanityCheckInputs(OpKernelContext* context) {
592     const Tensor& input_gradient_tensor = MklGetInput(context, kIdxGradient);
593     const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
594     const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
595     const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
596     MklDnnShape in_grads_dnn_shape, in_image_dnn_shape, out_image_dnn_shape,
597         workspace_dnn_shape;
598     GetMklShape(context, kIdxGradient, &in_grads_dnn_shape);
599     GetMklShape(context, kIdxOrigInput, &in_image_dnn_shape);
600     GetMklShape(context, kIdxOrigOutput, &out_image_dnn_shape);
601     GetMklShape(context, kIdxWorkspace, &workspace_dnn_shape);
602     if (in_grads_dnn_shape.IsMklTensor()) {
603       OP_REQUIRES(context, in_grads_dnn_shape.GetDimension() == 4,
604                   errors::InvalidArgument("Input gradient must be "
605                                           "4-dimensional"));
606     } else {
607       OP_REQUIRES(
608           context, input_gradient_tensor.dims() == 4,
609           errors::InvalidArgument("input gradient must be 4-dimensional"));
610     }
611 
612     if (in_image_dnn_shape.IsMklTensor()) {
613       OP_REQUIRES(context, in_image_dnn_shape.GetDimension() == 4,
614                   errors::InvalidArgument("input images must be "
615                                           "4-dimensional"));
616     } else {
617       OP_REQUIRES(context, orig_input_tensor.dims() == 4,
618                   errors::InvalidArgument("input images must be "
619                                           "4-dimensional"));
620     }
621 
622     if (out_image_dnn_shape.IsMklTensor()) {
623       OP_REQUIRES(context, out_image_dnn_shape.GetDimension() == 4,
624                   errors::InvalidArgument("Output image must be "
625                                           "4-dimensional"));
626     } else {
627       OP_REQUIRES(
628           context, orig_output_tensor.dims() == 4,
629           errors::InvalidArgument("Output image must be 4-dimensional"));
630     }
631 
632     if (workspace_enabled_) {
633       if (workspace_dnn_shape.IsMklTensor()) {
634         OP_REQUIRES(
635             context, workspace_dnn_shape.IsMklTensor() == false,
636             errors::InvalidArgument("Workspace should not be MKL Tensor."));
637       } else {
638         OP_REQUIRES(context, workspace_tensor.dims() == 1,
639                     errors::InvalidArgument("Workspace must be 1-dimensional"));
640       }
641     }
642   }
643 
644   // Input("input_grads: T")
645   // Input("input_image: T")
646   // Input("output_image: T")
647   // Input("workspace: uint8")
648   const int kIdxGradient = 0, kIdxOrigInput = 1, kIdxOrigOutput = 2,
649             kIdxWorkspace = 3, kIdxOutput = 0;
650 
651   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
652   bool workspace_enabled_;
653   int depth_radius_;
654   float bias_;
655   float alpha_;
656   float beta_;
657   engine cpu_engine_;
658   std::shared_ptr<stream> bwd_stream_;
659 };
660 
661 #define REGISTER_MKL_LRN_CPU(T)                                \
662   REGISTER_KERNEL_BUILDER(                                     \
663       Name("_MklLRN")                                          \
664           .Device(DEVICE_CPU)                                  \
665           .TypeConstraint<T>("T")                              \
666           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
667       MklLRNOp<T>);                                            \
668   REGISTER_KERNEL_BUILDER(                                     \
669       Name("_MklLRNGrad")                                      \
670           .Device(DEVICE_CPU)                                  \
671           .TypeConstraint<T>("T")                              \
672           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
673       MklLRNGradOp<T>);
674 
675 TF_CALL_float(REGISTER_MKL_LRN_CPU);
676 
677 }  // namespace tensorflow
678 
679 #endif  // INTEL_MKL
680