xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/lrn_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // LRN = Local Response Normalization
17 // See docs in ../ops/nn_ops.cc.
18 
19 #define EIGEN_USE_THREADS
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 
29 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
30 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
31 #endif
32 
33 #if !defined(IS_MOBILE_PLATFORM)
34 #include "tensorflow/core/util/work_sharder.h"
35 #endif
36 
37 #if GOOGLE_CUDA
38 #include "third_party/gpus/cuda/include/cuda.h"
39 #endif
40 
41 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
42 #include "tensorflow/core/kernels/conv_2d.h"
43 #include "tensorflow/core/kernels/gpu_utils.h"
44 #if TENSORFLOW_USE_ROCM
45 #include "tensorflow/core/kernels/conv_ops_gpu.h"
46 #endif
47 #include "tensorflow/core/platform/stream_executor.h"
48 #include "tensorflow/core/util/stream_executor_util.h"
49 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50 
51 namespace tensorflow {
52 
53 namespace {
54 
55 // When the depth is large and beta_ is 0.5 or 1.0, Single-threaded
56 // LRN is faster than the main band matrix approach used
57 // below. Benchmarks suggest switching to SingleThreadedLRN when depth > 384.
58 const int kSingleThreadedLRNDepthCutoff = 384;
59 
60 // Create a depth-by-depth band matrix with 1s along a swath of size (2 *
61 // depth_radius + 1) around the diagonal.
62 template <typename T>
GetBandMatrix(int depth,int depth_radius,Eigen::Tensor<T,2,Eigen::RowMajor> * result)63 void GetBandMatrix(int depth, int depth_radius,
64                    Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
65   result->setZero();
66   for (int row = 0; row < depth; ++row) {
67     const int begin = std::max<int>(0, row - depth_radius);
68     const int end = std::min<int>(depth, row + depth_radius + 1);
69     Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
70     Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
71     result->slice(start, sizes).setConstant(T(1));
72   }
73 }
74 
75 }  // namespace
76 
77 typedef Eigen::ThreadPoolDevice CPUDevice;
78 typedef Eigen::GpuDevice GPUDevice;
79 
80 template <typename Device, typename T>
81 struct LaunchLRN;
82 
83 template <typename T>
84 struct LaunchLRN<CPUDevice, T> {
LaunchLRNtensorflow::LaunchLRN85   LaunchLRN(int depth_radius, T bias, T alpha, T beta)
86       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
87 
launchtensorflow::LaunchLRN88   void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
89               Tensor* output) {
90     const int batch = static_cast<int>(in.dim_size(0));
91     const int rows = static_cast<int>(in.dim_size(1));
92     const int cols = static_cast<int>(in.dim_size(2));
93     const int depth = static_cast<int>(in.dim_size(3));
94 
95 #if defined(IS_MOBILE_PLATFORM)
96     SingleThreadedLRN(in, batch, rows, cols, depth, output);
97 #else
98     const int nodes = cols * rows;
99     if (depth > kSingleThreadedLRNDepthCutoff &&
100         (beta_ == T(0.5) || beta_ == T(1))) {
101       SingleThreadedLRN(in, batch, rows, cols, depth, output);
102       return;
103     }
104 
105     auto in_shaped = in.shaped<T, 2>({nodes * batch, depth});
106 
107     // Multiplying the input with the band matrix has the effect of reducing the
108     // correct patch along the depth.
109     Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
110     GetBandMatrix<T>(depth, depth_radius_, &multiplier);
111 
112     auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
113     Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
114     auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
115     if (beta_ == T(1)) {
116       out_shaped.device(context->eigen_cpu_device()) =
117           in_shaped * tmp.inverse();
118     } else if (beta_ == T(0.5)) {
119       out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt();
120     } else {
121       out_shaped.device(context->eigen_cpu_device()) =
122           in_shaped * (tmp.log() * -beta_).exp();
123     }
124 #endif
125   }
126 
127  private:
128   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
129 
SingleThreadedLRNtensorflow::LaunchLRN130   void SingleThreadedLRN(const Tensor& in, const int batch, const int rows,
131                          const int cols, const int depth, Tensor* out) {
132     Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> data_in(
133         in.flat<T>().data(), depth, batch * rows * cols);
134 
135     Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> data_out(
136         out->flat<T>().data(), depth, batch * rows * cols);
137 
138     const int double_depth_radius = depth_radius_ * 2;
139     Eigen::Matrix<T, Eigen::Dynamic, 1> padded_square(data_in.rows() +
140                                                       double_depth_radius);
141     padded_square.setZero();
142     for (int r = 0; r < data_in.cols(); ++r) {
143       // Do local response normalization for data_in(:, r). First, compute the
144       // square and store them in buffer for repeated use.
145       padded_square.block(depth_radius_, 0, data_out.rows(), 1) =
146           data_in.col(r).cwiseProduct(data_in.col(r)) * alpha_;
147       // Then, compute the scale and write it to data_out.
148       T accumulated_scale(0);
149       for (int i = 0; i < double_depth_radius; ++i) {
150         accumulated_scale += padded_square(i);
151       }
152       for (int i = 0; i < data_in.rows(); ++i) {
153         accumulated_scale += padded_square(i + double_depth_radius);
154         data_out(i, r) = bias_ + accumulated_scale;
155         accumulated_scale -= padded_square(i);
156       }
157     }
158 
159     if (beta_ == T(1)) {
160       data_out.array() = data_in.array() * data_out.array().inverse();
161     } else if (beta_ == T(0.5)) {
162       data_out.array() = data_in.array() * data_out.array().rsqrt();
163     } else {
164       data_out.array() =
165           data_in.array() * (data_out.array().log() * -beta_).exp();
166     }
167   }
168 
169   int depth_radius_;
170   T bias_;
171   T alpha_;
172   T beta_;
173 };
174 
175 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
176 
177 template <typename T>
178 struct LaunchLRN<GPUDevice, T> {
LaunchLRNtensorflow::LaunchLRN179   LaunchLRN(int depth_radius, T bias, T alpha, T beta)
180       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
181 
launchtensorflow::LaunchLRN182   void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
183               Tensor* output) {
184 #if GOOGLE_CUDA
185     OP_REQUIRES(
186         context, beta_ >= 0.01,
187         errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
188 
189     OP_REQUIRES(
190         context, depth_radius_ > 0 && depth_radius_ <= 7,
191         errors::InvalidArgument("cuDNN requires depth_radius in [1, 7], got: ",
192                                 depth_radius_));
193     OP_REQUIRES(
194         context, bias_ >= 1e-5,
195         errors::InvalidArgument("cuDNN requires bias >= 1e-5, got: ", bias_));
196 
197     // Cast to platform-specific int to avoid conversion warnings.
198     const int batch = static_cast<int>(in.dim_size(0));
199     const int rows = static_cast<int>(in.dim_size(1));
200     const int cols = static_cast<int>(in.dim_size(2));
201     const int depth = static_cast<int>(in.dim_size(3));
202 
203     se::dnn::BatchDescriptor dimensions_desc;
204     dimensions_desc.set_count(batch)
205         .set_height(rows)
206         .set_width(cols)
207         .set_feature_map_count(depth)
208         .set_layout(se::dnn::DataLayout::kBatchYXDepth);
209 
210     se::dnn::NormalizeDescriptor normalize_desc;
211     normalize_desc.set_bias(bias_)
212         .set_range(depth_radius_)
213         .set_alpha(alpha_)
214         .set_beta(beta_);
215 
216     auto input_data = StreamExecutorUtil::AsDeviceMemory<T>(in);
217     auto output_data = StreamExecutorUtil::AsDeviceMemory<T>(*output);
218 
219     auto* stream = context->op_device_context()->stream();
220     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
221 
222     bool status =
223         stream
224             ->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc,
225                                           input_data, &output_data)
226             .ok();
227     OP_REQUIRES(context, status,
228                 errors::Internal("NormalizeWithDimensions launch failed"));
229 #elif TENSORFLOW_USE_ROCM
230     // For NHWC input/output tensors, convert to NCHW because it's the only
231     // supported format in MIOpen for now.
232 
233     // Cast to platform-specific int to avoid conversion warnings.
234     const int batch = static_cast<int>(in.dim_size(0));
235     const int rows = static_cast<int>(in.dim_size(1));
236     const int cols = static_cast<int>(in.dim_size(2));
237     const int depth = static_cast<int>(in.dim_size(3));
238 
239     Tensor transformed_input;
240     OP_REQUIRES_OK(context,
241                    context->allocate_temp(
242                        DataTypeToEnum<T>::value,
243                        ShapeFromFormat(FORMAT_NCHW, in.shape(), FORMAT_NHWC),
244                        &transformed_input));
245     functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
246                                            in.tensor<T, 4>(),
247                                            transformed_input.tensor<T, 4>());
248 
249     Tensor transformed_output;
250     OP_REQUIRES_OK(
251         context, context->allocate_temp(
252                      DataTypeToEnum<T>::value,
253                      ShapeFromFormat(FORMAT_NCHW, output->shape(), FORMAT_NHWC),
254                      &transformed_output));
255 
256     perftools::gputools::dnn::BatchDescriptor dimensions_desc;
257     dimensions_desc.set_count(batch)
258         .set_height(rows)
259         .set_width(cols)
260         .set_feature_map_count(depth)
261         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
262 
263     perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
264     normalize_desc.set_bias(bias_)
265         .set_range(depth_radius_)
266         .set_alpha(alpha_)
267         .set_beta(beta_);
268 
269     auto input_data =
270         AsDeviceMemory(transformed_input.template flat<T>().data(),
271                        transformed_input.template flat<T>().size());
272     auto output_data =
273         AsDeviceMemory(transformed_output.template flat<T>().data(),
274                        transformed_output.template flat<T>().size());
275 
276     auto* stream = context->op_device_context()->stream();
277     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
278 
279     bool status =
280         stream
281             ->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc,
282                                           input_data, &output_data)
283             .ok();
284     OP_REQUIRES(context, status,
285                 errors::Internal("NormalizeWithDimensions launch failed"));
286 
287     // Need to convert it back to NHWC once MIOpen kernels finishes.
288     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
289     functor::NCHWToNHWC<GPUDevice, T, 4>()(
290         context->eigen_device<GPUDevice>(),
291         toConstTensor(transformed_output).template tensor<T, 4>(),
292         output->tensor<T, 4>());
293 #endif
294   }
295 
296   int depth_radius_;
297   T bias_;
298   T alpha_;
299   T beta_;
300 };
301 
302 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
303 
304 template <typename Device, typename T>
305 class LRNOp : public OpKernel {
306  public:
LRNOp(OpKernelConstruction * context)307   explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) {
308     int64_t depth_radius64;
309     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
310     OP_REQUIRES(
311         context,
312         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
313         errors::InvalidArgument("depth_radius = ", depth_radius64,
314                                 " larger than int max"));
315     depth_radius_ = static_cast<int>(depth_radius64);
316     float tmp;
317     OP_REQUIRES_OK(context, context->GetAttr("bias", &tmp));
318     bias_ = T(tmp);
319     OP_REQUIRES_OK(context, context->GetAttr("alpha", &tmp));
320     alpha_ = T(tmp);
321     OP_REQUIRES_OK(context, context->GetAttr("beta", &tmp));
322     beta_ = T(tmp);
323   }
324 
Compute(OpKernelContext * context)325   void Compute(OpKernelContext* context) override {
326     const Tensor& in = context->input(0);
327     OP_REQUIRES(context, in.dims() == 4,
328                 errors::InvalidArgument("in must be 4-dimensional"));
329     OP_REQUIRES(
330         context,
331         FastBoundsCheck(in.NumElements(), std::numeric_limits<int>::max()),
332         errors::InvalidArgument("argument to LRN too large"));
333     // Cast to platform-specific int to avoid conversion warnings.
334     const int batch = static_cast<int>(in.dim_size(0));
335     const int rows = static_cast<int>(in.dim_size(1));
336     const int cols = static_cast<int>(in.dim_size(2));
337     const int depth = static_cast<int>(in.dim_size(3));
338 
339     OP_REQUIRES(context,
340                 (depth + depth_radius_) <= std::numeric_limits<int>::max(),
341                 errors::InvalidArgument("depth ", depth, " + depth_radius ",
342                                         depth_radius_, " exceeds int max."));
343 
344     Tensor* output = nullptr;
345     OP_REQUIRES_OK(context,
346                    context->allocate_output(
347                        0, TensorShape({batch, rows, cols, depth}), &output));
348 
349     LaunchLRN<Device, T> launcher(depth_radius_, bias_, alpha_, beta_);
350     launcher.launch(context, this, in, output);
351   }
352 
353  private:
354   int depth_radius_;
355   T bias_;
356   T alpha_;
357   T beta_;
358 };
359 
360 #define REGISTER_CPU(T)                                      \
361   REGISTER_KERNEL_BUILDER(                                   \
362       Name("LRN").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
363       LRNOp<CPUDevice, T>);
364 TF_CALL_float(REGISTER_CPU);
365 TF_CALL_half(REGISTER_CPU);
366 
367 #undef REGISTER_CPU
368 
369 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
370 
371 #define REGISTER_GPU(T)                                      \
372   REGISTER_KERNEL_BUILDER(                                   \
373       Name("LRN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
374       LRNOp<GPUDevice, T>);
375 TF_CALL_float(REGISTER_GPU);
376 
377 #undef REGISTER_GPU
378 
379 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
380 
381 #if !defined(IS_MOBILE_PLATFORM)
382 
383 template <typename Device, typename T>
384 struct LaunchLRNGrad;
385 
386 template <typename T>
387 struct LaunchLRNGrad<CPUDevice, T> {
LaunchLRNGradtensorflow::LaunchLRNGrad388   LaunchLRNGrad(int depth_radius, T bias, T alpha, T beta)
389       : depth_radius_(depth_radius),
390         bias_(bias),
391         alpha_(alpha),
392         beta_(beta),
393         alpha_beta_2_(T(-2) * alpha * beta) {}
394 
launchtensorflow::LaunchLRNGrad395   void launch(OpKernelContext* context, OpKernel* kernel,
396               const Tensor& in_grads, const Tensor& in_image,
397               const Tensor& out_image, Tensor* output) {
398     const int64_t batch = in_grads.dim_size(0);
399     const int64_t rows = in_grads.dim_size(1);
400     const int64_t cols = in_grads.dim_size(2);
401     const int64_t depth = in_grads.dim_size(3);
402     const auto nodes = cols * rows;
403     auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth});
404     auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth});
405     auto activations = out_image.shaped<T, 2>({nodes * batch, depth});
406 
407     auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
408     out_shaped.setZero();
409 
410     auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
411                   depth](int64_t begin, int64_t end) {
412       for (int64_t i = begin; i < end; ++i) {
413         for (int64_t j = 0; j < depth; ++j) {
414           // Let y be the LRN activations and x be the inputs along the depth
415           // dimension. (LRN operates independently along rows, cols, and
416           // batch).
417           // We have
418           // yi = xi / (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius}
419           //      x_j^2))^beta
420           //
421           // Let N = (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius}
422           //           x_j^2))
423           // dy_i/dx_i = (N^beta - xi. beta*N^(beta-1)*2*alpha*xi)/N^(2*beta)
424           // dy_i/dx_j = (       - xi. beta*N^(beta-1)*2*alpha*xj)/N^(2*beta)
425           //
426           // NOTE(keveman) : We can compute N by doing (yi/xi) ^ (1/beta).
427           // However, this is numerically unstable for small values of xi. We
428           // compute N explicitly here to avoid that.
429 
430           T gs = grads_shaped(i, j);
431           if (gs == T(0)) continue;
432 
433           int64_t depth_begin = std::max<int64_t>(0, j - depth_radius_);
434           int64_t depth_end = std::min<int64_t>(depth, j + depth_radius_ + 1);
435 
436           T norm(0);
437           for (int64_t k = depth_begin; k < depth_end; ++k) {
438             norm += in_shaped(i, k) * in_shaped(i, k);
439           }
440           norm = alpha_ * norm + bias_;
441           DCHECK_GT(norm, T(1e-6));
442           T pre_computed_pow = Eigen::numext::pow(norm, -beta_);
443           T activations_ab2 = alpha_beta_2_ * activations(i, j);
444           for (int64_t k = depth_begin; k < depth_end; ++k) {
445             T dyi = in_shaped(i, k) * activations_ab2 / norm;
446             if (k == j) {
447               dyi += pre_computed_pow;
448             }
449             dyi *= gs;
450             const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) += dyi;
451           }
452         }
453       }
454     };
455     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
456     Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
457           depth * depth, shard);
458   }
459 
460   int depth_radius_;
461   T bias_;
462   T alpha_;
463   T beta_;
464   T alpha_beta_2_;
465 };
466 
467 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
468 
469 template <typename T>
470 struct LaunchLRNGrad<GPUDevice, T> {
LaunchLRNGradtensorflow::LaunchLRNGrad471   LaunchLRNGrad(int depth_radius, T bias, T alpha, T beta)
472       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
473 
launchtensorflow::LaunchLRNGrad474   void launch(OpKernelContext* context, OpKernel* kernel,
475               const Tensor& in_grads, const Tensor& in_image,
476               const Tensor& out_image, Tensor* output) {
477 #if GOOGLE_CUDA
478     OP_REQUIRES(
479         context, beta_ >= 0.01,
480         errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
481 
482     OP_REQUIRES(
483         context, depth_radius_ > 0 && depth_radius_ <= 7,
484         errors::InvalidArgument("cuDNN requires depth_radius in [1, 7], got: ",
485                                 depth_radius_));
486     OP_REQUIRES(
487         context, bias_ >= 1e-5,
488         errors::InvalidArgument("cuDNN requires bias >= 1e-5, got: ", bias_));
489 
490     const int64_t batch = in_grads.dim_size(0);
491     const int64_t rows = in_grads.dim_size(1);
492     const int64_t cols = in_grads.dim_size(2);
493     const int64_t depth = in_grads.dim_size(3);
494 
495     se::dnn::BatchDescriptor dimensions_desc;
496     dimensions_desc.set_count(batch)
497         .set_height(rows)
498         .set_width(cols)
499         .set_feature_map_count(depth)
500         .set_layout(se::dnn::DataLayout::kBatchYXDepth);
501 
502     se::dnn::NormalizeDescriptor normalize_desc;
503     normalize_desc.set_bias(bias_)
504         .set_range(depth_radius_)
505         .set_alpha(alpha_)
506         .set_beta(beta_);
507 
508     auto input_grads_data = StreamExecutorUtil::AsDeviceMemory<T>(in_grads);
509     auto input_image_data = StreamExecutorUtil::AsDeviceMemory<T>(in_image);
510     auto output_image_data = StreamExecutorUtil::AsDeviceMemory<T>(out_image);
511     auto output_grads_data = StreamExecutorUtil::AsDeviceMemory<T>(*output);
512 
513     auto* stream = context->op_device_context()->stream();
514     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
515 
516     bool status =
517         stream
518             ->ThenNormalizeBackwardWithDimensions(
519                 normalize_desc, dimensions_desc, input_image_data,
520                 output_image_data, input_grads_data, &output_grads_data)
521             .ok();
522     OP_REQUIRES(
523         context, status,
524         errors::Internal("NormalizeBackwardWithDimensions launch failed"));
525 #elif TENSORFLOW_USE_ROCM
526     // For NHWC input/output tensors, convert to NCHW because it's the only
527     // supported format in MIOpen for now.
528     const int64 batch = in_grads.dim_size(0);
529     const int64 rows = in_grads.dim_size(1);
530     const int64 cols = in_grads.dim_size(2);
531     const int64 depth = in_grads.dim_size(3);
532 
533     Tensor transformed_in_grads;
534     OP_REQUIRES_OK(context, context->allocate_temp(
535                                 DataTypeToEnum<T>::value,
536                                 ShapeFromFormat(FORMAT_NCHW, in_grads.shape(),
537                                                 FORMAT_NHWC),
538                                 &transformed_in_grads));
539     functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
540                                            in_grads.tensor<T, 4>(),
541                                            transformed_in_grads.tensor<T, 4>());
542 
543     Tensor transformed_in_image;
544     OP_REQUIRES_OK(context, context->allocate_temp(
545                                 DataTypeToEnum<T>::value,
546                                 ShapeFromFormat(FORMAT_NCHW, in_image.shape(),
547                                                 FORMAT_NHWC),
548                                 &transformed_in_image));
549     functor::NHWCToNCHW<GPUDevice, T, 4>()(context->eigen_device<GPUDevice>(),
550                                            in_image.tensor<T, 4>(),
551                                            transformed_in_image.tensor<T, 4>());
552 
553     Tensor transformed_out_image;
554     OP_REQUIRES_OK(context, context->allocate_temp(
555                                 DataTypeToEnum<T>::value,
556                                 ShapeFromFormat(FORMAT_NCHW, out_image.shape(),
557                                                 FORMAT_NHWC),
558                                 &transformed_out_image));
559     functor::NHWCToNCHW<GPUDevice, T, 4>()(
560         context->eigen_device<GPUDevice>(), out_image.tensor<T, 4>(),
561         transformed_out_image.tensor<T, 4>());
562 
563     Tensor transformed_output;
564     OP_REQUIRES_OK(
565         context, context->allocate_temp(
566                      DataTypeToEnum<T>::value,
567                      ShapeFromFormat(FORMAT_NCHW, output->shape(), FORMAT_NHWC),
568                      &transformed_output));
569 
570     perftools::gputools::dnn::BatchDescriptor dimensions_desc;
571     dimensions_desc.set_count(batch)
572         .set_height(rows)
573         .set_width(cols)
574         .set_feature_map_count(depth)
575         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
576 
577     perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
578     normalize_desc.set_bias(bias_)
579         .set_range(depth_radius_)
580         .set_alpha(alpha_)
581         .set_beta(beta_);
582 
583     auto input_grads_data =
584         AsDeviceMemory(transformed_in_grads.template flat<T>().data(),
585                        transformed_in_grads.template flat<T>().size());
586     auto input_image_data =
587         AsDeviceMemory(transformed_in_image.template flat<T>().data(),
588                        transformed_in_image.template flat<T>().size());
589     auto output_image_data =
590         AsDeviceMemory(transformed_out_image.template flat<T>().data(),
591                        transformed_out_image.template flat<T>().size());
592     auto output_grads_data =
593         AsDeviceMemory(transformed_output.template flat<T>().data(),
594                        transformed_output.template flat<T>().size());
595 
596     auto* stream = context->op_device_context()->stream();
597     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
598 
599     static int64 NormalizeBackwardScratchSize = GetDnnWorkspaceLimit(
600         // default value is in bytes despite the name of the environment
601         // variable
602         "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
603     );
604 
605     DnnScratchAllocator scratch_allocator(NormalizeBackwardScratchSize,
606                                           context);
607     bool status = stream
608                       ->ThenNormalizeBackwardWithDimensions(
609                           normalize_desc, dimensions_desc, input_image_data,
610                           output_image_data, input_grads_data,
611                           &output_grads_data, &scratch_allocator)
612                       .ok();
613     OP_REQUIRES(
614         context, status,
615         errors::Internal("NormalizeBackwardWithDimensions launch failed"));
616 
617     // Need to convert it back to NHWC once MIOpen kernels finishes.
618     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
619     functor::NCHWToNHWC<GPUDevice, T, 4>()(
620         context->eigen_device<GPUDevice>(),
621         toConstTensor(transformed_output).template tensor<T, 4>(),
622         output->tensor<T, 4>());
623 #endif
624   }
625 
626   int depth_radius_;
627   T bias_;
628   T alpha_;
629   T beta_;
630 };
631 
632 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
633 
634 template <typename Device, typename T>
635 class LRNGradOp : public OpKernel {
636  public:
LRNGradOp(OpKernelConstruction * context)637   explicit LRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
638     int64_t depth_radius64;
639     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
640     OP_REQUIRES(
641         context,
642         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
643         errors::InvalidArgument("depth_radius = ", depth_radius64,
644                                 " larger than int max"));
645     depth_radius_ = static_cast<int>(depth_radius64);
646     float tmp;
647     OP_REQUIRES_OK(context, context->GetAttr("bias", &tmp));
648     bias_ = T(tmp);
649     OP_REQUIRES_OK(context, context->GetAttr("alpha", &tmp));
650     alpha_ = T(tmp);
651     OP_REQUIRES_OK(context, context->GetAttr("beta", &tmp));
652     beta_ = T(tmp);
653   }
654 
Compute(OpKernelContext * context)655   void Compute(OpKernelContext* context) override {
656     const Tensor& in_grads = context->input(0);
657     const Tensor& in_image = context->input(1);
658     const Tensor& out_image = context->input(2);
659 
660     OP_REQUIRES(context, in_grads.dims() == 4 && in_image.dims() == 4,
661                 errors::InvalidArgument("inputs must be 4-dimensional"));
662     const int64_t batch = in_grads.dim_size(0);
663     const int64_t rows = in_grads.dim_size(1);
664     const int64_t cols = in_grads.dim_size(2);
665     const int64_t depth = in_grads.dim_size(3);
666     OP_REQUIRES(
667         context,
668         in_image.dim_size(0) == batch && in_image.dim_size(1) == rows &&
669             in_image.dim_size(2) == cols && in_image.dim_size(3) == depth &&
670             out_image.dim_size(0) == batch && out_image.dim_size(1) == rows &&
671             out_image.dim_size(2) == cols && out_image.dim_size(3) == depth &&
672             out_image.dims() == 4,
673         errors::InvalidArgument(
674             "input_grads, input_image, and out_image should have the same "
675             "shape"));
676 
677     Tensor* output = nullptr;
678     OP_REQUIRES_OK(context,
679                    context->allocate_output(
680                        0, TensorShape({batch, rows, cols, depth}), &output));
681 
682     LaunchLRNGrad<Device, T> launcher(depth_radius_, bias_, alpha_, beta_);
683     launcher.launch(context, this, in_grads, in_image, out_image, output);
684   }
685 
686  private:
687   int depth_radius_;
688   T bias_;
689   T alpha_;
690   T beta_;
691 };
692 
693 #define REGISTER_CPU(T)                                          \
694   REGISTER_KERNEL_BUILDER(                                       \
695       Name("LRNGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
696       LRNGradOp<CPUDevice, T>);
697 TF_CALL_float(REGISTER_CPU);
698 TF_CALL_half(REGISTER_CPU);
699 
700 #undef REGISTER_CPU
701 
702 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
703 
704 #define REGISTER_GPU(T)                                          \
705   REGISTER_KERNEL_BUILDER(                                       \
706       Name("LRNGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
707       LRNGradOp<GPUDevice, T>);
708 TF_CALL_float(REGISTER_GPU);
709 
710 #undef REGISTER_GPU
711 
712 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
713 
714 #endif  // !defined(IS_MOBILE_PLATFORM)
715 
716 }  // namespace tensorflow
717