1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 #define EIGEN_USE_GPU
18 #if GOOGLE_CUDA
19 #include "third_party/gpus/cuda/include/cuda.h"
20 #endif
21 #if TENSORFLOW_USE_ROCM
22 #include "rocm/include/hip/hip_fp16.h"
23 typedef __half2 half2;
24 #endif
25 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
26 #include "tensorflow/core/util/determinism.h"
27 #include "tensorflow/core/util/gpu_kernel_helper.h"
28
29 namespace tensorflow {
30 typedef Eigen::GpuDevice GPUDevice;
31
32 namespace functor {
33
34 // TODO(ezhulenev): Use CUB reductions on GPU.
35 template <typename T, typename U>
36 struct FusedBatchNormFreezeGrad<GPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormFreezeGrad37 void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
38 const Tensor& x_input, const Tensor& scale_input,
39 const Tensor& pop_mean_input,
40 const Tensor& pop_variance_input, U epsilon,
41 Tensor* x_backprop_output, Tensor* scale_backprop_output,
42 Tensor* offset_backprop_output) {
43 typename TTypes<T, 4>::ConstTensor y_backprop(
44 y_backprop_input.tensor<T, 4>());
45 typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
46 typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
47 typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
48 typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
49 typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
50 typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
51 typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
52
53 const int depth = pop_mean.dimension(0);
54 const int rest_size = input.size() / depth;
55
56 // Allocate two temporary workspaces of [depth] shape.
57 Tensor scratch1_vec, scratch2_vec;
58 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
59 {depth}, &scratch1_vec));
60 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
61 {depth}, &scratch2_vec));
62
63 typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>());
64 typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>());
65
66 const GPUDevice& d = context->eigen_device<GPUDevice>();
67
68 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
69 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
70 one_by_depth.set(1, depth);
71 Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
72 Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > rest_by_one;
73 rest_by_one.set(0, rest_size);
74
75 OP_REQUIRES(
76 context, !OpDeterminismRequired(),
77 errors::Unimplemented(
78 "A deterministic GPU implementation of fused batch-norm backprop,"
79 " when training is disabled, is not currently available."));
80
81 OP_REQUIRES(
82 context, !OpDeterminismRequired(),
83 errors::Unimplemented(
84 "A deterministic GPU implementation of fused batch-norm backprop,"
85 " when training is disabled, is not currently available."));
86
87 // offset_backprop = sum(y_backprop)
88 // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
89 // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
90
91 auto y_backprop_rest_by_depth =
92 y_backprop.reshape(rest_by_depth).template cast<U>();
93 auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
94
95 offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis);
96
97 // scratch1 = rsqrt(pop_var + epsilon)
98 scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
99
100 // scratch2 = sum(y_backprop * (x - mean))
101 scratch2.device(d) =
102 (y_backprop_rest_by_depth *
103 (input_rest_by_depth -
104 pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
105 .sum(reduction_axis);
106
107 x_backprop.reshape(rest_by_depth).device(d) =
108 (y_backprop_rest_by_depth *
109 ((scratch1 * scale).reshape(one_by_depth).broadcast(rest_by_one)))
110 .template cast<T>();
111 scale_backprop.device(d) = scratch2 * scratch1;
112 }
113 };
114
115 template struct FusedBatchNormFreezeGrad<GPUDevice, float, float>;
116 template struct FusedBatchNormFreezeGrad<GPUDevice, Eigen::half, float>;
117
118 // -------------------------------------------------------------------------- //
119 // FusedBatchNormInferenceFunctor implementation. //
120 // -------------------------------------------------------------------------- //
121
122 // Generic kernel, that does all computations by converting input to U data
123 // type. We use it when CUDA architecture doesn't have fast arithmetic fot the
124 // T data type (e.g. no fp16 in old GPU generations).
125 template <typename T, typename U, TensorFormat tensor_format,
126 bool add_side_input, FusedBatchNormActivationMode activation_mode,
127 bool is_generic_kernel>
128 struct FusedBatchNormInferenceKernel {
129 static_assert(tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW,
130 "Unsupported data format");
131
runtensorflow::functor::FusedBatchNormInferenceKernel132 __device__ static void run(int32 count, int32 channels_size,
133 int32 inner_dim_size, const T* __restrict__ in,
134 const U* __restrict__ scale,
135 const U* __restrict__ offset,
136 const U* __restrict__ mean,
137 const U* __restrict__ var,
138 const T* __restrict__ side_input, float epsilon,
139 T* __restrict__ out) {
140 int32 index = blockIdx.x * blockDim.x + threadIdx.x;
141 const int32 total_device_threads = gridDim.x * blockDim.x;
142
143 while (index < count) {
144 const int channel = (tensor_format == FORMAT_NHWC)
145 ? index % channels_size
146 : (index / inner_dim_size) % channels_size;
147
148 U in_v = U(in[index]);
149 U scale_v = scale[channel];
150 U offset_v = offset[channel];
151 U mean_v = mean[channel];
152 U var_v = var[channel];
153
154 U scaling_factor_v = rsqrt(var_v + epsilon) * scale_v;
155 static_assert(std::is_same<U, float>::value, "U data type must be float");
156 U shifted_v = fmaf(in_v - mean_v, scaling_factor_v, offset_v);
157
158 if (add_side_input) {
159 shifted_v += U(side_input[index]);
160 }
161
162 if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
163 out[index] = T(shifted_v);
164 } else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
165 out[index] = T(shifted_v < U(0) ? U(0) : shifted_v);
166 }
167
168 index += total_device_threads;
169 }
170 }
171 };
172
173 // Specialization for T=Eigen::half and U=float.
174 template <TensorFormat tensor_format, bool add_side_input,
175 FusedBatchNormActivationMode activation_mode>
176 struct FusedBatchNormInferenceKernel<Eigen::half, float, tensor_format,
177 add_side_input, activation_mode,
178 /*is_generic_kernel=*/false> {
179 #if TENSORFLOW_USE_ROCM
180 using IT = __half;
181 #else
182 using IT = Eigen::half;
183 #endif
184 using T = Eigen::half;
185 using U = float;
186
187 // If CUDA architecture doesn't support fast fp16 computation, we will
188 // fallback on generic kernel defined above.
189 using GenericKernel =
190 FusedBatchNormInferenceKernel<T, U, tensor_format, add_side_input,
191 activation_mode,
192 /*is_generic_kernel=*/true>;
193
runtensorflow::functor::FusedBatchNormInferenceKernel194 __device__ static void run(int32 count, int32 channels_size,
195 int32 inner_dim_size, const T* __restrict__ _in,
196 const U* __restrict__ scale,
197 const U* __restrict__ offset,
198 const U* __restrict__ mean,
199 const U* __restrict__ var,
200 const T* __restrict__ _side_input, float epsilon,
201 T* __restrict__ _out) {
202 // Old GPUs do not have (or have very slow) fp16 arithmetic.
203 #if (__CUDA_ARCH__ >= 610) || TENSORFLOW_USE_ROCM
204 const IT* in = reinterpret_cast<const IT*>(_in);
205 const IT* side_input = reinterpret_cast<const IT*>(_side_input);
206 IT* out = reinterpret_cast<IT*>(_out);
207
208 int32 index = blockIdx.x * blockDim.x + threadIdx.x;
209 const int32 total_device_threads = gridDim.x * blockDim.x;
210
211 int32 half2_count = count >> 1;
212
213 half epsilon_h = __float2half(epsilon);
214 half2 epsilon_h2 = __float2half2_rn(epsilon);
215
216 const int32 max_channel_size = channels_size - 1;
217
218 while (index < half2_count) {
219 int32 channel[2];
220 if (tensor_format == FORMAT_NHWC) {
221 channel[0] = (2 * index) % channels_size;
222 channel[1] = channel[0] == max_channel_size ? 0 : channel[0] + 1;
223 } else {
224 channel[0] = ((2 * index) / inner_dim_size) % channels_size;
225 channel[1] = ((2 * index + 1) / inner_dim_size) % channels_size;
226 }
227
228 half2 in_v = reinterpret_cast<const half2*>(in)[index];
229 half2 scale_v = __floats2half2_rn(scale[channel[0]], scale[channel[1]]);
230 half2 offset_v =
231 __floats2half2_rn(offset[channel[0]], offset[channel[1]]);
232 half2 mean_v = __floats2half2_rn(mean[channel[0]], mean[channel[1]]);
233 half2 var_v = __floats2half2_rn(var[channel[0]], var[channel[1]]);
234
235 half2 scaling_factor_v =
236 __hmul2(h2rsqrt(__hadd2(var_v, epsilon_h2)), scale_v);
237 half2 shifted_v =
238 __hfma2(__hsub2(in_v, mean_v), scaling_factor_v, offset_v);
239
240 if (add_side_input) {
241 shifted_v = __hadd2(shifted_v,
242 reinterpret_cast<const half2*>(side_input)[index]);
243 }
244
245 if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
246 reinterpret_cast<half2*>(out)[index] = shifted_v;
247
248 } else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
249 const half2 kZeroH = __float2half2_rn(0.f);
250 const half2 mask_h = __hgt2(shifted_v, kZeroH);
251 reinterpret_cast<half2*>(out)[index] = __hmul2(mask_h, shifted_v);
252 }
253
254 index += total_device_threads;
255 }
256
257 if ((count & 0x1) == 1 && index == half2_count) {
258 index = count - 1;
259
260 const int32 channel = (tensor_format == FORMAT_NHWC)
261 ? index % channels_size
262 : (index / inner_dim_size) % channels_size;
263
264 half in_v = in[index];
265 half scale_v = __float2half(scale[channel]);
266 half offset_v = __float2half(offset[channel]);
267 half mean_v = __float2half(mean[channel]);
268 half var_v = __float2half(var[channel]);
269
270 half scaling_factor_v = __hmul(hrsqrt(__hadd(var_v, epsilon_h)), scale_v);
271 half shifted_v = __hfma(__hsub(in_v, mean_v), scaling_factor_v, offset_v);
272
273 if (add_side_input) {
274 shifted_v = __hadd(shifted_v, side_input[index]);
275 }
276
277 if (activation_mode == FusedBatchNormActivationMode::kIdentity) {
278 out[index] = shifted_v;
279
280 } else if (activation_mode == FusedBatchNormActivationMode::kRelu) {
281 const half kZeroH = __float2half(0.f);
282 const half mask_h = __hgt(shifted_v, kZeroH);
283 out[index] = __hmul(mask_h, shifted_v);
284 }
285 }
286
287 #else
288 GenericKernel::run(count, channels_size, inner_dim_size, _in, scale, offset,
289 mean, var, _side_input, epsilon, _out);
290 #endif // __CUDA_ARCH__ >= 610
291 }
292 };
293
294 template <typename T, typename U, TensorFormat tensor_format,
295 bool add_side_input, FusedBatchNormActivationMode activation_mode>
FusedBatchNormInferenceMetaKernel(int32 count,int32 channels_size,int32 inner_dim_size,const T * in,const U * scale,const U * offset,const U * mean,const U * var,const T * side_input,float epsilon,T * out)296 __global__ void FusedBatchNormInferenceMetaKernel(
297 int32 count, int32 channels_size, int32 inner_dim_size, const T* in,
298 const U* scale, const U* offset, const U* mean, const U* var,
299 const T* side_input, float epsilon, T* out) {
300 // We prefer to run non-generic specialization, for the given types T and U.
301 FusedBatchNormInferenceKernel<T, U, tensor_format, add_side_input,
302 activation_mode,
303 #if TENSORFLOW_USE_ROCM
304 false
305 #else
306 // TODO(b/135435976): Temporary disable
307 // non-generic kernel implementation.
308 /*is_generic_kernel=*/true
309 #endif
310 >::run(count, channels_size, inner_dim_size, in,
311 scale, offset, mean, var, side_input,
312 epsilon, out);
313 }
314
315 template <typename T, typename U>
316 struct FusedBatchNormInferenceFunctor<GPUDevice, T, U> {
operator ()tensorflow::functor::FusedBatchNormInferenceFunctor317 void operator()(OpKernelContext* context, TensorFormat tensor_format,
318 typename TTypes<T, 4>::ConstTensor in,
319 typename TTypes<U>::ConstVec scale,
320 typename TTypes<U>::ConstVec offset,
321 typename TTypes<U>::ConstVec estimated_mean,
322 typename TTypes<U>::ConstVec estimated_variance,
323 typename TTypes<T, 4>::ConstTensor side_input, U epsilon,
324 FusedBatchNormActivationMode activation_mode,
325 typename TTypes<T, 4>::Tensor out) {
326 const auto& d = context->eigen_device<GPUDevice>();
327
328 const int32 count = out.size();
329 if (count == 0) return;
330
331 bool launched = false;
332 #if TENSORFLOW_USE_ROCM
333 constexpr int32 kThreadInBlock = 1024;
334 #else
335 constexpr int32 kThreadInBlock = 512;
336 #endif
337
338 #define LAUNCH(DATA_FORMAT, ADD_SIDE_INPUT, ACTIVATION, CHANNEL_SIZE, \
339 INNER_DIM_SIZE) \
340 launched = true; \
341 \
342 GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize( \
343 std::is_same<T, Eigen::half>::value ? Eigen::divup(count, 2) : count, d, \
344 FusedBatchNormInferenceMetaKernel<T, U, DATA_FORMAT, ADD_SIDE_INPUT, \
345 ACTIVATION>, \
346 0, kThreadInBlock); \
347 \
348 TF_CHECK_OK(GpuLaunchKernel( \
349 FusedBatchNormInferenceMetaKernel<T, U, DATA_FORMAT, ADD_SIDE_INPUT, \
350 ACTIVATION>, \
351 config.block_count, config.thread_per_block, 0, d.stream(), count, \
352 CHANNEL_SIZE, INNER_DIM_SIZE, in.data(), scale.data(), offset.data(), \
353 estimated_mean.data(), estimated_variance.data(), side_input.data(), \
354 epsilon, out.data()));
355
356 const bool no_side_input = side_input.dimensions().TotalSize() == 0;
357 const bool add_side_input = side_input.dimensions().TotalSize() != 0;
358
359 using Activation = FusedBatchNormActivationMode;
360 const bool no_activation = activation_mode == Activation::kIdentity;
361 const bool relu_activation = activation_mode == Activation::kRelu;
362
363 if (tensor_format == FORMAT_NHWC) {
364 const int c = in.dimensions()[3];
365
366 if (no_activation && no_side_input) {
367 LAUNCH(FORMAT_NHWC, false, Activation::kIdentity, c, 1);
368 } else if (relu_activation && no_side_input) {
369 LAUNCH(FORMAT_NHWC, false, Activation::kRelu, c, 1);
370 } else if (no_activation && add_side_input) {
371 LAUNCH(FORMAT_NHWC, true, Activation::kIdentity, c, 1);
372 } else if (relu_activation && add_side_input) {
373 LAUNCH(FORMAT_NHWC, true, Activation::kRelu, c, 1);
374 }
375
376 } else if (tensor_format == FORMAT_NCHW) {
377 const int c = in.dimensions()[1];
378 const int inner = in.dimensions()[2] * in.dimensions()[3];
379
380 if (no_activation && no_side_input) {
381 LAUNCH(FORMAT_NCHW, false, Activation::kIdentity, c, inner);
382 } else if (relu_activation && no_side_input) {
383 LAUNCH(FORMAT_NCHW, false, Activation::kRelu, c, inner);
384 } else if (no_activation && add_side_input) {
385 LAUNCH(FORMAT_NCHW, true, Activation::kIdentity, c, inner);
386 } else if (relu_activation && add_side_input) {
387 LAUNCH(FORMAT_NCHW, true, Activation::kRelu, c, inner);
388 }
389 }
390 #undef LAUNCH
391
392 OP_REQUIRES(context, launched,
393 errors::InvalidArgument("Unsupported launch configuration"));
394 }
395 };
396
397 template struct FusedBatchNormInferenceFunctor<GPUDevice, float, float>;
398 template struct FusedBatchNormInferenceFunctor<GPUDevice, Eigen::half, float>;
399
400 } // namespace functor
401 } // namespace tensorflow
402
403 #else
404
405 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
406
407 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
408