1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // LRN = Local Response Normalization
17 // See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
18 // layout and primitives, use MKL dnn primitives to compute local
19 // response normalization
20
21 #ifdef INTEL_MKL
22
23 #define EIGEN_USE_THREADS
24
25 #include <unordered_map>
26 #include <vector>
27
28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29 #include "dnnl.hpp"
30 #include "tensorflow/core/framework/bounds_check.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/kernels/ops_util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/util/mkl_util.h"
37 #include "tensorflow/core/util/tensor_format.h"
38
39 #if !defined(IS_MOBILE_PLATFORM)
40 #include "tensorflow/core/util/work_sharder.h"
41 #endif
42
43 using dnnl::lrn_backward;
44 using dnnl::lrn_forward;
45 using dnnl::prop_kind;
46 using dnnl::stream;
47
48 namespace tensorflow {
49
50 namespace {
51 // Create a depth-by-depth band matrix with 1s along a swath of size (2 *
52 // depth_radius + 1) around the diagonal.
53 template <typename T>
GetBandMatrix(int depth,int depth_radius,Eigen::Tensor<T,2,Eigen::RowMajor> * result)54 void GetBandMatrix(int depth, int depth_radius,
55 Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
56 result->setZero();
57 for (int row = 0; row < depth; ++row) {
58 const int begin = std::max<int>(0, row - depth_radius);
59 const int end = std::min<int>(depth, row + depth_radius + 1);
60 Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
61 Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
62 result->slice(start, sizes).setConstant(T(1));
63 }
64 }
65
66 } // namespace
67
68 template <typename T>
69 class MklLRNOp : public OpKernel {
70 public:
~MklLRNOp()71 ~MklLRNOp() {}
72
MklLRNOp(OpKernelConstruction * context)73 explicit MklLRNOp(OpKernelConstruction* context)
74 : OpKernel(context), cpu_engine_(engine::kind::cpu, 0) {
75 int64 depth_radius64;
76 OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
77 OP_REQUIRES(
78 context,
79 FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
80 errors::InvalidArgument("depth_radius = ", depth_radius64,
81 " larger than int max"));
82 depth_radius_ = static_cast<size_t>(depth_radius64);
83
84 OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
85 OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
86 OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
87 workspace_enabled_ = false;
88 OP_REQUIRES_OK(context,
89 context->GetAttr("workspace_enabled", &workspace_enabled_));
90 }
91
Compute(OpKernelContext * context)92 void Compute(OpKernelContext* context) override {
93 try {
94 SanityCheckInputs(context);
95 if (!context->status().ok()) return;
96
97 const Tensor& src_tensor = MklGetInput(context, kIdxInput);
98 MklDnnShape src_dnn_shape;
99 GetMklShape(context, kIdxInput, &src_dnn_shape);
100
101 // MKL-DNN has a notion of kernel_size and not depth_radius.
102 int kernel_size = 2 * depth_radius_ + 1;
103 float new_alpha = alpha_ * kernel_size;
104
105 // if the input tensor is not an MKL Tensor, or if the last
106 // dimension is not channel, then just use Eigen.
107 // MKL only support normalization over the channel dimension.
108 if (!src_dnn_shape.IsMklTensor()) {
109 MklDefaultToEigen(context, src_tensor);
110 return;
111 } else if (!src_dnn_shape.IsMklChannelDim(src_dnn_shape.GetDimension() -
112 1)) {
113 Tensor converted_tensor;
114 OP_REQUIRES_OK(context,
115 ConvertMklToTF<T>(context, src_tensor, src_dnn_shape,
116 &converted_tensor));
117 MklDefaultToEigen(context, converted_tensor);
118 return;
119 }
120 // At this point, we can assume that the src is an MklTensor
121 // and we can enable the workspace
122 workspace_enabled_ = true;
123
124 MklDnnData<T> src_dnn_data(&cpu_engine_);
125 MklDnnData<T> dst_dnn_data(&cpu_engine_);
126 MklDnnData<uint8> workspace_dnn_data(&cpu_engine_);
127
128 TensorShape tf_output_shape = src_tensor.shape();
129
130 memory::desc src_md = src_dnn_shape.GetCurLayout();
131 memory::dims input_dims = src_dnn_shape.GetSizesAsMklDnnDims();
132
133 // Create memory for user input.
134 // Since Tensorflow always performs normalization over last dimension,
135 // and MKL-DNN performs normalization over Channel, we tell MKL-DNN
136 // that input is in NHWC layout with Channel being the last dimension.
137 src_dnn_data.SetUsrMem(src_md, &src_tensor);
138 src_dnn_data.SetOpMemDesc(input_dims, memory::format_tag::nhwc);
139 src_dnn_data.SetUsrMemDataHandle(&src_tensor, fwd_stream_);
140
141 // dst_dnn_data has the same shape as input.
142 dst_dnn_data.SetUsrMem(src_md);
143 dst_dnn_data.SetOpMemDesc(input_dims, memory::format_tag::nhwc);
144
145 // Create LRN primitive descriptor.
146 // Tensorflow's normalization semantics is across channels.
147 // MKL-DNN also supports normalization within channel.
148 auto lrn_desc = lrn_forward::desc(
149 prop_kind::forward, dnnl::algorithm::lrn_across_channels,
150 src_dnn_data.GetUsrMemDesc(), kernel_size, new_alpha, beta_, bias_);
151 auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine_);
152
153 // Allocate output_dnn_data tensor.
154 Tensor* output_tensor = nullptr;
155 auto input_format = src_dnn_shape.GetTfDataFormat();
156 AllocateOutputTensor(context, lrn_prim_desc, input_dims, input_format,
157 &output_tensor);
158 OP_REQUIRES_OK(context, context->status());
159 DCHECK(output_tensor != nullptr);
160 dst_dnn_data.SetUsrMemDataHandle(output_tensor, fwd_stream_);
161
162 // Handle workspace required for MKL-DNN.
163 AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
164 OP_REQUIRES_OK(context, context->status());
165
166 // Check for input reorder
167 src_dnn_data.CheckReorderToOpMem(lrn_prim_desc.src_desc(), cpu_engine_);
168
169 std::vector<primitive> net;
170 MklDnnThreadPool eigen_tp(context);
171 fwd_stream_.reset(CreateStream(&eigen_tp, cpu_engine_));
172 net.push_back(lrn_forward(lrn_prim_desc));
173 std::vector<std::unordered_map<int, memory>> net_args;
174 net_args.push_back({{DNNL_ARG_SRC, src_dnn_data.GetOpMem()},
175 {DNNL_ARG_WORKSPACE, workspace_dnn_data.GetOpMem()},
176 {DNNL_ARG_DST, dst_dnn_data.GetOpMem()}});
177 net.push_back(lrn_forward(lrn_prim_desc));
178 net.at(0).execute(*fwd_stream_, net_args.at(0));
179 } catch (dnnl::error& e) {
180 string error_msg = "Status: " + std::to_string(e.status) +
181 ", message: " + string(e.message) + ", in file " +
182 string(__FILE__) + ":" + std::to_string(__LINE__);
183 OP_REQUIRES_OK(
184 context,
185 errors::Aborted("Operation received an exception:", error_msg));
186 }
187 }
188
189 private:
AllocateOutputTensor(OpKernelContext * context,const lrn_forward::primitive_desc & lrn_fwd_prim_desc,const memory::dims output_dims_mkl_order,const MklTensorFormat & output_tf_format,Tensor ** output_tensor)190 void AllocateOutputTensor(
191 OpKernelContext* context,
192 const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
193 const memory::dims output_dims_mkl_order,
194 const MklTensorFormat& output_tf_format, Tensor** output_tensor) {
195 DCHECK(output_tensor != nullptr);
196 memory::desc dst_pd = lrn_fwd_prim_desc.dst_desc();
197
198 MklDnnShape output_mkl_shape;
199 // We only handle the case when the inputs and output are in Mkl format
200 // Any other case is handled by Eigen
201 output_mkl_shape.SetMklTensor(true);
202 output_mkl_shape.SetMklLayout(&dst_pd);
203 output_mkl_shape.SetElemType(MklDnnType<T>());
204 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
205 output_dims_mkl_order, output_tf_format);
206 TensorShape output_tf_shape;
207 // only allocate enough space for the elements we need.
208 size_t num_bytes = dst_pd.get_size();
209 CHECK_EQ(num_bytes % sizeof(T), 0);
210 output_tf_shape.AddDim(num_bytes / sizeof(T));
211 AllocateOutputSetMklShape(context, kIdxOutput, output_tensor,
212 output_tf_shape, output_mkl_shape);
213 }
214
215 // Fallback implementation - Taken from lrn_op.cc
216 // TODO(intel-tf) Check if we can use EigenLRNOp directly instead of making a
217 // copy.
MklDefaultToEigen(OpKernelContext * context,const Tensor & input)218 void MklDefaultToEigen(OpKernelContext* context, const Tensor& input) {
219 const int batch = static_cast<int>(input.dim_size(0));
220 const int rows = static_cast<int>(input.dim_size(1));
221 const int cols = static_cast<int>(input.dim_size(2));
222 const int depth = static_cast<int>(input.dim_size(3));
223 const int nodes = cols * rows;
224
225 auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
226 // Multiplying the input with the band matrix has the effect of reducing
227 // the correct patch along the depth.
228 Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
229 GetBandMatrix<T>(depth, depth_radius_, &multiplier);
230
231 Tensor* output_dnn_data = nullptr;
232 MklDnnShape mkl_output_mkl_shape;
233 mkl_output_mkl_shape.SetMklTensor(false);
234 mkl_output_mkl_shape.SetDimensions(4);
235 AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
236 input.shape(), mkl_output_mkl_shape);
237 DCHECK(output_dnn_data != nullptr);
238
239 Tensor* workspace_tensor = nullptr;
240 MklDnnShape workspace_mkl_shape;
241 workspace_mkl_shape.SetMklTensor(false);
242 TensorShape workspace_tf_shape;
243 workspace_tf_shape.AddDim(0);
244 AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
245 workspace_tf_shape, workspace_mkl_shape);
246 DCHECK(workspace_tensor);
247
248 auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
249 Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
250 auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
251 if (beta_ == T(1)) {
252 out_shaped.device(context->eigen_cpu_device()) =
253 in_shaped * tmp.inverse();
254 } else if (beta_ == T(0.5)) {
255 out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt();
256 } else {
257 out_shaped.device(context->eigen_cpu_device()) =
258 in_shaped * (tmp.log() * -beta_).exp();
259 }
260 }
261
AllocateWorkspaceTensor(OpKernelContext * context,const lrn_forward::primitive_desc & lrn_fwd_prim_desc,MklDnnData<uint8> * dnn_data_wksp)262 void AllocateWorkspaceTensor(
263 OpKernelContext* context,
264 const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
265 MklDnnData<uint8>* dnn_data_wksp) {
266 DCHECK(dnn_data_wksp != nullptr);
267 Tensor* workspace_tensor = nullptr;
268 memory::desc workspace_pd = lrn_fwd_prim_desc.workspace_desc();
269 size_t workspace_bytes = workspace_pd.get_size();
270 MklDnnShape workspace_mkl_shape;
271 // the workspace tensor is a uint8 tensor that has
272 // exactly the number of bytes necessary
273 workspace_mkl_shape.SetMklTensor(false);
274 TensorShape workspace_tf_shape;
275 workspace_tf_shape.AddDim(workspace_bytes);
276 AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
277 workspace_tf_shape, workspace_mkl_shape);
278 DCHECK(workspace_tensor != nullptr);
279 dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
280 }
281
SanityCheckInputs(OpKernelContext * context)282 void SanityCheckInputs(OpKernelContext* context) {
283 const Tensor& src_tensor = MklGetInput(context, kIdxInput);
284 MklDnnShape src_dnn_shape;
285 GetMklShape(context, kIdxInput, &src_dnn_shape);
286 if (src_dnn_shape.IsMklTensor()) {
287 OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4,
288 errors::InvalidArgument("input must be 4-dimensional"));
289 OP_REQUIRES(context,
290 FastBoundsCheck(src_tensor.NumElements(),
291 std::numeric_limits<int>::max()),
292 errors::InvalidArgument("argument to LRN too large"));
293 } else {
294 OP_REQUIRES(context, src_tensor.dims() == 4,
295 errors::InvalidArgument("input must be 4-dimensional"));
296 OP_REQUIRES(context,
297 FastBoundsCheck(src_tensor.NumElements(),
298 std::numeric_limits<int>::max()),
299 errors::InvalidArgument("argument to LRN too large"));
300 }
301 }
302 const int kIdxInput = 0, kIdxOutput = 0, kIdxWorkspace = 1;
303
304 typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
305 bool workspace_enabled_;
306 int depth_radius_;
307 float bias_;
308 float alpha_;
309 float beta_;
310 engine cpu_engine_;
311 std::shared_ptr<stream> fwd_stream_;
312 };
313
314 template <typename T>
315 class MklLRNGradOp : public OpKernel {
316 public:
MklLRNGradOp(OpKernelConstruction * context)317 explicit MklLRNGradOp(OpKernelConstruction* context)
318 : OpKernel(context), cpu_engine_(engine::kind::cpu, 0) {
319 int64 depth_radius64;
320 OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
321 OP_REQUIRES(
322 context,
323 FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
324 errors::InvalidArgument("depth_radius = ", depth_radius64,
325 " larger than int max"));
326 depth_radius_ = static_cast<int>(depth_radius64);
327 OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
328 OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
329 OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
330 workspace_enabled_ = false;
331 OP_REQUIRES_OK(context,
332 context->GetAttr("workspace_enabled", &workspace_enabled_));
333 bwd_stream_.reset(new stream(cpu_engine_));
334 }
335
Compute(OpKernelContext * context)336 void Compute(OpKernelContext* context) override {
337 try {
338 SanityCheckInputs(context);
339 if (!context->status().ok()) return;
340
341 MklDnnData<T> input_grad_dnn_data(&cpu_engine_);
342 MklDnnData<T> orig_input_dnn_data(&cpu_engine_);
343 MklDnnData<T> orig_output_dnn_data(&cpu_engine_);
344 MklDnnData<T> output_dnn_data(&cpu_engine_);
345
346 MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
347 orig_output_dnn_shape;
348 GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
349 GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
350 GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
351
352 // We only use oneDNN if all of the necessary inputs are present
353 // in oneDNN format, and Channel is the last dimension
354 bool can_use_mkldnn = workspace_enabled_ &&
355 input_grad_dnn_shape.IsMklTensor() &&
356 orig_input_dnn_shape.IsMklTensor() &&
357 orig_output_dnn_shape.IsMklTensor() &&
358 input_grad_dnn_shape.IsMklChannelDim(
359 input_grad_dnn_shape.GetDimension() - 1) &&
360 orig_input_dnn_shape.IsMklChannelDim(
361 orig_input_dnn_shape.GetDimension() - 1) &&
362 orig_output_dnn_shape.IsMklChannelDim(
363 orig_output_dnn_shape.GetDimension() - 1);
364
365 if (!can_use_mkldnn) {
366 // Fallback to eigen
367 MklDefaultToEigen(context);
368 return;
369 }
370 // At this point, we have the all clear to use MklDnn constructs
371 // Naming: diff_dst is input_gradient_tensor; src is orig_input_tensor.
372 const Tensor& input_grad_tensor = MklGetInput(context, kIdxGradient);
373 const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
374
375 // Get input sizes in MKL-DNN required NCHW format.
376 // LRN does not have data_format attribute. But by default it has
377 // NHWC format.
378 memory::desc original_output_md = orig_output_dnn_shape.GetCurLayout();
379 memory::desc target_diff_dst_md = ConfigureInputGradient(
380 input_grad_tensor, input_grad_dnn_shape, &input_grad_dnn_data);
381
382 memory::desc orig_input_md = orig_input_dnn_shape.GetCurLayout();
383 memory::dims orig_input_dims =
384 orig_input_dnn_shape.GetSizesAsMklDnnDims();
385 orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
386 orig_input_dnn_data.SetOpMemDesc(orig_input_dims,
387 memory::format_tag::nhwc);
388 orig_input_dnn_data.SetUsrMemDataHandle(&orig_input_tensor, bwd_stream_);
389
390 // output_dnn_data has the same shape as original input
391 output_dnn_data.SetUsrMem(orig_input_md);
392 output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format_tag::nhwc);
393
394 // MKL-DNN has a notion of kernel_size and not depth_radius.
395 int kernel_size = 2 * depth_radius_ + 1;
396 float new_alpha = alpha_ * kernel_size;
397
398 // Create LRN backward primitive descriptor. It requires LRN forward
399 // primitive descriptor also.
400 auto lrn_fwd_desc = lrn_forward::desc(
401 prop_kind::forward, dnnl::algorithm::lrn_across_channels,
402 orig_input_md, kernel_size, new_alpha, beta_, bias_);
403 auto lrn_fwd_prim_desc =
404 lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine_);
405 auto lrn_bwd_desc = lrn_backward::desc(
406 dnnl::algorithm::lrn_across_channels, original_output_md,
407 target_diff_dst_md, kernel_size, new_alpha, beta_, bias_);
408 auto lrn_bwd_prim_desc = lrn_backward::primitive_desc(
409 lrn_bwd_desc, cpu_engine_, lrn_fwd_prim_desc);
410
411 Tensor* output_tensor = nullptr;
412 auto orig_input_format = orig_input_dnn_shape.GetTfDataFormat();
413 AllocateOutputTensor(context, lrn_bwd_prim_desc, orig_input_dims,
414 orig_input_format, &output_tensor);
415 OP_REQUIRES_OK(context, context->status());
416 DCHECK(output_tensor != nullptr);
417 output_dnn_data.SetUsrMemDataHandle(output_tensor, bwd_stream_);
418
419 // Create LRN primitive and add it to the net
420 // At this point, workspace is enabled, so we don't need
421 // to check. Pass input workspace to LRN backward primitive.
422 const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
423 MklDnnData<uint8> workspace_dnn_data(&cpu_engine_);
424 ConfigureWorkspace(workspace_tensor, lrn_fwd_prim_desc.workspace_desc(),
425 &workspace_dnn_data);
426
427 // Check for input reordering on the diff dst input
428 input_grad_dnn_data.CheckReorderToOpMem(lrn_bwd_prim_desc.diff_dst_desc(),
429 cpu_engine_);
430
431 // Check for input reordering on the original input
432 orig_input_dnn_data.CheckReorderToOpMem(lrn_fwd_prim_desc.src_desc(),
433 cpu_engine_);
434
435 std::vector<primitive> net;
436 std::vector<std::unordered_map<int, memory>> net_args;
437 net.push_back(lrn_backward(lrn_bwd_prim_desc));
438 net_args.push_back({{DNNL_ARG_SRC, orig_input_dnn_data.GetOpMem()},
439 {DNNL_ARG_DIFF_DST, input_grad_dnn_data.GetOpMem()},
440 {DNNL_ARG_DST, output_dnn_data.GetOpMem()}});
441 net.push_back(lrn_backward(lrn_bwd_prim_desc));
442 net.at(0).execute(*bwd_stream_, net_args.at(0));
443 } catch (dnnl::error& e) {
444 string error_msg = "Status: " + std::to_string(e.status) +
445 ", message: " + string(e.message) + ", in file " +
446 string(__FILE__) + ":" + std::to_string(__LINE__);
447 OP_REQUIRES_OK(
448 context,
449 errors::Aborted("Operation received an exception:", error_msg));
450 }
451 }
452
AllocateOutputTensor(OpKernelContext * context,const lrn_backward::primitive_desc & lrn_bkwd_prim_desc,const memory::dims output_dims_mkl_order,const MklTensorFormat & output_tf_format,Tensor ** output_tensor)453 void AllocateOutputTensor(
454 OpKernelContext* context,
455 const lrn_backward::primitive_desc& lrn_bkwd_prim_desc,
456 const memory::dims output_dims_mkl_order,
457 const MklTensorFormat& output_tf_format, Tensor** output_tensor) {
458 DCHECK(output_tensor != nullptr);
459 memory::desc dst_pd = lrn_bkwd_prim_desc.diff_src_desc();
460 MklDnnShape output_mkl_shape;
461
462 // We assume that all outputs at this point are MKL Tensors
463 output_mkl_shape.SetMklTensor(true);
464 output_mkl_shape.SetMklLayout(&dst_pd);
465 output_mkl_shape.SetElemType(MklDnnType<T>());
466 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
467 output_dims_mkl_order, output_tf_format);
468
469 TensorShape output_tf_shape;
470 size_t num_bytes = dst_pd.get_size();
471 CHECK_EQ(num_bytes % sizeof(T), 0);
472 output_tf_shape.AddDim(num_bytes / sizeof(T));
473 AllocateOutputSetMklShape(context, kIdxOutput, output_tensor,
474 output_tf_shape, output_mkl_shape);
475 }
476
ConfigureInputGradient(const Tensor & input_grad_tensor,const MklDnnShape & input_grad_dnn_shape,MklDnnData<T> * input_grad_dnn_data)477 memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor,
478 const MklDnnShape& input_grad_dnn_shape,
479 MklDnnData<T>* input_grad_dnn_data) {
480 DCHECK(input_grad_dnn_data != nullptr);
481 // This shouldn't be necessary at this point, but just in case
482 DCHECK(input_grad_dnn_shape.IsMklTensor() == true);
483
484 memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout();
485 memory::dims orig_input_dims = input_grad_dnn_shape.GetSizesAsMklDnnDims();
486 input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor);
487 input_grad_dnn_data->SetOpMemDesc(orig_input_dims,
488 memory::format_tag::nhwc);
489 return input_grad_md;
490 }
491
ConfigureWorkspace(const Tensor & workspace_tensor,memory::desc workspace_pd,MklDnnData<uint8> * workspace_dnn_data)492 void ConfigureWorkspace(const Tensor& workspace_tensor,
493 memory::desc workspace_pd,
494 MklDnnData<uint8>* workspace_dnn_data) {
495 DCHECK(workspace_dnn_data);
496
497 workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
498 }
499
500 // Fallback implementation - Taken from lrn_op.cc
501 // TODO(intel-tf) Check if we can use EigenLRNOp directly
502 // instead of making a copy.
MklDefaultToEigen(OpKernelContext * context)503 void MklDefaultToEigen(OpKernelContext* context) {
504 Tensor input_gradient_tensor;
505 Tensor orig_input_tensor;
506 Tensor orig_output_tensor;
507
508 MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
509 orig_output_dnn_shape;
510 GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
511 GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
512 GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
513
514 if (input_grad_dnn_shape.IsMklTensor()) {
515 OP_REQUIRES_OK(
516 context,
517 ConvertMklToTF<T>(context, MklGetInput(context, kIdxGradient),
518 input_grad_dnn_shape, &input_gradient_tensor));
519 } else {
520 input_gradient_tensor = MklGetInput(context, kIdxGradient);
521 }
522
523 if (orig_input_dnn_shape.IsMklTensor()) {
524 OP_REQUIRES_OK(context, ConvertMklToTF<T>(
525 context, MklGetInput(context, kIdxOrigInput),
526 orig_input_dnn_shape, &orig_input_tensor));
527 } else {
528 orig_input_tensor = MklGetInput(context, kIdxOrigInput);
529 }
530
531 if (orig_output_dnn_shape.IsMklTensor()) {
532 OP_REQUIRES_OK(context, ConvertMklToTF<T>(
533 context, MklGetInput(context, kIdxOrigOutput),
534 orig_output_dnn_shape, &orig_output_tensor));
535 } else {
536 orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
537 }
538
539 const int64 batch = static_cast<int64_t>(input_gradient_tensor.dim_size(0));
540 const int64 rows = static_cast<int64_t>(input_gradient_tensor.dim_size(1));
541 const int64 cols = static_cast<int64_t>(input_gradient_tensor.dim_size(2));
542 const int64 depth = static_cast<int64_t>(input_gradient_tensor.dim_size(3));
543 const auto nodes = cols * rows;
544
545 auto grads_shaped =
546 input_gradient_tensor.shaped<T, 2>({nodes * batch, depth});
547
548 auto in_shaped = orig_input_tensor.shaped<T, 2>({nodes * batch, depth});
549 auto activations = orig_output_tensor.shaped<T, 2>({nodes * batch, depth});
550
551 Tensor* output_dnn_data;
552 MklDnnShape mkl_output_mkl_shape;
553 mkl_output_mkl_shape.SetMklTensor(false);
554 mkl_output_mkl_shape.SetDimensions(4);
555 AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
556 input_gradient_tensor.shape(),
557 mkl_output_mkl_shape);
558
559 auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
560 out_shaped.setZero();
561 auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
562 depth](int64 begin, int64 end) {
563 for (int64 i = begin; i < end; ++i) {
564 for (int64 j = 0; j < depth; ++j) {
565 int64 depth_begin = std::max<int64_t>(0, j - depth_radius_);
566 int64 depth_end = std::min<int64_t>(depth, j + depth_radius_ + 1);
567
568 T norm(0);
569 for (int64 k = depth_begin; k < depth_end; ++k) {
570 norm += in_shaped(i, k) * in_shaped(i, k);
571 }
572 norm = alpha_ * norm + bias_;
573 DCHECK_GT(norm, T(1e-6));
574 for (int64 k = depth_begin; k < depth_end; ++k) {
575 T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
576 activations(i, j) / norm;
577 if (k == j) {
578 dyi += Eigen::numext::pow(norm, -beta_);
579 }
580 dyi *= grads_shaped(i, j);
581 const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) += dyi;
582 }
583 }
584 }
585 };
586 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
587 Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
588 depth * depth, shard);
589 }
590
SanityCheckInputs(OpKernelContext * context)591 void SanityCheckInputs(OpKernelContext* context) {
592 const Tensor& input_gradient_tensor = MklGetInput(context, kIdxGradient);
593 const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
594 const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
595 const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
596 MklDnnShape in_grads_dnn_shape, in_image_dnn_shape, out_image_dnn_shape,
597 workspace_dnn_shape;
598 GetMklShape(context, kIdxGradient, &in_grads_dnn_shape);
599 GetMklShape(context, kIdxOrigInput, &in_image_dnn_shape);
600 GetMklShape(context, kIdxOrigOutput, &out_image_dnn_shape);
601 GetMklShape(context, kIdxWorkspace, &workspace_dnn_shape);
602 if (in_grads_dnn_shape.IsMklTensor()) {
603 OP_REQUIRES(context, in_grads_dnn_shape.GetDimension() == 4,
604 errors::InvalidArgument("Input gradient must be "
605 "4-dimensional"));
606 } else {
607 OP_REQUIRES(
608 context, input_gradient_tensor.dims() == 4,
609 errors::InvalidArgument("input gradient must be 4-dimensional"));
610 }
611
612 if (in_image_dnn_shape.IsMklTensor()) {
613 OP_REQUIRES(context, in_image_dnn_shape.GetDimension() == 4,
614 errors::InvalidArgument("input images must be "
615 "4-dimensional"));
616 } else {
617 OP_REQUIRES(context, orig_input_tensor.dims() == 4,
618 errors::InvalidArgument("input images must be "
619 "4-dimensional"));
620 }
621
622 if (out_image_dnn_shape.IsMklTensor()) {
623 OP_REQUIRES(context, out_image_dnn_shape.GetDimension() == 4,
624 errors::InvalidArgument("Output image must be "
625 "4-dimensional"));
626 } else {
627 OP_REQUIRES(
628 context, orig_output_tensor.dims() == 4,
629 errors::InvalidArgument("Output image must be 4-dimensional"));
630 }
631
632 if (workspace_enabled_) {
633 if (workspace_dnn_shape.IsMklTensor()) {
634 OP_REQUIRES(
635 context, workspace_dnn_shape.IsMklTensor() == false,
636 errors::InvalidArgument("Workspace should not be MKL Tensor."));
637 } else {
638 OP_REQUIRES(context, workspace_tensor.dims() == 1,
639 errors::InvalidArgument("Workspace must be 1-dimensional"));
640 }
641 }
642 }
643
644 // Input("input_grads: T")
645 // Input("input_image: T")
646 // Input("output_image: T")
647 // Input("workspace: uint8")
648 const int kIdxGradient = 0, kIdxOrigInput = 1, kIdxOrigOutput = 2,
649 kIdxWorkspace = 3, kIdxOutput = 0;
650
651 typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
652 bool workspace_enabled_;
653 int depth_radius_;
654 float bias_;
655 float alpha_;
656 float beta_;
657 engine cpu_engine_;
658 std::shared_ptr<stream> bwd_stream_;
659 };
660
661 #define REGISTER_MKL_LRN_CPU(T) \
662 REGISTER_KERNEL_BUILDER( \
663 Name("_MklLRN") \
664 .Device(DEVICE_CPU) \
665 .TypeConstraint<T>("T") \
666 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
667 MklLRNOp<T>); \
668 REGISTER_KERNEL_BUILDER( \
669 Name("_MklLRNGrad") \
670 .Device(DEVICE_CPU) \
671 .TypeConstraint<T>("T") \
672 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
673 MklLRNGradOp<T>);
674
675 TF_CALL_float(REGISTER_MKL_LRN_CPU);
676
677 } // namespace tensorflow
678
679 #endif // INTEL_MKL
680