xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/fused_batch_norm_op.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 #define EIGEN_USE_GPU
18 #if GOOGLE_CUDA
19 #include "third_party/gpus/cuda/include/cuda.h"
20 #endif
21 #if TENSORFLOW_USE_ROCM
22 #include "rocm/include/hip/hip_fp16.h"
23 typedef __half2 half2;
24 #endif
25 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
26 #include "tensorflow/core/util/determinism.h"
27 #include "tensorflow/core/util/gpu_kernel_helper.h"
28 
29 namespace tensorflow {
30 typedef Eigen::GpuDevice GPUDevice;
31 
32 namespace functor {
33 
34 // TODO(ezhulenev): Use CUB reductions on GPU.
35 template <typename T, typename U>
36 struct FusedBatchNormFreezeGrad<GPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormFreezeGrad37   void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
38                   const Tensor& x_input, const Tensor& scale_input,
39                   const Tensor& pop_mean_input,
40                   const Tensor& pop_variance_input, U epsilon,
41                   Tensor* x_backprop_output, Tensor* scale_backprop_output,
42                   Tensor* offset_backprop_output) {
43     typename TTypes<T, 4>::ConstTensor y_backprop(
44         y_backprop_input.tensor<T, 4>());
45     typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
46     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
47     typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
48     typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
49     typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
50     typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
51     typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
52 
53     const int depth = pop_mean.dimension(0);
54     const int rest_size = input.size() / depth;
55 
56     // Allocate two temporary workspaces of [depth] shape.
57     Tensor scratch1_vec, scratch2_vec;
58     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
59                                                    {depth}, &scratch1_vec));
60     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
61                                                    {depth}, &scratch2_vec));
62 
63     typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>());
64     typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>());
65 
66     const GPUDevice& d = context->eigen_device<GPUDevice>();
67 
68     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
69     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
70     one_by_depth.set(1, depth);
71     Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
72     Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > rest_by_one;
73     rest_by_one.set(0, rest_size);
74 
75     OP_REQUIRES(
76         context, !OpDeterminismRequired(),
77         errors::Unimplemented(
78             "A deterministic GPU implementation of fused batch-norm backprop,"
79             " when training is disabled, is not currently available."));
80 
81     OP_REQUIRES(
82         context, !OpDeterminismRequired(),
83         errors::Unimplemented(
84             "A deterministic GPU implementation of fused batch-norm backprop,"
85             " when training is disabled, is not currently available."));
86 
87     // offset_backprop  = sum(y_backprop)
88     // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
89     // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
90 
91     auto y_backprop_rest_by_depth =
92         y_backprop.reshape(rest_by_depth).template cast<U>();
93     auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
94 
95     offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis);
96 
97     // scratch1 = rsqrt(pop_var + epsilon)
98     scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
99 
100     // scratch2 = sum(y_backprop * (x - mean))
101     scratch2.device(d) =
102         (y_backprop_rest_by_depth *
103          (input_rest_by_depth -
104           pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
105             .sum(reduction_axis);
106 
107     x_backprop.reshape(rest_by_depth).device(d) =
108         (y_backprop_rest_by_depth *
109          ((scratch1 * scale).reshape(one_by_depth).broadcast(rest_by_one)))
110             .template cast<T>();
111     scale_backprop.device(d) = scratch2 * scratch1;
112   }
113 };
114 
115 template struct FusedBatchNormFreezeGrad<GPUDevice, float, float>;
116 template struct FusedBatchNormFreezeGrad<GPUDevice, Eigen::half, float>;
117 
118 // -------------------------------------------------------------------------- //
119 // FusedBatchNormInferenceFunctor implementation.                             //
120 // -------------------------------------------------------------------------- //
121 
122 // Generic kernel, that does all computations by converting input to U data
123 // type. We use it when CUDA architecture doesn't have fast arithmetic fot the
124 // T data type (e.g. no fp16 in old GPU generations).
125 template <typename T, typename U, TensorFormat tensor_format,
126           bool add_side_input, FusedBatchNormActivationMode activation_mode,
127           bool is_generic_kernel>
128 struct FusedBatchNormInferenceKernel {
129   static_assert(tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW,
130                 "Unsupported data format");
131 
runtensorflow::functor::FusedBatchNormInferenceKernel132   __device__ static void run(int32 count, int32 channels_size,
133                              int32 inner_dim_size, const T* __restrict__ in,
134                              const U* __restrict__ scale,
135                              const U* __restrict__ offset,
136                              const U* __restrict__ mean,
137                              const U* __restrict__ var,
138                              const T* __restrict__ side_input, float epsilon,
139                              T* __restrict__ out) {
140     int32 index = blockIdx.x * blockDim.x + threadIdx.x;
141     const int32 total_device_threads = gridDim.x * blockDim.x;
142 
143     while (index < count) {
144       const int channel = (tensor_format == FORMAT_NHWC)
145                               ? index % channels_size
146                               : (index / inner_dim_size) % channels_size;
147 
148       U in_v = U(in[index]);
149       U scale_v = scale[channel];
150       U offset_v = offset[channel];
151       U mean_v = mean[channel];
152       U var_v = var[channel];
153 
154       U scaling_factor_v = rsqrt(var_v + epsilon) * scale_v;
155       static_assert(std::is_same<U, float>::value, "U data type must be float");
156       U shifted_v = fmaf(in_v - mean_v, scaling_factor_v, offset_v);
157 
158       if (add_side_input) {
159         shifted_v += U(side_input[index]);
160       }
161 
162       if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
163         out[index] = T(shifted_v);
164       } else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
165         out[index] = T(shifted_v < U(0) ? U(0) : shifted_v);
166       }
167 
168       index += total_device_threads;
169     }
170   }
171 };
172 
173 // Specialization for T=Eigen::half and U=float.
174 template <TensorFormat tensor_format, bool add_side_input,
175           FusedBatchNormActivationMode activation_mode>
176 struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
177                                      add_side_input, activation_mode,
178                                      /*is_generic_kernel=*/false> {
179 #if TENSORFLOW_USE_ROCM
180   using IT = __half;
181 #else
182   using IT = Eigen::half;
183 #endif
184   using T = Eigen::half;
185   using U = float;
186 
187   // If CUDA architecture doesn't support fast fp16 computation, we will
188   // fallback on generic kernel defined above.
189   using GenericKernel =
190       FusedBatchNormInferenceKernel<T, U, tensor_format, add_side_input,
191                                     activation_mode,
192                                     /*is_generic_kernel=*/true>;
193 
runtensorflow::functor::FusedBatchNormInferenceKernel194   __device__ static void run(int32 count, int32 channels_size,
195                              int32 inner_dim_size, const T* __restrict__ _in,
196                              const U* __restrict__ scale,
197                              const U* __restrict__ offset,
198                              const U* __restrict__ mean,
199                              const U* __restrict__ var,
200                              const T* __restrict__ _side_input, float epsilon,
201                              T* __restrict__ _out) {
202     // Old GPUs do not have (or have very slow) fp16 arithmetic.
203 #if (__CUDA_ARCH__ >= 610) || TENSORFLOW_USE_ROCM
204     const IT* in = reinterpret_cast<const IT*>(_in);
205     const IT* side_input = reinterpret_cast<const IT*>(_side_input);
206     IT* out = reinterpret_cast<IT*>(_out);
207 
208     int32 index = blockIdx.x * blockDim.x + threadIdx.x;
209     const int32 total_device_threads = gridDim.x * blockDim.x;
210 
211     int32 half2_count = count >> 1;
212 
213     half epsilon_h = __float2half(epsilon);
214     half2 epsilon_h2 = __float2half2_rn(epsilon);
215 
216     const int32 max_channel_size = channels_size - 1;
217 
218     while (index < half2_count) {
219       int32 channel[2];
220       if (tensor_format == FORMAT_NHWC) {
221         channel[0] = (2 * index) % channels_size;
222         channel[1] = channel[0] == max_channel_size ? 0 : channel[0] + 1;
223       } else {
224         channel[0] = ((2 * index) / inner_dim_size) % channels_size;
225         channel[1] = ((2 * index + 1) / inner_dim_size) % channels_size;
226       }
227 
228       half2 in_v = reinterpret_cast<const half2*>(in)[index];
229       half2 scale_v = __floats2half2_rn(scale[channel[0]], scale[channel[1]]);
230       half2 offset_v =
231           __floats2half2_rn(offset[channel[0]], offset[channel[1]]);
232       half2 mean_v = __floats2half2_rn(mean[channel[0]], mean[channel[1]]);
233       half2 var_v = __floats2half2_rn(var[channel[0]], var[channel[1]]);
234 
235       half2 scaling_factor_v =
236           __hmul2(h2rsqrt(__hadd2(var_v, epsilon_h2)), scale_v);
237       half2 shifted_v =
238           __hfma2(__hsub2(in_v, mean_v), scaling_factor_v, offset_v);
239 
240       if (add_side_input) {
241         shifted_v = __hadd2(shifted_v,
242                             reinterpret_cast<const half2*>(side_input)[index]);
243       }
244 
245       if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
246         reinterpret_cast<half2*>(out)[index] = shifted_v;
247 
248       } else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
249         const half2 kZeroH = __float2half2_rn(0.f);
250         const half2 mask_h = __hgt2(shifted_v, kZeroH);
251         reinterpret_cast<half2*>(out)[index] = __hmul2(mask_h, shifted_v);
252       }
253 
254       index += total_device_threads;
255     }
256 
257     if ((count & 0x1) == 1 && index == half2_count) {
258       index = count - 1;
259 
260       const int32 channel = (tensor_format == FORMAT_NHWC)
261                                 ? index % channels_size
262                                 : (index / inner_dim_size) % channels_size;
263 
264       half in_v = in[index];
265       half scale_v = __float2half(scale[channel]);
266       half offset_v = __float2half(offset[channel]);
267       half mean_v = __float2half(mean[channel]);
268       half var_v = __float2half(var[channel]);
269 
270       half scaling_factor_v = __hmul(hrsqrt(__hadd(var_v, epsilon_h)), scale_v);
271       half shifted_v = __hfma(__hsub(in_v, mean_v), scaling_factor_v, offset_v);
272 
273       if (add_side_input) {
274         shifted_v = __hadd(shifted_v, side_input[index]);
275       }
276 
277       if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
278         out[index] = shifted_v;
279 
280       } else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
281         const half kZeroH = __float2half(0.f);
282         const half mask_h = __hgt(shifted_v, kZeroH);
283         out[index] = __hmul(mask_h, shifted_v);
284       }
285     }
286 
287 #else
288     GenericKernel::run(count, channels_size, inner_dim_size, _in, scale, offset,
289                        mean, var, _side_input, epsilon, _out);
290 #endif  // __CUDA_ARCH__ >= 610
291   }
292 };
293 
294 template <typename T, typename U, TensorFormat tensor_format,
295           bool add_side_input, FusedBatchNormActivationMode activation_mode>
FusedBatchNormInferenceMetaKernel(int32 count,int32 channels_size,int32 inner_dim_size,const T * in,const U * scale,const U * offset,const U * mean,const U * var,const T * side_input,float epsilon,T * out)296 __global__ void FusedBatchNormInferenceMetaKernel(
297     int32 count, int32 channels_size, int32 inner_dim_size, const T* in,
298     const U* scale, const U* offset, const U* mean, const U* var,
299     const T* side_input, float epsilon, T* out) {
300   // We prefer to run non-generic specialization, for the given types T and U.
301   FusedBatchNormInferenceKernel<T, U, tensor_format, add_side_input,
302                                 activation_mode,
303 #if TENSORFLOW_USE_ROCM
304                                 false
305 #else
306                                 // TODO(b/135435976): Temporary disable
307                                 // non-generic kernel implementation.
308                                 /*is_generic_kernel=*/true
309 #endif
310                                 >::run(count, channels_size, inner_dim_size, in,
311                                        scale, offset, mean, var, side_input,
312                                        epsilon, out);
313 }
314 
315 template <typename T, typename U>
316 struct FusedBatchNormInferenceFunctor<GPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormInferenceFunctor317   void operator()(OpKernelContext* context, TensorFormat tensor_format,
318                   typename TTypes<T, 4>::ConstTensor in,
319                   typename TTypes<U>::ConstVec scale,
320                   typename TTypes<U>::ConstVec offset,
321                   typename TTypes<U>::ConstVec estimated_mean,
322                   typename TTypes<U>::ConstVec estimated_variance,
323                   typename TTypes<T, 4>::ConstTensor side_input, U epsilon,
324                   FusedBatchNormActivationMode activation_mode,
325                   typename TTypes<T, 4>::Tensor out) {
326     const auto& d = context->eigen_device<GPUDevice>();
327 
328     const int32 count = out.size();
329     if (count == 0) return;
330 
331     bool launched = false;
332 #if TENSORFLOW_USE_ROCM
333     constexpr int32 kThreadInBlock = 1024;
334 #else
335     constexpr int32 kThreadInBlock = 512;
336 #endif
337 
338 #define LAUNCH(DATA_FORMAT, ADD_SIDE_INPUT, ACTIVATION, CHANNEL_SIZE,          \
339                INNER_DIM_SIZE)                                                 \
340   launched = true;                                                             \
341                                                                                \
342   GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(                   \
343       std::is_same<T, Eigen::half>::value ? Eigen::divup(count, 2) : count, d, \
344       FusedBatchNormInferenceMetaKernel<T, U, DATA_FORMAT, ADD_SIDE_INPUT,     \
345                                         ACTIVATION>,                           \
346       0, kThreadInBlock);                                                      \
347                                                                                \
348   TF_CHECK_OK(GpuLaunchKernel(                                                 \
349       FusedBatchNormInferenceMetaKernel<T, U, DATA_FORMAT, ADD_SIDE_INPUT,     \
350                                         ACTIVATION>,                           \
351       config.block_count, config.thread_per_block, 0, d.stream(), count,       \
352       CHANNEL_SIZE, INNER_DIM_SIZE, in.data(), scale.data(), offset.data(),    \
353       estimated_mean.data(), estimated_variance.data(), side_input.data(),     \
354       epsilon, out.data()));
355 
356     const bool no_side_input = side_input.dimensions().TotalSize() == 0;
357     const bool add_side_input = side_input.dimensions().TotalSize() != 0;
358 
359     using Activation = FusedBatchNormActivationMode;
360     const bool no_activation = activation_mode == Activation::kIdentity;
361     const bool relu_activation = activation_mode == Activation::kRelu;
362 
363     if (tensor_format == FORMAT_NHWC) {
364       const int c = in.dimensions()[3];
365 
366       if (no_activation && no_side_input) {
367         LAUNCH(FORMAT_NHWC, false, Activation::kIdentity, c, 1);
368       } else if (relu_activation && no_side_input) {
369         LAUNCH(FORMAT_NHWC, false, Activation::kRelu, c, 1);
370       } else if (no_activation && add_side_input) {
371         LAUNCH(FORMAT_NHWC, true, Activation::kIdentity, c, 1);
372       } else if (relu_activation && add_side_input) {
373         LAUNCH(FORMAT_NHWC, true, Activation::kRelu, c, 1);
374       }
375 
376     } else if (tensor_format == FORMAT_NCHW) {
377       const int c = in.dimensions()[1];
378       const int inner = in.dimensions()[2] * in.dimensions()[3];
379 
380       if (no_activation && no_side_input) {
381         LAUNCH(FORMAT_NCHW, false, Activation::kIdentity, c, inner);
382       } else if (relu_activation && no_side_input) {
383         LAUNCH(FORMAT_NCHW, false, Activation::kRelu, c, inner);
384       } else if (no_activation && add_side_input) {
385         LAUNCH(FORMAT_NCHW, true, Activation::kIdentity, c, inner);
386       } else if (relu_activation && add_side_input) {
387         LAUNCH(FORMAT_NCHW, true, Activation::kRelu, c, inner);
388       }
389     }
390 #undef LAUNCH
391 
392     OP_REQUIRES(context, launched,
393                 errors::InvalidArgument("Unsupported launch configuration"));
394   }
395 };
396 
397 template struct FusedBatchNormInferenceFunctor<GPUDevice, float, float>;
398 template struct FusedBatchNormInferenceFunctor<GPUDevice, Eigen::half, float>;
399 
400 }  // namespace functor
401 }  // namespace tensorflow
402 
403 #else
404 
405 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
406 
407 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
408