xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/fused_batch_norm_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <atomic>
17 
18 #define EIGEN_USE_THREADS
19 
20 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 #define EIGEN_USE_GPU
22 #if GOOGLE_CUDA
23 #include "third_party/gpus/cudnn/cudnn.h"
24 #endif  // GOOGLE_CUDA
25 
26 #include "tensorflow/core/kernels/conv_2d.h"
27 #include "tensorflow/core/platform/stream_executor.h"
28 #include "tensorflow/core/util/stream_executor_util.h"
29 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
30 
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_types.h"
36 #include "tensorflow/core/kernels/fill_functor.h"
37 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
38 #include "tensorflow/core/kernels/redux_functor.h"
39 #include "tensorflow/core/kernels/transpose_functor.h"
40 #include "tensorflow/core/lib/core/blocking_counter.h"
41 #include "tensorflow/core/util/env_var.h"
42 #include "tensorflow/core/util/tensor_format.h"
43 
44 namespace tensorflow {
45 using CPUDevice = Eigen::ThreadPoolDevice;
46 using GPUDevice = Eigen::GpuDevice;
47 
48 namespace functor {
49 
50 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51 using se::DeviceMemory;
52 using se::ScratchAllocator;
53 using se::Stream;
54 using se::port::StatusOr;
55 #endif
56 
ToString(FusedBatchNormActivationMode activation_mode)57 string ToString(FusedBatchNormActivationMode activation_mode) {
58   switch (activation_mode) {
59     case FusedBatchNormActivationMode::kIdentity:
60       return "Identity";
61     case FusedBatchNormActivationMode::kRelu:
62       return "Relu";
63   }
64 }
65 
ParseActivationMode(OpKernelConstruction * context,FusedBatchNormActivationMode * activation_mode)66 Status ParseActivationMode(OpKernelConstruction* context,
67                            FusedBatchNormActivationMode* activation_mode) {
68   string activation_mode_str;
69   TF_RETURN_IF_ERROR(context->GetAttr("activation_mode", &activation_mode_str));
70 
71   if (activation_mode_str == "Identity") {
72     *activation_mode = FusedBatchNormActivationMode::kIdentity;
73     return OkStatus();
74   }
75   if (activation_mode_str == "Relu") {
76     *activation_mode = FusedBatchNormActivationMode::kRelu;
77     return OkStatus();
78   }
79   return errors::InvalidArgument("Unsupported activation mode: ",
80                                  activation_mode_str);
81 }
82 
83 // Functor used by FusedBatchNormOp to do the computations.
84 template <typename Device, typename T, typename U, bool is_training>
85 struct FusedBatchNorm;
86 // Functor used by FusedBatchNormGradOp to do the computations when
87 // is_training=True.
88 template <typename Device, typename T, typename U>
89 struct FusedBatchNormGrad;
90 
91 template <typename T, typename U>
92 struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ true> {
operator ()tensorflow::functor::FusedBatchNorm93   void operator()(OpKernelContext* context, const Tensor& x_input,
94                   const Tensor& scale_input, const Tensor& offset_input,
95                   const Tensor& running_mean_input,
96                   const Tensor& running_variance_input,
97                   const Tensor* side_input, U epsilon, U exponential_avg_factor,
98                   FusedBatchNormActivationMode activation_mode,
99                   Tensor* y_output, Tensor* running_mean_output,
100                   Tensor* running_var_output, Tensor* saved_batch_mean_output,
101                   Tensor* saved_batch_var_output, TensorFormat tensor_format,
102                   bool use_reserved_space) {
103     OP_REQUIRES(context, side_input == nullptr,
104                 errors::Internal(
105                     "The CPU implementation of FusedBatchNorm does not support "
106                     "side input."));
107     OP_REQUIRES(context,
108                 activation_mode == FusedBatchNormActivationMode::kIdentity,
109                 errors::Internal("The CPU implementation of FusedBatchNorm "
110                                  "does not support activations."));
111 
112     if (use_reserved_space) {
113       Tensor* dummy_reserve_space = nullptr;
114       OP_REQUIRES_OK(context,
115                      context->allocate_output(5, {}, &dummy_reserve_space));
116       // Initialize the memory, to avoid sanitizer alerts.
117       dummy_reserve_space->flat<U>()(0) = U();
118     }
119 
120     // If input is empty, return NaN mean/variance
121     if (x_input.shape().num_elements() == 0) {
122       functor::SetNanFunctor<CPUDevice, U> f;
123       f(context->eigen_device<CPUDevice>(), running_mean_output->flat<U>());
124       f(context->eigen_device<CPUDevice>(), running_var_output->flat<U>());
125       return;
126     }
127 
128     Tensor transformed_x;
129     Tensor transformed_y;
130     if (tensor_format == FORMAT_NCHW) {
131       const int64_t in_batch = GetTensorDim(x_input, tensor_format, 'N');
132       const int64_t in_rows = GetTensorDim(x_input, tensor_format, 'H');
133       const int64_t in_cols = GetTensorDim(x_input, tensor_format, 'W');
134       const int64_t in_depths = GetTensorDim(x_input, tensor_format, 'C');
135       OP_REQUIRES_OK(context, context->allocate_temp(
136                                   DataTypeToEnum<T>::value,
137                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
138                                                   in_rows, in_cols, in_depths),
139                                   &transformed_x));
140       OP_REQUIRES_OK(context, context->allocate_temp(
141                                   DataTypeToEnum<T>::value,
142                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
143                                                   in_rows, in_cols, in_depths),
144                                   &transformed_y));
145       // Perform NCHW to NHWC
146       std::vector<int32> perm = {0, 2, 3, 1};
147       OP_REQUIRES_OK(
148           context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
149                                              x_input, perm, &transformed_x));
150     } else {
151       transformed_x = x_input;
152       transformed_y = *y_output;
153     }
154     typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>());
155     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
156     typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
157     typename TTypes<U>::ConstVec old_mean(running_mean_input.vec<U>());
158     typename TTypes<U>::ConstVec old_variance(running_variance_input.vec<U>());
159     typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
160     typename TTypes<U>::Vec new_mean(running_mean_output->vec<U>());
161     typename TTypes<U>::Vec new_variance(running_var_output->vec<U>());
162     typename TTypes<U>::Vec saved_batch_mean(saved_batch_mean_output->vec<U>());
163     typename TTypes<U>::Vec saved_batch_var(saved_batch_var_output->vec<U>());
164 
165     const CPUDevice& d = context->eigen_device<CPUDevice>();
166 
167     const int depth = x.dimension(3);
168     const int size = x.size();
169     const int rest_size = size / depth;
170     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
171 
172     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
173     one_by_depth.set(1, depth);
174     Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
175     Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
176     bcast_spec.set(0, rest_size);
177 
178     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
179     const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
180     U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
181     // This adjustment is for Bessel's correction
182     U rest_size_adjust =
183         static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one);
184 
185     Eigen::Tensor<U, 1, Eigen::RowMajor> batch_mean(depth);
186     Eigen::Tensor<U, 1, Eigen::RowMajor> batch_variance(depth);
187 
188     batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
189     auto x_centered = x_rest_by_depth -
190                       batch_mean.reshape(one_by_depth).broadcast(bcast_spec);
191 
192     batch_variance.device(d) =
193         x_centered.square().sum(reduce_dims) * rest_size_inv;
194     auto scaling_factor = ((batch_variance + epsilon).rsqrt() * scale)
195                               .eval()
196                               .reshape(one_by_depth)
197                               .broadcast(bcast_spec);
198     auto x_scaled = x_centered * scaling_factor;
199     auto x_shifted =
200         (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
201             .template cast<T>();
202 
203     y.reshape(rest_by_depth).device(d) = x_shifted;
204     if (exponential_avg_factor == U(1.0)) {
205       saved_batch_var.device(d) = batch_variance;
206       saved_batch_mean.device(d) = batch_mean;
207       new_variance.device(d) = batch_variance * rest_size_adjust;
208       new_mean.device(d) = batch_mean;
209     } else {
210       U one_minus_factor = U(1) - exponential_avg_factor;
211       saved_batch_var.device(d) = batch_variance;
212       saved_batch_mean.device(d) = batch_mean;
213       new_variance.device(d) =
214           one_minus_factor * old_variance +
215           (exponential_avg_factor * rest_size_adjust) * batch_variance;
216       new_mean.device(d) =
217           one_minus_factor * old_mean + exponential_avg_factor * batch_mean;
218     }
219 
220     if (tensor_format == FORMAT_NCHW) {
221       // Perform NHWC to NCHW
222       const std::vector<int32> perm = {0, 3, 1, 2};
223       const Status s = ::tensorflow::DoTranspose(
224           context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
225       if (!s.ok()) {
226         context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
227       }
228     }
229   }
230 };
231 
232 template <typename T, typename U>
233 struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ false> {
operator ()tensorflow::functor::FusedBatchNorm234   void operator()(OpKernelContext* context, const Tensor& x_input,
235                   const Tensor& scale_input, const Tensor& offset_input,
236                   const Tensor& estimated_mean_input,
237                   const Tensor& estimated_variance_input,
238                   const Tensor* side_input, U epsilon, U exponential_avg_factor,
239                   FusedBatchNormActivationMode activation_mode,
240                   Tensor* y_output, Tensor* batch_mean_output,
241                   Tensor* batch_var_output, Tensor* saved_mean_output,
242                   Tensor* saved_var_output, TensorFormat tensor_format,
243                   bool use_reserved_space) {
244     OP_REQUIRES(context, side_input == nullptr,
245                 errors::Internal(
246                     "The CPU implementation of FusedBatchNorm does not support "
247                     "side input."));
248     OP_REQUIRES(context,
249                 activation_mode == FusedBatchNormActivationMode::kIdentity,
250                 errors::Internal("The CPU implementation of FusedBatchNorm "
251                                  "does not support activations."));
252 
253     if (use_reserved_space) {
254       Tensor* dummy_reserve_space = nullptr;
255       OP_REQUIRES_OK(context,
256                      context->allocate_output(5, {}, &dummy_reserve_space));
257       // Initialize the memory, to avoid sanitizer alerts.
258       dummy_reserve_space->flat<U>()(0) = U();
259     }
260 
261     // If input is empty, return NaN mean/variance
262     if (x_input.shape().num_elements() == 0) {
263       functor::SetNanFunctor<CPUDevice, U> f;
264       f(context->eigen_device<CPUDevice>(), batch_mean_output->flat<U>());
265       f(context->eigen_device<CPUDevice>(), batch_var_output->flat<U>());
266       return;
267     }
268 
269     Tensor transformed_x;
270     Tensor transformed_y;
271     if (tensor_format == FORMAT_NCHW) {
272       const int64_t in_batch = GetTensorDim(x_input, tensor_format, 'N');
273       const int64_t in_rows = GetTensorDim(x_input, tensor_format, 'H');
274       const int64_t in_cols = GetTensorDim(x_input, tensor_format, 'W');
275       const int64_t in_depths = GetTensorDim(x_input, tensor_format, 'C');
276       OP_REQUIRES_OK(context, context->allocate_temp(
277                                   DataTypeToEnum<T>::value,
278                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
279                                                   in_rows, in_cols, in_depths),
280                                   &transformed_x));
281       OP_REQUIRES_OK(context, context->allocate_temp(
282                                   DataTypeToEnum<T>::value,
283                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
284                                                   in_rows, in_cols, in_depths),
285                                   &transformed_y));
286       // Perform NCHW to NHWC
287       std::vector<int32> perm = {0, 2, 3, 1};
288       OP_REQUIRES_OK(
289           context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
290                                              x_input, perm, &transformed_x));
291     } else {
292       transformed_x = x_input;
293       transformed_y = *y_output;
294     }
295     typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>());
296     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
297     typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
298     typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>());
299     typename TTypes<U>::ConstVec estimated_variance(
300         estimated_variance_input.vec<U>());
301     typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
302     typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>());
303     typename TTypes<U>::Vec batch_variance(batch_var_output->vec<U>());
304 
305     const CPUDevice& d = context->eigen_device<CPUDevice>();
306 
307     const int depth = x.dimension(3);
308     OP_REQUIRES(
309         context, depth != 0,
310         errors::Internal("The 4th element in the input shape cannot be 0."));
311     const int size = x.size();
312     const int rest_size = size / depth;
313     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
314     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
315     one_by_depth.set(1, depth);
316     Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
317     bcast_spec.set(0, rest_size);
318 
319     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
320     auto x_centered =
321         x_rest_by_depth -
322         estimated_mean.reshape(one_by_depth).broadcast(bcast_spec);
323     auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale)
324                               .eval()
325                               .reshape(one_by_depth)
326                               .broadcast(bcast_spec);
327     auto x_scaled = x_centered * scaling_factor;
328     auto x_shifted =
329         (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
330             .template cast<T>();
331 
332     y.reshape(rest_by_depth).device(d) = x_shifted;
333     batch_mean.device(d) = estimated_mean;
334     batch_variance.device(d) = estimated_variance;
335 
336     if (tensor_format == FORMAT_NCHW) {
337       // Perform NHWC to NCHW
338       const std::vector<int32> perm = {0, 3, 1, 2};
339       const Status s = ::tensorflow::DoTranspose(
340           context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
341       if (!s.ok()) {
342         context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
343       }
344     }
345   }
346 };
347 
348 template <typename T, typename U>
349 struct FusedBatchNormGrad<CPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormGrad350   void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
351                   const Tensor& x_input, const Tensor& scale_input,
352                   const Tensor* offset_input, const Tensor& mean_input,
353                   const Tensor& variance_input, const Tensor* y_input,
354                   U epsilon, FusedBatchNormActivationMode activation_mode,
355                   Tensor* x_backprop_output, Tensor* scale_backprop_output,
356                   Tensor* offset_backprop_output,
357                   Tensor* side_input_backprop_output, bool use_reserved_space,
358                   TensorFormat tensor_format) {
359     OP_REQUIRES(context,
360                 y_input == nullptr &&
361                     activation_mode == FusedBatchNormActivationMode::kIdentity,
362                 errors::Internal(
363                     "The CPU implementation of FusedBatchNormGrad does not "
364                     "support activations."));
365     OP_REQUIRES(context, side_input_backprop_output == nullptr,
366                 errors::Internal("The CPU implementation of FusedBatchNormGrad "
367                                  "does not support side input."));
368 
369     Tensor transformed_y_backprop_input;
370     Tensor transformed_x_input;
371     Tensor transformed_x_backprop_output;
372     if (tensor_format == FORMAT_NCHW) {
373       const int64_t in_batch = GetTensorDim(x_input, tensor_format, 'N');
374       const int64_t in_rows = GetTensorDim(x_input, tensor_format, 'H');
375       const int64_t in_cols = GetTensorDim(x_input, tensor_format, 'W');
376       const int64_t in_depths = GetTensorDim(x_input, tensor_format, 'C');
377       OP_REQUIRES_OK(context, context->allocate_temp(
378                                   DataTypeToEnum<T>::value,
379                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
380                                                   in_rows, in_cols, in_depths),
381                                   &transformed_y_backprop_input));
382       OP_REQUIRES_OK(context, context->allocate_temp(
383                                   DataTypeToEnum<T>::value,
384                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
385                                                   in_rows, in_cols, in_depths),
386                                   &transformed_x_input));
387       OP_REQUIRES_OK(context, context->allocate_temp(
388                                   DataTypeToEnum<T>::value,
389                                   ShapeFromFormat(FORMAT_NHWC, in_batch,
390                                                   in_rows, in_cols, in_depths),
391                                   &transformed_x_backprop_output));
392       // Perform NCHW to NHWC
393       std::vector<int32> perm = {0, 2, 3, 1};
394       OP_REQUIRES_OK(
395           context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
396                                              y_backprop_input, perm,
397                                              &transformed_y_backprop_input));
398       OP_REQUIRES_OK(context, ::tensorflow::DoTranspose(
399                                   context->eigen_device<CPUDevice>(), x_input,
400                                   perm, &transformed_x_input));
401     } else {
402       transformed_y_backprop_input = y_backprop_input;
403       transformed_x_input = x_input;
404       transformed_x_backprop_output = *x_backprop_output;
405     }
406     typename TTypes<T, 4>::Tensor y_backprop(
407         transformed_y_backprop_input.tensor<T, 4>());
408     typename TTypes<T, 4>::Tensor x(transformed_x_input.tensor<T, 4>());
409     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
410     typename TTypes<U>::ConstVec mean(mean_input.vec<U>());
411     typename TTypes<U>::ConstVec variance(variance_input.vec<U>());
412     typename TTypes<T, 4>::Tensor x_backprop(
413         transformed_x_backprop_output.tensor<T, 4>());
414     typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
415 
416     // Note: the following formulas are used to compute the gradients for
417     // back propagation.
418     // x_backprop = scale * rsqrt(variance + epsilon) *
419     //              [y_backprop - mean(y_backprop) - (x - mean(x)) *
420     //              mean(y_backprop * (x - mean(x))) / (variance + epsilon)]
421     // scale_backprop = sum(y_backprop *
422     //                  (x - mean(x)) * rsqrt(variance + epsilon))
423     // offset_backprop = sum(y_backprop)
424 
425     const CPUDevice& d = context->eigen_device<CPUDevice>();
426     const int depth = x.dimension(3);
427     const int size = x.size();
428     const int rest_size = size / depth;
429     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
430     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
431     one_by_depth.set(1, depth);
432     Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
433     bcast_spec.set(0, rest_size);
434 
435     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
436     U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
437 
438     // Eigen is notoriously bad at reducing outer dimension, so we materialize
439     // all temporary tensors that require reduction, and then use Eigen redux
440     // functor, that is optimized for this particular task.
441     //
442     // All reductions are of this type: [rest_size, depth] -> [depth].
443     using ScalarSum = Eigen::internal::scalar_sum_op<U>;
444     const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t;
445     const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u;
446 
447     auto scratch_dtype = DataTypeToEnum<U>::value;
448 
449     // Allocate a temporary workspace of [depth] shape.
450     Tensor scratch_one_by_depth;
451     OP_REQUIRES_OK(context, context->allocate_temp(scratch_dtype, {depth},
452                                                    &scratch_one_by_depth));
453 
454     // Maybe allocate a temporary workspace of [rest_size, depth] shape.
455     Tensor scratch_rest_by_depth;
456     if (std::is_same<T, U>::value) {
457       OP_REQUIRES(context,
458                   scratch_rest_by_depth.CopyFrom(transformed_x_backprop_output,
459                                                  {rest_size, depth}),
460                   errors::Internal("Failed to copy a tensor"));
461     } else {
462       OP_REQUIRES_OK(context,
463                      context->allocate_temp(scratch_dtype, {rest_size, depth},
464                                             &scratch_rest_by_depth));
465     }
466 
467     typename TTypes<U, 2>::Tensor scratch_tensor(
468         scratch_rest_by_depth.tensor<U, 2>());
469     typename TTypes<U>::Vec scratch_vector(scratch_one_by_depth.vec<U>());
470 
471     auto x_mean_rest_by_depth =
472         mean.reshape(one_by_depth).broadcast(bcast_spec);
473     auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth);
474     auto coef0_one_by_depth =
475         (variance.reshape(one_by_depth) + epsilon).rsqrt();
476     auto coef0_rest_by_depth = coef0_one_by_depth.broadcast(bcast_spec);
477     auto x_scaled = x_centered * coef0_rest_by_depth;
478 
479     auto y_backprop_rest_by_depth =
480         y_backprop.reshape(rest_by_depth).template cast<U>();
481 
482     // Compute `scale_backprop_output`:
483     //   scale_backprop =
484     //     (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims)
485     scratch_tensor.device(d) = y_backprop_rest_by_depth * x_scaled;
486     redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, scale_backprop_output);
487 
488     // Compute 'offset_backprop_output':
489     //   offset_backprop =
490     //     y_backprop_rest_by_depth.sum(reduce_dims)
491     redux_sum_t(d, rest_by_depth, transformed_y_backprop_input,
492                 offset_backprop_output);
493     auto y_backprop_sum = offset_backprop;
494 
495     auto y_backprop_sum_one_by_depth = y_backprop_sum.reshape(one_by_depth);
496     auto y_backprop_mean_one_by_depth =
497         y_backprop_sum_one_by_depth * rest_size_inv;
498     auto y_backprop_mean_rest_by_depth =
499         y_backprop_mean_one_by_depth.broadcast(bcast_spec);
500     auto y_backprop_centered =
501         y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth;
502 
503     // Compute expression:
504     //   y_backprop_centered_mean =
505     //     (y_backprop_rest_by_depth * x_centered).mean(reduce_dims)
506     scratch_tensor.device(d) = y_backprop_rest_by_depth * x_centered;
507     redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, &scratch_one_by_depth);
508     auto y_backprop_centered_mean =
509         scratch_vector.reshape(one_by_depth) / static_cast<U>(rest_size);
510 
511     auto coef1 = (scale.reshape(one_by_depth) * coef0_one_by_depth)
512                      .broadcast(bcast_spec);
513     auto coef2 = (coef0_one_by_depth.square() * y_backprop_centered_mean)
514                      .broadcast(bcast_spec);
515 
516     x_backprop.reshape(rest_by_depth).device(d) =
517         (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>();
518 
519     if (tensor_format == FORMAT_NCHW) {
520       // Perform NHWC to NCHW
521       std::vector<int32> perm = {0, 3, 1, 2};
522       OP_REQUIRES_OK(
523           context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
524                                              transformed_x_backprop_output,
525                                              perm, x_backprop_output));
526     }
527   }
528 };
529 
530 template <typename T, typename U>
531 struct FusedBatchNormFreezeGrad<CPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormFreezeGrad532   void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
533                   const Tensor& x_input, const Tensor& scale_input,
534                   const Tensor& pop_mean_input,
535                   const Tensor& pop_variance_input, U epsilon,
536                   Tensor* x_backprop_output, Tensor* scale_backprop_output,
537                   Tensor* offset_backprop_output) {
538     typename TTypes<T, 4>::ConstTensor y_backprop(
539         y_backprop_input.tensor<T, 4>());
540     typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
541     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
542     typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
543     typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
544     typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
545     typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
546 
547     const int depth = pop_mean.dimension(0);
548     const int rest_size = input.size() / depth;
549 
550     const CPUDevice& d = context->eigen_device<CPUDevice>();
551 
552     // Allocate two temporary workspaces of [depth] shape.
553     Tensor scratch1_vec, scratch2_vec;
554     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
555                                                    {depth}, &scratch1_vec));
556     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
557                                                    {depth}, &scratch2_vec));
558 
559     // Maybe allocate a temporary workspace of [rest_size, depth] shape.
560     Tensor scratch3_tensor;
561     if (std::is_same<T, U>::value) {
562       OP_REQUIRES(
563           context,
564           scratch3_tensor.CopyFrom(*x_backprop_output, {rest_size, depth}),
565           errors::Internal("Failed to copy a tensor"));
566     } else {
567       OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
568                                                      {rest_size, depth},
569                                                      &scratch3_tensor));
570     }
571 
572     typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>());
573     typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>());
574     typename TTypes<U, 2>::Tensor scratch3(scratch3_tensor.tensor<U, 2>());
575 
576     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
577     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
578     one_by_depth.set(1, depth);
579     Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> rest_by_one;
580     rest_by_one.set(0, rest_size);
581 
582     // Sum reduction along the 0th dimension using custom CPU functor.
583     using ScalarSum = Eigen::internal::scalar_sum_op<U>;
584     const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t;
585     const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u;
586 
587     // offset_backprop  = sum(y_backprop)
588     // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
589     // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
590 
591     // NOTE: DEFAULT DEVICE comment is added to expression assignments that
592     // we don't want to be executed in a thread pool.
593 
594     auto y_backprop_rest_by_depth =
595         y_backprop.reshape(rest_by_depth).template cast<U>();
596     auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
597 
598     // offset_backprop  = sum(y_backprop)
599     redux_sum_t(d, rest_by_depth, y_backprop_input, offset_backprop_output);
600 
601     // scratch1 = rsqrt(pop_var + epsilon)
602     scratch1 = (pop_var + pop_var.constant(epsilon)).rsqrt();  // DEFAULT DEVICE
603 
604     // scratch2 = sum(y_backprop * (x - mean))
605     scratch3.device(d) =
606         y_backprop_rest_by_depth *
607         (input_rest_by_depth -
608          pop_mean.reshape(one_by_depth).broadcast(rest_by_one));
609     redux_sum_u(d, rest_by_depth, scratch3_tensor, &scratch2_vec);
610 
611     x_backprop.reshape(rest_by_depth).device(d) =
612         (y_backprop_rest_by_depth *
613          ((scratch1.reshape(one_by_depth) * scale.reshape(one_by_depth))
614               .broadcast(rest_by_one)))
615             .template cast<T>();
616     scale_backprop = scratch2 * scratch1;  // DEFAULT DEVICE
617   }
618 };
619 
620 #if !GOOGLE_CUDA
621 namespace {
622 // See implementation under GOOGLE_CUDA #ifdef below.
623 // This is a CUDA specific feature, do not enable it for non-CUDA builds
BatchnormSpatialPersistentEnabled()624 bool BatchnormSpatialPersistentEnabled() { return false; }
625 }  // namespace
626 #endif
627 
628 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
629 
630 namespace {
631 
AsDnnActivationMode(const FusedBatchNormActivationMode activation_mode)632 se::dnn::ActivationMode AsDnnActivationMode(
633     const FusedBatchNormActivationMode activation_mode) {
634   switch (activation_mode) {
635     case FusedBatchNormActivationMode::kIdentity:
636       return se::dnn::ActivationMode::kNone;
637     case FusedBatchNormActivationMode::kRelu:
638       return se::dnn::ActivationMode::kRelu;
639   }
640 }
641 
642 #if GOOGLE_CUDA
643 // NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
644 // `cuda_dnn.cc` for details.
BatchnormSpatialPersistentEnabled()645 bool BatchnormSpatialPersistentEnabled() {
646 #if CUDNN_VERSION >= 7402
647   static bool is_enabled = [] {
648     bool is_enabled = false;
649     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
650         "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
651         /*default_val=*/false, &is_enabled));
652     return is_enabled;
653   }();
654   return is_enabled;
655 #else
656   return false;
657 #endif
658 }
659 #endif
660 
661 }  // namespace
662 
663 template <typename U, typename T>
CastDeviceMemory(Tensor * tensor)664 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
665   return DeviceMemory<U>::MakeFromByteSize(
666       tensor->template flat<T>().data(),
667       tensor->template flat<T>().size() * sizeof(T));
668 }
669 
670 // A helper to allocate temporary scratch memory for Cudnn BatchNormEx ops. It
671 // takes the ownership of the underlying memory. The expectation is that the
672 // memory should be alive for the span of the Cudnn BatchNormEx itself.
673 template <typename T>
674 class CudnnBatchNormAllocatorInTemp : public ScratchAllocator {
675  public:
676   ~CudnnBatchNormAllocatorInTemp() override = default;
677 
CudnnBatchNormAllocatorInTemp(OpKernelContext * context)678   explicit CudnnBatchNormAllocatorInTemp(OpKernelContext* context)
679       : context_(context) {}
680 
GetMemoryLimitInBytes()681   int64_t GetMemoryLimitInBytes() override {
682     return std::numeric_limits<int64_t>::max();
683   }
684 
AllocateBytes(int64_t byte_size)685   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
686     Tensor temporary_memory;
687     const DataType tf_data_type = DataTypeToEnum<T>::v();
688     int64_t allocate_count =
689         Eigen::divup(byte_size, static_cast<int64_t>(sizeof(T)));
690     Status allocation_status(context_->allocate_temp(
691         tf_data_type, TensorShape({allocate_count}), &temporary_memory));
692     if (!allocation_status.ok()) {
693       return allocation_status;
694     }
695     // Hold the reference of the allocated tensors until the end of the
696     // allocator.
697     allocated_tensors_.push_back(temporary_memory);
698     total_byte_size_ += byte_size;
699     return DeviceMemory<uint8>::MakeFromByteSize(
700         temporary_memory.template flat<T>().data(),
701         temporary_memory.template flat<T>().size() * sizeof(T));
702   }
703 
TotalByteSize() const704   int64_t TotalByteSize() const { return total_byte_size_; }
705 
get_allocated_tensor(int index) const706   Tensor get_allocated_tensor(int index) const {
707     return allocated_tensors_[index];
708   }
709 
710  private:
711   int64_t total_byte_size_ = 0;
712   OpKernelContext* context_;  // not owned
713   std::vector<Tensor> allocated_tensors_;
714 };
715 
716 // A helper to allocate memory for Cudnn BatchNormEx as a kernel output. It is
717 // used by forward pass kernel to feed the output to the backward pass.
718 // The memory is expected to live long enough after the backward pass is
719 // finished.
720 template <typename T>
721 class CudnnBatchNormAllocatorInOutput : public ScratchAllocator {
722  public:
~CudnnBatchNormAllocatorInOutput()723   ~CudnnBatchNormAllocatorInOutput() override {
724     if (!output_allocated) {
725       Tensor* dummy_reserve_space = nullptr;
726       OP_REQUIRES_OK(context_, context_->allocate_output(output_index_, {},
727                                                          &dummy_reserve_space));
728     }
729   }
730 
CudnnBatchNormAllocatorInOutput(OpKernelContext * context,int output_index)731   CudnnBatchNormAllocatorInOutput(OpKernelContext* context, int output_index)
732       : context_(context), output_index_(output_index) {}
733 
GetMemoryLimitInBytes()734   int64_t GetMemoryLimitInBytes() override {
735     return std::numeric_limits<int64_t>::max();
736   }
737 
AllocateBytes(int64_t byte_size)738   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
739     output_allocated = true;
740     DCHECK(total_byte_size_ == 0)
741         << "Reserve space allocator can only be called once";
742     int64_t allocate_count =
743         Eigen::divup(byte_size, static_cast<int64_t>(sizeof(T)));
744 
745     Tensor* temporary_memory = nullptr;
746     Status allocation_status(context_->allocate_output(
747         output_index_, TensorShape({allocate_count}), &temporary_memory));
748     if (!allocation_status.ok()) {
749       return allocation_status;
750     }
751     total_byte_size_ += byte_size;
752     auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
753         temporary_memory->template flat<T>().data(),
754         temporary_memory->template flat<T>().size() * sizeof(T));
755     return StatusOr<DeviceMemory<uint8>>(memory_uint8);
756   }
757 
TotalByteSize()758   int64_t TotalByteSize() { return total_byte_size_; }
759 
760  private:
761   int64_t total_byte_size_ = 0;
762   OpKernelContext* context_;  // not owned
763   int output_index_;
764   bool output_allocated = false;
765 };
766 
767 template <typename T, typename U, bool is_training>
768 struct FusedBatchNorm<GPUDevice, T, U, is_training> {
operator ()tensorflow::functor::FusedBatchNorm769   void operator()(OpKernelContext* context, const Tensor& x,
770                   const Tensor& scale, const Tensor& offset,
771                   const Tensor& estimated_mean,
772                   const Tensor& estimated_variance, const Tensor* side_input,
773                   U epsilon, U exponential_avg_factor,
774                   FusedBatchNormActivationMode activation_mode, Tensor* y,
775                   Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean,
776                   Tensor* saved_inv_var, TensorFormat tensor_format,
777                   bool use_reserved_space) {
778     auto* stream = context->op_device_context()->stream();
779     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
780 
781     const int64_t batch_size = GetTensorDim(x, tensor_format, 'N');
782     const int64_t channels = GetTensorDim(x, tensor_format, 'C');
783     const int64_t height = GetTensorDim(x, tensor_format, 'H');
784     const int64_t width = GetTensorDim(x, tensor_format, 'W');
785 
786     // If use_reserved_space we have reserve_space_3 output (only in
787     // FusedBatchNormV3 op).
788 
789 #if GOOGLE_CUDA
790     // Check if cuDNN batch normalization has a fast NHWC implementation:
791     //   (1) In inference mode it's always fast.
792     //   (2) Tensorflow enabled batchnorm spatial persistence, we are called
793     //   from
794     //       FusedBatchNormV3, i.e. use_reserved_space is true.
795     const bool fast_nhwc_batch_norm =
796         !is_training ||
797         (BatchnormSpatialPersistentEnabled() &&
798          DataTypeToEnum<T>::value == DT_HALF && use_reserved_space);
799 #else
800     // fast NHWC implementation is a CUDA only feature
801     const bool fast_nhwc_batch_norm = false;
802 #endif
803 
804     // If input tensor is in NHWC format, and we have a fast cuDNN
805     // implementation, there is no need to do data format conversion.
806     TensorFormat compute_format =
807         fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC
808                                                              : FORMAT_NCHW;
809 
810     VLOG(2) << "FusedBatchNorm:"
811             << " batch_size: " << batch_size << " channels: " << channels
812             << " height: " << height << " width:" << width
813             << " x shape: " << x.shape().DebugString()
814             << " scale shape: " << scale.shape().DebugString()
815             << " offset shape: " << offset.shape().DebugString()
816             << " activation mode: " << ToString(activation_mode)
817             << " tensor format: " << ToString(tensor_format)
818             << " compute format: " << ToString(compute_format);
819 
820     auto maybe_make_dummy_output = [context, use_reserved_space]() -> Status {
821       if (use_reserved_space) {
822         Tensor* dummy_reserve_space = nullptr;
823         return context->allocate_output(5, {}, &dummy_reserve_space);
824       }
825       return OkStatus();
826     };
827 
828     // If input is empty, return NaN mean/variance
829     if (x.shape().num_elements() == 0) {
830       OP_REQUIRES_OK(context, maybe_make_dummy_output());
831       functor::SetNanFunctor<GPUDevice, U> f;
832       f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>());
833       f(context->eigen_device<GPUDevice>(), batch_var->flat<U>());
834       return;
835     }
836 
837     // In inference mode we use custom CUDA kernel, because cuDNN does not
838     // support side input and activations for inference.
839     const bool has_side_input = side_input != nullptr;
840     const bool has_activation =
841         activation_mode != FusedBatchNormActivationMode::kIdentity;
842 
843     if (!is_training && (has_side_input || has_activation)) {
844       OP_REQUIRES_OK(context, maybe_make_dummy_output());
845       FusedBatchNormInferenceFunctor<GPUDevice, T, U> inference_functor;
846 
847       if (has_side_input) {
848         inference_functor(context, tensor_format, x.tensor<T, 4>(),
849                           scale.vec<U>(), offset.vec<U>(),
850                           estimated_mean.vec<U>(), estimated_variance.vec<U>(),
851                           side_input->tensor<T, 4>(), epsilon, activation_mode,
852                           y->tensor<T, 4>());
853       } else {
854         typename TTypes<T, 4>::ConstTensor empty_tensor(nullptr, 0, 0, 0, 0);
855         inference_functor(context, tensor_format, x.tensor<T, 4>(),
856                           scale.vec<U>(), offset.vec<U>(),
857                           estimated_mean.vec<U>(), estimated_variance.vec<U>(),
858                           empty_tensor, epsilon, activation_mode,
859                           y->tensor<T, 4>());
860       }
861       return;
862     }
863 
864     Tensor x_maybe_transformed = x;
865     Tensor x_transformed;
866     Tensor y_transformed;
867     se::DeviceMemory<T> y_ptr;
868 
869     if (tensor_format == compute_format) {
870       y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*y);
871     } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
872       OP_REQUIRES_OK(context, context->allocate_temp(
873                                   DataTypeToEnum<T>::value,
874                                   ShapeFromFormat(compute_format, batch_size,
875                                                   height, width, channels),
876                                   &x_transformed));
877       functor::NHWCToNCHW<GPUDevice, T, 4>()(
878           context->eigen_device<GPUDevice>(),
879           const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
880           x_transformed.tensor<T, 4>());
881       x_maybe_transformed = x_transformed;
882 
883       OP_REQUIRES_OK(context, context->allocate_temp(
884                                   DataTypeToEnum<T>::value,
885                                   ShapeFromFormat(compute_format, batch_size,
886                                                   height, width, channels),
887                                   &y_transformed));
888       y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(y_transformed);
889     } else {
890       context->SetStatus(errors::Internal(
891           "Unsupported tensor format: ", ToString(tensor_format),
892           " and compute format: ", ToString(compute_format)));
893       return;
894     }
895 
896     const se::dnn::DataLayout data_layout =
897         compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth
898                                       : se::dnn::DataLayout::kBatchDepthYX;
899 
900     se::dnn::BatchDescriptor x_desc;
901     x_desc.set_count(batch_size)
902         .set_feature_map_count(channels)
903         .set_height(height)
904         .set_width(width)
905         .set_layout(data_layout);
906 
907     se::dnn::BatchDescriptor scale_offset_desc;
908     scale_offset_desc.set_count(1)
909         .set_feature_map_count(channels)
910         .set_height(1)
911         .set_width(1)
912         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
913 
914     auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
915     auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
916     auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset);
917     auto estimated_mean_ptr =
918         StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean);
919     auto estimated_variance_ptr =
920         StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
921     auto side_input_ptr =
922         side_input != nullptr
923             ? StreamExecutorUtil::AsDeviceMemory<T>(*side_input)
924             : se::DeviceMemory<T>();
925     auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean);
926 
927     auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var);
928     auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean);
929     auto saved_inv_var_ptr =
930         StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var);
931 
932     std::unique_ptr<functor::CudnnBatchNormAllocatorInOutput<U>>
933         reserve_space_allocator;
934     std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>>
935         workspace_allocator;
936     if (use_reserved_space) {
937       reserve_space_allocator.reset(
938           new functor::CudnnBatchNormAllocatorInOutput<U>(context, 5));
939       workspace_allocator.reset(
940           new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
941     }
942     if (!batch_mean->SharesBufferWith(estimated_mean) &&
943         exponential_avg_factor != 1.0f) {
944       OP_REQUIRES(
945           context,
946           stream
947               ->ThenMemcpyD2D(&batch_mean_ptr, estimated_mean_ptr,
948                               estimated_mean.NumElements() * sizeof(U))
949               .ok(),
950           errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
951                            "from device"));
952     }
953     if (!batch_var->SharesBufferWith(estimated_variance) &&
954         exponential_avg_factor != 1.0f) {
955       OP_REQUIRES(
956           context,
957           stream
958               ->ThenMemcpyD2D(&batch_var_ptr, estimated_variance_ptr,
959                               estimated_variance.NumElements() * sizeof(U))
960               .ok(),
961           errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
962                            "from device"));
963     }
964     bool cudnn_launch_status =
965         stream
966             ->ThenBatchNormalizationForward(
967                 x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr,
968                 estimated_variance_ptr, side_input_ptr, x_desc,
969                 scale_offset_desc, static_cast<double>(epsilon),
970                 static_cast<double>(exponential_avg_factor),
971                 AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr,
972                 &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr,
973                 is_training, reserve_space_allocator.get(),
974                 workspace_allocator.get())
975             .ok();
976 
977     if (!cudnn_launch_status) {
978       context->SetStatus(
979           errors::Internal("cuDNN launch failure : input shape (",
980                            x.shape().DebugString(), ")"));
981       return;
982     }
983 
984     if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
985       functor::NCHWToNHWC<GPUDevice, T, 4>()(
986           context->eigen_device<GPUDevice>(),
987           const_cast<const Tensor&>(y_transformed).tensor<T, 4>(),
988           y->tensor<T, 4>());
989     }
990   }
991 };
992 
993 template <typename T, typename U>
994 struct FusedBatchNormGrad<GPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormGrad995   void operator()(OpKernelContext* context, const Tensor& y_backprop,
996                   const Tensor& x, const Tensor& scale, const Tensor* offset,
997                   const Tensor& mean, const Tensor& inv_variance,
998                   const Tensor* y, U epsilon,
999                   FusedBatchNormActivationMode activation_mode,
1000                   Tensor* x_backprop, Tensor* scale_backprop,
1001                   Tensor* offset_backprop, Tensor* side_input_backprop,
1002                   bool use_reserved_space, TensorFormat tensor_format) {
1003     auto* stream = context->op_device_context()->stream();
1004     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
1005 
1006     const int64_t batch_size = GetTensorDim(x, tensor_format, 'N');
1007     const int64_t channels = GetTensorDim(x, tensor_format, 'C');
1008     const int64_t height = GetTensorDim(x, tensor_format, 'H');
1009     const int64_t width = GetTensorDim(x, tensor_format, 'W');
1010 
1011 #if GOOGLE_CUDA
1012     // Check if cuDNN batch normalization has a fast NHWC implementation:
1013     //   (1) Tensorflow enabled batchnorm spatial persistence, and
1014     //       FusedBatchNormGradV3 passed non-null reserve space and allocator.
1015     const bool fast_nhwc_batch_norm = BatchnormSpatialPersistentEnabled() &&
1016                                       DataTypeToEnum<T>::value == DT_HALF &&
1017                                       use_reserved_space;
1018 #else
1019     // fast NHWC implementation is a CUDA only feature
1020     const bool fast_nhwc_batch_norm = false;
1021 #endif
1022 
1023     // If input tensor is in NHWC format, and we have a fast cuDNN
1024     // implementation, there is no need to do data format conversion.
1025     TensorFormat compute_format =
1026         fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC
1027                                                              : FORMAT_NCHW;
1028 
1029     VLOG(2) << "FusedBatchNormGrad:"
1030             << " batch_size: " << batch_size << " channels: " << channels
1031             << " height: " << height << " width: " << width
1032             << " y_backprop shape: " << y_backprop.shape().DebugString()
1033             << " x shape: " << x.shape().DebugString()
1034             << " scale shape: " << scale.shape().DebugString()
1035             << " activation mode: " << ToString(activation_mode)
1036             << " tensor format: " << ToString(tensor_format)
1037             << " compute format: " << ToString(compute_format);
1038 
1039     // Inputs
1040     Tensor y_backprop_maybe_transformed = y_backprop;
1041     Tensor x_maybe_transformed = x;
1042     Tensor y_backprop_transformed;
1043     Tensor x_transformed;
1044 
1045     // Outputs
1046     Tensor x_backprop_transformed;
1047     se::DeviceMemory<T> x_backprop_ptr;
1048 
1049     if (tensor_format == compute_format) {
1050       x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop);
1051     } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
1052       // Transform inputs from 'NHWC' to 'NCHW'
1053       OP_REQUIRES_OK(context, context->allocate_temp(
1054                                   DataTypeToEnum<T>::value,
1055                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
1056                                                   height, width, channels),
1057                                   &y_backprop_transformed));
1058       functor::NHWCToNCHW<GPUDevice, T, 4>()(
1059           context->eigen_device<GPUDevice>(),
1060           const_cast<const Tensor&>(y_backprop_maybe_transformed)
1061               .tensor<T, 4>(),
1062           y_backprop_transformed.tensor<T, 4>());
1063       y_backprop_maybe_transformed = y_backprop_transformed;
1064 
1065       OP_REQUIRES_OK(context, context->allocate_temp(
1066                                   DataTypeToEnum<T>::value,
1067                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
1068                                                   height, width, channels),
1069                                   &x_transformed));
1070       functor::NHWCToNCHW<GPUDevice, T, 4>()(
1071           context->eigen_device<GPUDevice>(),
1072           const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
1073           x_transformed.tensor<T, 4>());
1074       x_maybe_transformed = x_transformed;
1075 
1076       // Allocate memory for transformed outputs in 'NCHW'
1077       OP_REQUIRES_OK(context, context->allocate_temp(
1078                                   DataTypeToEnum<T>::value,
1079                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
1080                                                   height, width, channels),
1081                                   &x_backprop_transformed));
1082       x_backprop_ptr =
1083           StreamExecutorUtil::AsDeviceMemory<T>(x_backprop_transformed);
1084     } else {
1085       context->SetStatus(errors::Internal(
1086           "Unsupported tensor format: ", ToString(tensor_format),
1087           " and compute format: ", ToString(compute_format)));
1088       return;
1089     }
1090 
1091     const se::dnn::DataLayout data_layout =
1092         compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth
1093                                       : se::dnn::DataLayout::kBatchDepthYX;
1094 
1095     se::dnn::BatchDescriptor x_desc;
1096     x_desc.set_count(batch_size)
1097         .set_feature_map_count(channels)
1098         .set_height(height)
1099         .set_width(width)
1100         .set_layout(data_layout);
1101 
1102     se::dnn::BatchDescriptor scale_offset_desc;
1103     scale_offset_desc.set_count(1)
1104         .set_feature_map_count(channels)
1105         .set_height(1)
1106         .set_width(1)
1107         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
1108 
1109     auto y_backprop_ptr =
1110         StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed);
1111     auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
1112     auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
1113     auto offset_ptr = offset != nullptr
1114                           ? StreamExecutorUtil::AsDeviceMemory<U>(*offset)
1115                           : se::DeviceMemory<U>();
1116     auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean);
1117     auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance);
1118     auto y_ptr = y != nullptr ? StreamExecutorUtil::AsDeviceMemory<T>(*y)
1119                               : se::DeviceMemory<T>();
1120     auto scale_backprop_ptr =
1121         StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop);
1122     auto offset_backprop_ptr =
1123         StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop);
1124     auto side_input_backprop_ptr =
1125         side_input_backprop != nullptr
1126             ? StreamExecutorUtil::AsDeviceMemory<T>(*side_input_backprop)
1127             : se::DeviceMemory<T>();
1128 
1129     std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>>
1130         workspace_allocator;
1131     DeviceMemory<uint8>* reserve_space_data_ptr = nullptr;
1132     DeviceMemory<uint8> reserve_space_data;
1133 #if CUDNN_VERSION >= 7402
1134     if (use_reserved_space) {
1135       const Tensor& reserve_space = context->input(5);
1136       workspace_allocator.reset(
1137           new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
1138 
1139       // the cudnn kernel outputs inverse variance in forward and reuse it in
1140       // backward
1141       if (reserve_space.dims() != 0) {
1142         reserve_space_data = functor::CastDeviceMemory<uint8, U>(
1143             const_cast<Tensor*>(&reserve_space));
1144         reserve_space_data_ptr = &reserve_space_data;
1145       }
1146     }
1147 #endif  // CUDNN_VERSION >= 7402
1148 
1149     bool cudnn_launch_status =
1150         stream
1151             ->ThenBatchNormalizationBackward(
1152                 y_backprop_ptr, x_ptr, scale_ptr, offset_ptr, mean_ptr,
1153                 inv_variance_ptr, y_ptr, x_desc, scale_offset_desc,
1154                 static_cast<double>(epsilon),
1155                 AsDnnActivationMode(activation_mode), &x_backprop_ptr,
1156                 &scale_backprop_ptr, &offset_backprop_ptr,
1157                 &side_input_backprop_ptr, reserve_space_data_ptr,
1158                 workspace_allocator.get())
1159             .ok();
1160 
1161     if (!cudnn_launch_status) {
1162       context->SetStatus(
1163           errors::Internal("cuDNN launch failure : input shape (",
1164                            x.shape().DebugString(), ")"));
1165     }
1166     if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
1167       functor::NCHWToNHWC<GPUDevice, T, 4>()(
1168           context->eigen_device<GPUDevice>(),
1169           const_cast<const Tensor&>(x_backprop_transformed).tensor<T, 4>(),
1170           x_backprop->tensor<T, 4>());
1171     }
1172   }
1173 };
1174 
1175 // Forward declarations of the functor specializations for GPU.
1176 #define DECLARE_GPU_SPEC(T, U)                                                 \
1177   template <>                                                                  \
1178   void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()(                  \
1179       OpKernelContext* context, const Tensor& y_backprop_input,                \
1180       const Tensor& x_input, const Tensor& scale_input,                        \
1181       const Tensor& mean_input, const Tensor& variance_input, U epsilon,       \
1182       Tensor* x_backprop_output, Tensor* scale_backprop_output,                \
1183       Tensor* offset_backprop_output);                                         \
1184   extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>;            \
1185   template <>                                                                  \
1186   void FusedBatchNormInferenceFunctor<GPUDevice, T, U>::operator()(            \
1187       OpKernelContext* context, TensorFormat tensor_format,                    \
1188       typename TTypes<T, 4>::ConstTensor in,                                   \
1189       typename TTypes<U>::ConstVec scale, typename TTypes<U>::ConstVec offset, \
1190       typename TTypes<U>::ConstVec estimated_mean,                             \
1191       typename TTypes<U>::ConstVec estimated_variance,                         \
1192       typename TTypes<T, 4>::ConstTensor side_input, U epsilon,                \
1193       FusedBatchNormActivationMode activation_mode,                            \
1194       typename TTypes<T, 4>::Tensor out);                                      \
1195   extern template struct FusedBatchNormInferenceFunctor<GPUDevice, T, U>;
1196 
1197 DECLARE_GPU_SPEC(float, float);
1198 DECLARE_GPU_SPEC(Eigen::half, float);
1199 
1200 #undef DECLARE_GPU_SPEC
1201 
1202 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1203 }  // namespace functor
1204 
1205 template <typename Device, typename T, typename U>
1206 class FusedBatchNormOpBase : public OpKernel {
1207   using FbnActivationMode = functor::FusedBatchNormActivationMode;
1208 
1209  protected:
FusedBatchNormOpBase(OpKernelConstruction * context,bool is_batch_norm_ex=false)1210   explicit FusedBatchNormOpBase(OpKernelConstruction* context,
1211                                 bool is_batch_norm_ex = false)
1212       : OpKernel(context) {
1213     float epsilon;
1214     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1215     epsilon_ = U(epsilon);
1216     float exponential_avg_factor;
1217     OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor",
1218                                              &exponential_avg_factor));
1219     exponential_avg_factor_ = U(exponential_avg_factor);
1220     string tensor_format;
1221     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1222     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1223                 errors::InvalidArgument("Invalid data format"));
1224     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1225 
1226     if (!is_batch_norm_ex) {
1227       has_side_input_ = false;
1228       activation_mode_ = FbnActivationMode::kIdentity;
1229     } else {
1230       OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
1231 
1232       int num_side_inputs;
1233       OP_REQUIRES_OK(context,
1234                      context->GetAttr("num_side_inputs", &num_side_inputs));
1235       OP_REQUIRES(context, num_side_inputs >= 0 && num_side_inputs <= 1,
1236                   errors::InvalidArgument(
1237                       "FusedBatchNorm accepts at most one side input."));
1238       has_side_input_ = (num_side_inputs == 1);
1239       if (has_side_input_ && is_training_) {
1240         OP_REQUIRES(
1241             context, activation_mode_ != FbnActivationMode::kIdentity,
1242             errors::InvalidArgument("Identity activation is not supported with "
1243                                     "non-empty side input"));
1244       }
1245     }
1246 
1247     if (activation_mode_ != FbnActivationMode::kIdentity && is_training_) {
1248       // NOTE(ezhulenev): Following requirements are coming from implementation
1249       // details of cudnnBatchNormalizationForwardTrainingEx used in training
1250       // mode. In inference mode we call custom CUDA kernel that supports all
1251       // data formats and data types.
1252       OP_REQUIRES(context, DataTypeToEnum<T>::value == DT_HALF,
1253                   errors::InvalidArgument("FusedBatchNorm with activation "
1254                                           "supports only DT_HALF data type."));
1255       OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
1256                   errors::InvalidArgument("FusedBatchNorm with activation "
1257                                           "supports only NHWC tensor format."));
1258       OP_REQUIRES(context, functor::BatchnormSpatialPersistentEnabled(),
1259                   errors::InvalidArgument(
1260                       "FusedBatchNorm with activation must run with cuDNN "
1261                       "spatial persistence mode enabled."));
1262     }
1263   }
1264 
1265   // If use_reserved_space is true, we need to handle the 5th output (a reserved
1266   // space) and a new cudnn batch norm will be called if the version > 7.4.2.
1267   // If use_reserved_space is false, we don't have 5th output.
ComputeWithReservedSpace(OpKernelContext * context,bool use_reserved_space)1268   virtual void ComputeWithReservedSpace(OpKernelContext* context,
1269                                         bool use_reserved_space) {
1270     Tensor x = context->input(0);
1271     const Tensor& scale = context->input(1);
1272     const Tensor& offset = context->input(2);
1273     const Tensor& estimated_mean = context->input(3);
1274     const Tensor& estimated_variance = context->input(4);
1275     const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
1276 
1277     OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5,
1278                 errors::InvalidArgument("input must be 4 or 5-dimensional",
1279                                         x.shape().DebugString()));
1280     OP_REQUIRES(context, scale.dims() == 1,
1281                 errors::InvalidArgument("scale must be 1-dimensional",
1282                                         scale.shape().DebugString()));
1283     OP_REQUIRES(context, offset.dims() == 1,
1284                 errors::InvalidArgument("offset must be 1-dimensional",
1285                                         offset.shape().DebugString()));
1286     OP_REQUIRES(context, estimated_mean.dims() == 1,
1287                 errors::InvalidArgument("estimated_mean must be 1-dimensional",
1288                                         estimated_mean.shape().DebugString()));
1289     OP_REQUIRES(
1290         context, estimated_variance.dims() == 1,
1291         errors::InvalidArgument("estimated_variance must be 1-dimensional",
1292                                 estimated_variance.shape().DebugString()));
1293     bool use_reshape = (x.dims() == 5);
1294     auto x_shape = x.shape();
1295     TensorShape dest_shape;
1296     if (use_reshape) {
1297       const int64_t in_batch = GetTensorDim(x, tensor_format_, 'N');
1298       int64_t in_planes = GetTensorDim(x, tensor_format_, '0');
1299       int64_t in_rows = GetTensorDim(x, tensor_format_, '1');
1300       int64_t in_cols = GetTensorDim(x, tensor_format_, '2');
1301       const int64_t in_depth = GetTensorDim(x, tensor_format_, 'C');
1302       dest_shape = ShapeFromFormat(tensor_format_, in_batch,
1303                                    {{in_planes, in_rows * in_cols}}, in_depth);
1304       OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
1305                   errors::InvalidArgument("Error during tensor copy."));
1306     }
1307 
1308     const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
1309     OP_REQUIRES(
1310         context, scale.NumElements() == num_channels,
1311         errors::InvalidArgument("scale must have the same number of elements "
1312                                 "as the channels of x, got ",
1313                                 scale.NumElements(), " and ", num_channels));
1314     OP_REQUIRES(
1315         context, offset.NumElements() == num_channels,
1316         errors::InvalidArgument("offset must have the same number of elements "
1317                                 "as the channels of x, got ",
1318                                 offset.NumElements(), " and ", num_channels));
1319     if (!is_training_ || exponential_avg_factor_ != 1.) {
1320       std::string prefix_msg = is_training_ ? "When exponential_avg_factor != 1"
1321                                             : "When is_training=false";
1322       OP_REQUIRES(context, estimated_mean.NumElements() == num_channels,
1323                   errors::InvalidArgument(
1324                       prefix_msg,
1325                       ", mean must have the same number "
1326                       "of elements as the channels of x, got ",
1327                       estimated_mean.NumElements(), " and ", num_channels));
1328       OP_REQUIRES(context, estimated_variance.NumElements() == num_channels,
1329                   errors::InvalidArgument(
1330                       prefix_msg,
1331                       ", variance must have the same "
1332                       "number of elements as the channels of x, got ",
1333                       estimated_variance.NumElements(), " and ", num_channels));
1334     }
1335 
1336     if (has_side_input_) {
1337       OP_REQUIRES(context, side_input->shape() == x.shape(),
1338                   errors::InvalidArgument(
1339                       "side_input shape must be equal to input shape: ",
1340                       side_input->shape().DebugString(),
1341                       " != ", x.shape().DebugString()));
1342     }
1343 
1344     if (activation_mode_ != FbnActivationMode::kIdentity) {
1345       // NOTE(ezhulenev): This requirement is coming from implementation
1346       // details of cudnnBatchNormalizationForwardTrainingEx.
1347       OP_REQUIRES(
1348           context, !is_training_ || num_channels % 4 == 0,
1349           errors::InvalidArgument("FusedBatchNorm with activation requires "
1350                                   "channel dimension to be a multiple of 4."));
1351     }
1352 
1353     Tensor* y = nullptr;
1354     auto alloc_shape = use_reshape ? dest_shape : x_shape;
1355     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1356                                 {0}, 0, alloc_shape, &y));
1357 
1358     Tensor* batch_mean = nullptr;
1359     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1360                                 {3}, 1, scale.shape(), &batch_mean));
1361     Tensor* batch_var = nullptr;
1362     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1363                                 {4}, 2, scale.shape(), &batch_var));
1364     Tensor* saved_mean = nullptr;
1365     OP_REQUIRES_OK(context,
1366                    context->allocate_output(3, scale.shape(), &saved_mean));
1367     Tensor* saved_maybe_inv_var = nullptr;
1368     OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(),
1369                                                      &saved_maybe_inv_var));
1370 
1371     if (is_training_) {
1372       functor::FusedBatchNorm<Device, T, U, true>()(
1373           context, x, scale, offset, estimated_mean, estimated_variance,
1374           side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
1375           batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
1376           tensor_format_, use_reserved_space);
1377     } else {
1378       functor::FusedBatchNorm<Device, T, U, false>()(
1379           context, x, scale, offset, estimated_mean, estimated_variance,
1380           side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
1381           batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
1382           tensor_format_, use_reserved_space);
1383     }
1384     if (use_reshape) {
1385       OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
1386                   errors::InvalidArgument("Error during tensor copy."));
1387     }
1388   }
1389 
1390  private:
1391   U epsilon_;
1392   U exponential_avg_factor_;
1393   TensorFormat tensor_format_;
1394   bool is_training_;
1395   bool has_side_input_;
1396   FbnActivationMode activation_mode_;
1397 };
1398 
1399 template <typename Device, typename T, typename U>
1400 class FusedBatchNormOp : public FusedBatchNormOpBase<Device, T, U> {
1401  public:
FusedBatchNormOp(OpKernelConstruction * context)1402   explicit FusedBatchNormOp(OpKernelConstruction* context)
1403       : FusedBatchNormOpBase<Device, T, U>(context) {}
1404 
Compute(OpKernelContext * context)1405   void Compute(OpKernelContext* context) override {
1406     FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1407                                                                  false);
1408   }
1409 };
1410 
1411 template <typename Device, typename T, typename U>
1412 class FusedBatchNormOpV3 : public FusedBatchNormOpBase<Device, T, U> {
1413  public:
FusedBatchNormOpV3(OpKernelConstruction * context)1414   explicit FusedBatchNormOpV3(OpKernelConstruction* context)
1415       : FusedBatchNormOpBase<Device, T, U>(context) {}
1416 
Compute(OpKernelContext * context)1417   void Compute(OpKernelContext* context) override {
1418     FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, true);
1419   }
1420 };
1421 
1422 template <typename Device, typename T, typename U>
1423 class FusedBatchNormOpEx : public FusedBatchNormOpBase<Device, T, U> {
1424   static constexpr bool kWithSideInputAndActivation = true;
1425 
1426  public:
FusedBatchNormOpEx(OpKernelConstruction * context)1427   explicit FusedBatchNormOpEx(OpKernelConstruction* context)
1428       : FusedBatchNormOpBase<Device, T, U>(context,
1429                                            kWithSideInputAndActivation) {}
1430 
Compute(OpKernelContext * context)1431   void Compute(OpKernelContext* context) override {
1432     FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, true);
1433   }
1434 };
1435 
1436 template <typename Device, typename T, typename U>
1437 class FusedBatchNormGradOpBase : public OpKernel {
1438   using FbnActivationMode = functor::FusedBatchNormActivationMode;
1439 
1440  protected:
FusedBatchNormGradOpBase(OpKernelConstruction * context,bool is_batch_norm_grad_ex=false)1441   explicit FusedBatchNormGradOpBase(OpKernelConstruction* context,
1442                                     bool is_batch_norm_grad_ex = false)
1443       : OpKernel(context) {
1444     float epsilon;
1445     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1446     epsilon_ = U(epsilon);
1447     string tensor_format;
1448     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1449     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1450                 errors::InvalidArgument("Invalid data format"));
1451     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1452     if (!is_batch_norm_grad_ex) {
1453       has_side_input_ = false;
1454       activation_mode_ = FbnActivationMode::kIdentity;
1455     } else {
1456       OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
1457 
1458       int num_side_inputs;
1459       OP_REQUIRES_OK(context,
1460                      context->GetAttr("num_side_inputs", &num_side_inputs));
1461       OP_REQUIRES(context, num_side_inputs >= 0 && num_side_inputs <= 1,
1462                   errors::InvalidArgument(
1463                       "FusedBatchNormGrad accepts at most one side input."));
1464       has_side_input_ = (num_side_inputs == 1);
1465       if (has_side_input_ && is_training_) {
1466         OP_REQUIRES(
1467             context, activation_mode_ != FbnActivationMode::kIdentity,
1468             errors::InvalidArgument("Identity activation is not supported with "
1469                                     "non-empty side input"));
1470       }
1471     }
1472 
1473     if (activation_mode_ != FbnActivationMode::kIdentity && is_training_) {
1474       // NOTE(kaixih@nvidia): Following requirements are coming from
1475       // implementation details of cudnnBatchNormalizationBackwardEx used in
1476       // training mode.
1477       OP_REQUIRES(context, DataTypeToEnum<T>::value == DT_HALF,
1478                   errors::InvalidArgument("FusedBatchNormGrad with activation "
1479                                           "supports only DT_HALF data type."));
1480       OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
1481                   errors::InvalidArgument("FusedBatchNormGrad with activation "
1482                                           "supports only NHWC tensor format."));
1483       OP_REQUIRES(context, functor::BatchnormSpatialPersistentEnabled(),
1484                   errors::InvalidArgument(
1485                       "FusedBatchNormGrad with activation must run with cuDNN "
1486                       "spatial persistence mode enabled."));
1487     }
1488   }
1489 
ComputeWithReservedSpace(OpKernelContext * context,bool use_reserved_space)1490   virtual void ComputeWithReservedSpace(OpKernelContext* context,
1491                                         bool use_reserved_space) {
1492     Tensor y_backprop = context->input(0);
1493     Tensor x = context->input(1);
1494     const Tensor& scale = context->input(2);
1495     // When is_training=True, batch mean and variance/inverted variance are
1496     // saved in the forward pass to be reused here. When is_training=False,
1497     // population mean and variance need to be forwarded here to compute the
1498     // gradients.
1499     const Tensor& saved_mean_or_pop_mean = context->input(3);
1500     // The Eigen implementation saves variance in the forward pass, while cuDNN
1501     // saves inverted variance.
1502     const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
1503     bool use_activation = activation_mode_ != FbnActivationMode::kIdentity;
1504     const Tensor* offset = use_activation ? &context->input(6) : nullptr;
1505     const Tensor* y = use_activation ? &context->input(7) : nullptr;
1506 
1507     OP_REQUIRES(context, y_backprop.dims() == 4 || y_backprop.dims() == 5,
1508                 errors::InvalidArgument("input must be 4 or 5-dimensional",
1509                                         y_backprop.shape().DebugString()));
1510     OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5,
1511                 errors::InvalidArgument("input must be 4 or 5-dimensional",
1512                                         x.shape().DebugString()));
1513     OP_REQUIRES(context, scale.dims() == 1,
1514                 errors::InvalidArgument("scale must be 1-dimensional",
1515                                         scale.shape().DebugString()));
1516     OP_REQUIRES(
1517         context, saved_mean_or_pop_mean.dims() == 1,
1518         errors::InvalidArgument("saved mean must be 1-dimensional",
1519                                 saved_mean_or_pop_mean.shape().DebugString()));
1520     OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1,
1521                 errors::InvalidArgument(
1522                     "saved variance must be 1-dimensional",
1523                     saved_maybe_inv_var_or_pop_var.shape().DebugString()));
1524     OP_REQUIRES(
1525         context, x.shape() == y_backprop.shape(),
1526         errors::InvalidArgument(
1527             "x and y_backprop must have same shape, but x has shape ",
1528             x.shape(), " and y_backprop has shape ", y_backprop.shape()));
1529     if (use_activation) {
1530       OP_REQUIRES(
1531           context, x.dim_size(3) % 4 == 0,
1532           errors::InvalidArgument("FusedBatchNormGrad with activation requires "
1533                                   "channel dimension to be a multiple of 4."));
1534       OP_REQUIRES(context, offset->dims() == 1,
1535                   errors::InvalidArgument("offset must be 1-dimensional",
1536                                           offset->shape().DebugString()));
1537     }
1538     bool use_reshape = (x.dims() == 5);
1539     auto x_shape = x.shape();
1540     TensorShape dest_shape;
1541     if (use_reshape) {
1542       const int64_t in_batch = GetTensorDim(x, tensor_format_, 'N');
1543       int64_t in_planes = GetTensorDim(x, tensor_format_, '0');
1544       int64_t in_rows = GetTensorDim(x, tensor_format_, '1');
1545       int64_t in_cols = GetTensorDim(x, tensor_format_, '2');
1546       const int64_t in_depth = GetTensorDim(x, tensor_format_, 'C');
1547       dest_shape = ShapeFromFormat(tensor_format_, in_batch,
1548                                    {{in_planes, in_rows * in_cols}}, in_depth);
1549       OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
1550                   errors::InvalidArgument("Error during tensor copy."));
1551       OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
1552                   errors::InvalidArgument("Error during tensor copy."));
1553     }
1554 
1555     const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
1556     OP_REQUIRES(
1557         context, scale.NumElements() == num_channels,
1558         errors::InvalidArgument("scale must have the same number of elements "
1559                                 "as the channels of x, got ",
1560                                 scale.NumElements(), " and ", num_channels));
1561     OP_REQUIRES(
1562         context, saved_mean_or_pop_mean.NumElements() == num_channels,
1563         errors::InvalidArgument("reserve_space_1 must have the same number of "
1564                                 "elements as the channels of x, got ",
1565                                 saved_mean_or_pop_mean.NumElements(), " and ",
1566                                 num_channels));
1567     OP_REQUIRES(
1568         context, saved_maybe_inv_var_or_pop_var.NumElements() == num_channels,
1569         errors::InvalidArgument("reserve_space_2 must have the same number of "
1570                                 "elements as the channels of x, got ",
1571                                 saved_maybe_inv_var_or_pop_var.NumElements(),
1572                                 " and ", num_channels));
1573 
1574     Tensor* x_backprop = nullptr;
1575     auto alloc_shape = use_reshape ? dest_shape : x_shape;
1576     OP_REQUIRES_OK(context,
1577                    context->allocate_output(0, alloc_shape, &x_backprop));
1578 
1579     const TensorShape& scale_offset_shape = scale.shape();
1580     Tensor* scale_backprop = nullptr;
1581     OP_REQUIRES_OK(context, context->allocate_output(1, scale_offset_shape,
1582                                                      &scale_backprop));
1583     Tensor* offset_backprop = nullptr;
1584     OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape,
1585                                                      &offset_backprop));
1586     // Two placeholders for estimated_mean and estimated_variance, which are
1587     // used for inference and thus not needed here for gradient computation.
1588     // They are filled with zeros so as to avoid NaN outputs.
1589     Tensor* placeholder_1 = nullptr;
1590     OP_REQUIRES_OK(
1591         context, context->allocate_output(3, TensorShape({0}), &placeholder_1));
1592     Tensor* placeholder_2 = nullptr;
1593     OP_REQUIRES_OK(
1594         context, context->allocate_output(4, TensorShape({0}), &placeholder_2));
1595 
1596     Tensor* side_input_backprop = nullptr;
1597     if (has_side_input_) {
1598       OP_REQUIRES_OK(context, context->allocate_output(5, alloc_shape,
1599                                                        &side_input_backprop));
1600     }
1601 
1602     // If input is empty, set gradients w.r.t scale/offset to zero.
1603     if (x.shape().num_elements() == 0) {
1604       functor::SetZeroFunctor<Device, U> f;
1605       f(context->eigen_device<Device>(), scale_backprop->flat<U>());
1606       f(context->eigen_device<Device>(), offset_backprop->flat<U>());
1607       return;
1608     }
1609 
1610     if (is_training_) {
1611       functor::FusedBatchNormGrad<Device, T, U>()(
1612           context, y_backprop, x, scale, offset, saved_mean_or_pop_mean,
1613           saved_maybe_inv_var_or_pop_var, y, epsilon_, activation_mode_,
1614           x_backprop, scale_backprop, offset_backprop, side_input_backprop,
1615           use_reserved_space, tensor_format_);
1616     } else {
1617       OP_REQUIRES(
1618           context,
1619           activation_mode_ == FbnActivationMode::kIdentity && !has_side_input_,
1620           errors::InvalidArgument(
1621               "FusedBatchNormGrad with activation is only supported "
1622               "when is_training=True."));
1623       // Necessary layout conversion is currently done in python.
1624       OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
1625                   errors::InvalidArgument(
1626                       "The implementation of "
1627                       "FusedBatchNormGrad with is_training=False only support "
1628                       "NHWC tensor format for now."));
1629       functor::FusedBatchNormFreezeGrad<Device, T, U>()(
1630           context, y_backprop, x, scale, saved_mean_or_pop_mean,
1631           saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
1632           offset_backprop);
1633     }
1634     if (use_reshape) {
1635       OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
1636                   errors::InvalidArgument("Error during tensor copy."));
1637     }
1638   }
1639 
1640  private:
1641   U epsilon_;
1642   TensorFormat tensor_format_;
1643   bool is_training_;
1644   bool has_side_input_;
1645   FbnActivationMode activation_mode_;
1646 };
1647 
1648 template <typename Device, typename T, typename U>
1649 class FusedBatchNormGradOp : public FusedBatchNormGradOpBase<Device, T, U> {
1650  public:
FusedBatchNormGradOp(OpKernelConstruction * context)1651   explicit FusedBatchNormGradOp(OpKernelConstruction* context)
1652       : FusedBatchNormGradOpBase<Device, T, U>(context) {}
1653 
Compute(OpKernelContext * context)1654   void Compute(OpKernelContext* context) override {
1655     FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1656                                                                      false);
1657   }
1658 };
1659 
1660 template <typename Device, typename T, typename U>
1661 class FusedBatchNormGradOpV3 : public FusedBatchNormGradOpBase<Device, T, U> {
1662  public:
FusedBatchNormGradOpV3(OpKernelConstruction * context)1663   explicit FusedBatchNormGradOpV3(OpKernelConstruction* context)
1664       : FusedBatchNormGradOpBase<Device, T, U>(context) {}
1665 
Compute(OpKernelContext * context)1666   void Compute(OpKernelContext* context) override {
1667     FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1668                                                                      true);
1669   }
1670 };
1671 
1672 template <typename Device, typename T, typename U>
1673 class FusedBatchNormGradOpEx : public FusedBatchNormGradOpBase<Device, T, U> {
1674   static constexpr bool kWithSideInputAndActivation = true;
1675 
1676  public:
FusedBatchNormGradOpEx(OpKernelConstruction * context)1677   explicit FusedBatchNormGradOpEx(OpKernelConstruction* context)
1678       : FusedBatchNormGradOpBase<Device, T, U>(context,
1679                                                kWithSideInputAndActivation) {}
1680 
Compute(OpKernelContext * context)1681   void Compute(OpKernelContext* context) override {
1682     FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1683                                                                      true);
1684   }
1685 };
1686 
1687 REGISTER_KERNEL_BUILDER(
1688     Name("FusedBatchNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1689     FusedBatchNormOp<CPUDevice, float, float>);
1690 
1691 REGISTER_KERNEL_BUILDER(
1692     Name("FusedBatchNormGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1693     FusedBatchNormGradOp<CPUDevice, float, float>);
1694 
1695 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1696                             .Device(DEVICE_CPU)
1697                             .TypeConstraint<float>("T")
1698                             .TypeConstraint<float>("U"),
1699                         FusedBatchNormOp<CPUDevice, float, float>);
1700 
1701 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1702                             .Device(DEVICE_CPU)
1703                             .TypeConstraint<float>("T")
1704                             .TypeConstraint<float>("U"),
1705                         FusedBatchNormGradOp<CPUDevice, float, float>);
1706 
1707 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1708                             .Device(DEVICE_CPU)
1709                             .TypeConstraint<Eigen::half>("T")
1710                             .TypeConstraint<float>("U"),
1711                         FusedBatchNormOp<CPUDevice, Eigen::half, float>);
1712 
1713 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1714                             .Device(DEVICE_CPU)
1715                             .TypeConstraint<Eigen::half>("T")
1716                             .TypeConstraint<float>("U"),
1717                         FusedBatchNormGradOp<CPUDevice, Eigen::half, float>);
1718 
1719 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1720                             .Device(DEVICE_CPU)
1721                             .TypeConstraint<float>("T")
1722                             .TypeConstraint<float>("U"),
1723                         FusedBatchNormOpV3<CPUDevice, float, float>);
1724 
1725 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1726                             .Device(DEVICE_CPU)
1727                             .TypeConstraint<float>("T")
1728                             .TypeConstraint<float>("U"),
1729                         FusedBatchNormGradOpV3<CPUDevice, float, float>);
1730 
1731 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1732                             .Device(DEVICE_CPU)
1733                             .TypeConstraint<Eigen::half>("T")
1734                             .TypeConstraint<float>("U"),
1735                         FusedBatchNormOpV3<CPUDevice, Eigen::half, float>);
1736 
1737 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1738                             .Device(DEVICE_CPU)
1739                             .TypeConstraint<Eigen::half>("T")
1740                             .TypeConstraint<float>("U"),
1741                         FusedBatchNormGradOpV3<CPUDevice, Eigen::half, float>);
1742 
1743 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1744 
1745 REGISTER_KERNEL_BUILDER(
1746     Name("FusedBatchNorm").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1747     FusedBatchNormOp<GPUDevice, float, float>);
1748 
1749 REGISTER_KERNEL_BUILDER(
1750     Name("FusedBatchNormGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1751     FusedBatchNormGradOp<GPUDevice, float, float>);
1752 
1753 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1754                             .Device(DEVICE_GPU)
1755                             .TypeConstraint<float>("T")
1756                             .TypeConstraint<float>("U"),
1757                         FusedBatchNormOp<GPUDevice, float, float>);
1758 
1759 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1760                             .Device(DEVICE_GPU)
1761                             .TypeConstraint<float>("T")
1762                             .TypeConstraint<float>("U"),
1763                         FusedBatchNormGradOp<GPUDevice, float, float>);
1764 
1765 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1766                             .Device(DEVICE_GPU)
1767                             .TypeConstraint<Eigen::half>("T")
1768                             .TypeConstraint<float>("U"),
1769                         FusedBatchNormOp<GPUDevice, Eigen::half, float>);
1770 
1771 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1772                             .Device(DEVICE_GPU)
1773                             .TypeConstraint<Eigen::half>("T")
1774                             .TypeConstraint<float>("U"),
1775                         FusedBatchNormGradOp<GPUDevice, Eigen::half, float>);
1776 
1777 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1778                             .Device(DEVICE_GPU)
1779                             .TypeConstraint<float>("T")
1780                             .TypeConstraint<float>("U"),
1781                         FusedBatchNormOpV3<GPUDevice, float, float>);
1782 
1783 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1784                             .Device(DEVICE_GPU)
1785                             .TypeConstraint<float>("T")
1786                             .TypeConstraint<float>("U"),
1787                         FusedBatchNormOpEx<GPUDevice, float, float>);
1788 
1789 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1790                             .Device(DEVICE_GPU)
1791                             .TypeConstraint<float>("T")
1792                             .TypeConstraint<float>("U"),
1793                         FusedBatchNormGradOpV3<GPUDevice, float, float>);
1794 
1795 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormGradEx")
1796                             .Device(DEVICE_GPU)
1797                             .TypeConstraint<float>("T")
1798                             .TypeConstraint<float>("U"),
1799                         FusedBatchNormGradOpEx<GPUDevice, float, float>);
1800 
1801 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1802                             .Device(DEVICE_GPU)
1803                             .TypeConstraint<Eigen::half>("T")
1804                             .TypeConstraint<float>("U"),
1805                         FusedBatchNormOpV3<GPUDevice, Eigen::half, float>);
1806 
1807 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1808                             .Device(DEVICE_GPU)
1809                             .TypeConstraint<Eigen::half>("T")
1810                             .TypeConstraint<float>("U"),
1811                         FusedBatchNormOpEx<GPUDevice, Eigen::half, float>);
1812 
1813 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1814                             .Device(DEVICE_GPU)
1815                             .TypeConstraint<Eigen::half>("T")
1816                             .TypeConstraint<float>("U"),
1817                         FusedBatchNormGradOpV3<GPUDevice, Eigen::half, float>);
1818 
1819 REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormGradEx")
1820                             .Device(DEVICE_GPU)
1821                             .TypeConstraint<Eigen::half>("T")
1822                             .TypeConstraint<float>("U"),
1823                         FusedBatchNormGradOpEx<GPUDevice, Eigen::half, float>);
1824 
1825 #endif
1826 
1827 }  // namespace tensorflow
1828