1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA implementation of BatchNorm operations. 17 #include <algorithm> 18 #include <numeric> 19 #include <string> 20 21 #include "tensorflow/compiler/tf2xla/kernels/relu_op.h" 22 #include "tensorflow/compiler/tf2xla/type_util.h" 23 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 26 #include "tensorflow/compiler/xla/client/lib/constants.h" 27 #include "tensorflow/compiler/xla/client/lib/math.h" 28 #include "tensorflow/compiler/xla/client/xla_builder.h" 29 #include "tensorflow/compiler/xla/util.h" 30 #include "tensorflow/core/util/tensor_format.h" 31 32 namespace tensorflow { 33 namespace { 34 35 class FusedBatchNormOp : public XlaOpKernel { 36 public: FusedBatchNormOp(OpKernelConstruction * ctx)37 explicit FusedBatchNormOp(OpKernelConstruction* ctx) 38 : FusedBatchNormOp(ctx, false) {} 39 FusedBatchNormOp(OpKernelConstruction * ctx,bool is_batch_norm_ex)40 FusedBatchNormOp(OpKernelConstruction* ctx, bool is_batch_norm_ex) 41 : XlaOpKernel(ctx) { 42 OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); 43 OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); 44 OP_REQUIRES_OK( 45 ctx, ctx->GetAttr("exponential_avg_factor", &exponential_avg_factor_)); 46 string data_format_str; 47 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); 48 OP_REQUIRES( 49 ctx, FormatFromString(data_format_str, &data_format_), 50 errors::InvalidArgument("Invalid data format: ", data_format_str)); 51 52 if (is_batch_norm_ex) { 53 int num_side_inputs; 54 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_side_inputs", &num_side_inputs)); 55 OP_REQUIRES(ctx, num_side_inputs >= 0 && num_side_inputs <= 1, 56 errors::InvalidArgument( 57 "FusedBatchNormEx supports at most 1 side input.")); 58 add_side_input_ = (num_side_inputs == 1); 59 string activation_mode; 60 OP_REQUIRES_OK(ctx, ctx->GetAttr("activation_mode", &activation_mode)); 61 OP_REQUIRES(ctx, 62 activation_mode == "Identity" || activation_mode == "Relu", 63 errors::InvalidArgument( 64 "Unsupported FusedBatchNormEx activation mode: ", 65 activation_mode)); 66 apply_relu_ = (activation_mode == "Relu"); 67 } else { 68 add_side_input_ = false; 69 apply_relu_ = false; 70 } 71 is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; 72 } 73 Compile(XlaOpKernelContext * ctx)74 void Compile(XlaOpKernelContext* ctx) override { CompileImpl(ctx); } 75 76 protected: CompileImpl(XlaOpKernelContext * ctx)77 virtual void CompileImpl(XlaOpKernelContext* ctx) { 78 xla::XlaBuilder* const b = ctx->builder(); 79 xla::PrimitiveType input_type; 80 OP_REQUIRES_OK(ctx, 81 DataTypeToPrimitiveType(ctx->input_type(0), &input_type)); 82 xla::PrimitiveType scale_type; 83 OP_REQUIRES_OK(ctx, 84 DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); 85 86 xla::XlaOp input = ctx->Input(0); 87 TensorShape input_shape = ctx->InputShape(0); 88 89 int feature_index = 90 GetTensorFeatureDimIndex(input_shape.dims(), data_format_); 91 92 // TODO(b/69928690): support mixed precision in the XLA batch normalization 93 // operators. As a workaround, cast everything to the statistics type (which 94 // may be more precise than the input type). 95 input = xla::ConvertElementType(input, scale_type); 96 97 if (is_training_) { 98 xla::XlaOp output = xla::BatchNormTraining( 99 input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); 100 101 // In training mode, outputs the normalized value as well as the 102 // calculated mean and variance. Optionally we add side input and apply 103 // relu activation. 104 xla::XlaOp converted = 105 xla::ConvertElementType(xla::GetTupleElement(output, 0), input_type); 106 if (add_side_input_ && apply_relu_) { 107 ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted))); 108 } else if (apply_relu_) { 109 ctx->SetOutput(0, xla::Relu(converted)); 110 } else { 111 ctx->SetOutput(0, converted); 112 } 113 114 xla::XlaOp variance = xla::GetTupleElement(output, 2); 115 // Apply Bessel's correction. 116 int total_input_size = ctx->InputShape(0).num_elements(); 117 int total_scale_size = ctx->InputShape(1).num_elements(); 118 int sample_size = 119 total_scale_size > 0 ? total_input_size / total_scale_size : 0; 120 int sample_size_minus_one = std::max(1, sample_size - 1); 121 double factor = static_cast<double>(sample_size) / 122 static_cast<double>(sample_size_minus_one); 123 124 constexpr int kVarianceOutputIndex = 2; 125 xla::XlaOp corrected = 126 xla::Mul(variance, xla::ScalarLike(variance, factor)); 127 if (input_shape.num_elements() == 0) { 128 auto status_or_output_shape = b->GetShape(corrected); 129 OP_REQUIRES_OK(ctx, status_or_output_shape.status()); 130 ctx->SetOutput(1, xla::GetTupleElement(output, 1)); 131 ctx->SetOutput( 132 kVarianceOutputIndex, 133 xla::Broadcast( 134 xla::NanValue(b, ctx->output_xla_type(kVarianceOutputIndex)), 135 status_or_output_shape.ValueOrDie().dimensions())); 136 137 } else { 138 if (exponential_avg_factor_ == 1.0f) { 139 ctx->SetOutput(1, xla::GetTupleElement(output, 1)); 140 ctx->SetOutput(2, corrected); 141 } else { 142 xla::XlaOp old_mean = ctx->Input(3); 143 xla::XlaOp alpha = 144 xla::ScalarLike(old_mean, 1.0f - exponential_avg_factor_); 145 xla::XlaOp beta = xla::ScalarLike(old_mean, exponential_avg_factor_); 146 // new_running_mean = alpha * old_mean + beta * batch_mean. 147 xla::XlaOp new_running_mean = 148 xla::Add(xla::Mul(old_mean, alpha), 149 xla::Mul(xla::GetTupleElement(output, 1), beta)); 150 ctx->SetOutput(1, new_running_mean); 151 152 xla::XlaOp old_variance = ctx->Input(4); 153 xla::XlaOp new_running_variance = xla::Add( 154 xla::Mul(old_variance, alpha), xla::Mul(corrected, beta)); 155 // new_running_variance = alpha * old_variance + beta * 156 // batch_variance. 157 ctx->SetOutput(2, new_running_variance); 158 } 159 } 160 161 // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved 162 // space 1 & 2". They are used to pass the per-batch mean and 163 // variance to the gradient. Here we maintain the same behavior by setting 164 // them to the mean and variance calculated by BatchNormTraining. 165 ctx->SetOutput(3, xla::GetTupleElement(output, 1)); 166 if (is_on_gpu_) { 167 // The last two outputs from the FusedBatchNorm training TensorFlow GPU 168 // op are implementation defined. For now we rely on the in-practice 169 // behavior of the op: 170 // output 3 is the mean 171 // output 4 is rsqrt(variance + epsilon) 172 ctx->SetOutput(4, xla::Rsqrt(xla::Add( 173 variance, xla::ScalarLike(variance, epsilon_)))); 174 } else { 175 ctx->SetOutput(4, variance); 176 } 177 } else { 178 xla::XlaOp output = xla::BatchNormInference( 179 input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), 180 epsilon_, feature_index); 181 182 xla::XlaOp converted = xla::ConvertElementType(output, input_type); 183 if (add_side_input_ && apply_relu_) { 184 ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted))); 185 } else if (apply_relu_) { 186 ctx->SetOutput(0, xla::Relu(converted)); 187 } else { 188 ctx->SetOutput(0, converted); 189 } 190 191 // Directly send input to output as mean and variance in inference mode. 192 ctx->SetOutput(1, ctx->Input(3)); 193 ctx->SetOutput(2, ctx->Input(4)); 194 ctx->SetOutput(3, ctx->Input(3)); 195 ctx->SetOutput(4, ctx->Input(4)); 196 } 197 } 198 199 private: 200 float epsilon_; 201 TensorFormat data_format_; 202 bool is_training_; 203 float exponential_avg_factor_; 204 bool add_side_input_; 205 bool apply_relu_; 206 bool is_on_gpu_; 207 }; 208 209 class FusedBatchNormOpV3 : public FusedBatchNormOp { 210 public: FusedBatchNormOpV3(OpKernelConstruction * ctx)211 explicit FusedBatchNormOpV3(OpKernelConstruction* ctx) 212 : FusedBatchNormOp(ctx) {} 213 Compile(XlaOpKernelContext * ctx)214 void Compile(XlaOpKernelContext* ctx) override { 215 FusedBatchNormOp::CompileImpl(ctx); 216 if (!ctx->status().ok()) { 217 return; 218 } 219 ctx->SetConstantOutput(5, Tensor()); 220 } 221 }; 222 223 class FusedBatchNormOpEx : public FusedBatchNormOp { 224 public: FusedBatchNormOpEx(OpKernelConstruction * ctx)225 explicit FusedBatchNormOpEx(OpKernelConstruction* ctx) 226 : FusedBatchNormOp(ctx, /*is_batch_norm_ex=*/true) {} 227 Compile(XlaOpKernelContext * ctx)228 void Compile(XlaOpKernelContext* ctx) override { 229 FusedBatchNormOp::CompileImpl(ctx); 230 if (!ctx->status().ok()) { 231 return; 232 } 233 ctx->SetConstantOutput(5, Tensor()); 234 } 235 }; 236 237 REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); 238 REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); 239 REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3); 240 REGISTER_XLA_OP(Name("_FusedBatchNormEx"), FusedBatchNormOpEx); 241 242 class FusedBatchNormGradOp : public XlaOpKernel { 243 public: FusedBatchNormGradOp(OpKernelConstruction * ctx)244 explicit FusedBatchNormGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 245 OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); 246 OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); 247 string data_format_str; 248 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); 249 OP_REQUIRES( 250 ctx, FormatFromString(data_format_str, &data_format_), 251 errors::InvalidArgument("Invalid data format: ", data_format_str)); 252 is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT; 253 } 254 Compile(XlaOpKernelContext * ctx)255 void Compile(XlaOpKernelContext* ctx) override { 256 xla::XlaBuilder* const b = ctx->builder(); 257 DataType input_dtype = ctx->input_type(0); 258 DataType scale_dtype = ctx->input_type(2); 259 260 // TODO(b/69928690): support mixed precision in the XLA batch normalization 261 // operators. For now, cast everything to the statistics type (which 262 // may be more precise than the input type). 263 auto grad_backprop = 264 XlaHelpers::ConvertElementType(ctx->Input(0), scale_dtype); 265 auto activations = 266 XlaHelpers::ConvertElementType(ctx->Input(1), scale_dtype); 267 auto scale = ctx->Input(2); 268 auto mean = ctx->Input(3); 269 auto var = ctx->Input(4); 270 271 const int input_dims = ctx->InputShape(0).dims(); 272 const int feature_index = 273 GetTensorFeatureDimIndex(input_dims, data_format_); 274 275 xla::XlaOp x_backprop; 276 xla::XlaOp scale_backprop; 277 xla::XlaOp offset_backprop; 278 if (is_training_) { 279 if (is_on_gpu_) { 280 // The last two inputs to the FusedBatchNormGrad training TensorFlow GPU 281 // op are implementation defined. For now we rely on the in-practice 282 // behavior of the op: input 3 is the mean input 4 is rsqrt(variance + 283 // epsilon) 284 // 285 // The XLA op expects: 286 // input 3 is the mean 287 // input 4 is the variance 288 // 289 // so we adjust input 4 here. 290 xla::XlaOp one = xla::ScalarLike(var, 1.0f); 291 xla::XlaOp epsilon = xla::ScalarLike(var, epsilon_); 292 var = xla::Sub(one / (var * var), epsilon); 293 } 294 295 xla::XlaOp output = 296 xla::BatchNormGrad(activations, scale, mean, var, grad_backprop, 297 epsilon_, feature_index); 298 299 x_backprop = xla::GetTupleElement(output, 0); 300 scale_backprop = xla::GetTupleElement(output, 1); 301 offset_backprop = xla::GetTupleElement(output, 2); 302 } else { 303 // Reduce over all dimensions except the feature dim. 304 std::vector<int64_t> reduction_dims(input_dims - 1); 305 std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index, 306 0); 307 std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(), 308 feature_index + 1); 309 // offset_backprop = sum(y_backprop) 310 // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + 311 // epsilon)) 312 // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) 313 const DataType accumulation_type = 314 XlaHelpers::SumAccumulationType(scale_dtype); 315 auto converted = 316 XlaHelpers::ConvertElementType(grad_backprop, accumulation_type); 317 auto reduce = 318 xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), 319 *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); 320 offset_backprop = XlaHelpers::ConvertElementType(reduce, scale_dtype); 321 322 // scratch1 = rsqrt(pop_var + epsilon) 323 auto epsilon = XlaHelpers::FloatLiteral(b, scale_dtype, epsilon_); 324 auto scratch1 = xla::Rsqrt(xla::Add(var, epsilon)); 325 326 // scratch2 = sum(y_backprop * (x - mean)) 327 auto mul = 328 xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index})); 329 converted = XlaHelpers::ConvertElementType(mul, accumulation_type); 330 reduce = 331 xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), 332 *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); 333 auto scratch2 = XlaHelpers::ConvertElementType(reduce, scale_dtype); 334 335 x_backprop = 336 xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index}); 337 scale_backprop = xla::Mul(scratch1, scratch2); 338 } 339 340 ctx->SetOutput(0, XlaHelpers::ConvertElementType(x_backprop, input_dtype)); 341 ctx->SetOutput(1, scale_backprop); 342 ctx->SetOutput(2, offset_backprop); 343 ctx->SetConstantOutput(3, Tensor()); 344 ctx->SetConstantOutput(4, Tensor()); 345 } 346 347 private: 348 TensorFormat data_format_; 349 float epsilon_; 350 bool is_training_; 351 bool is_on_gpu_; 352 }; 353 354 REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); 355 REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp); 356 REGISTER_XLA_OP(Name("FusedBatchNormGradV3"), FusedBatchNormGradOp); 357 358 } // namespace 359 } // namespace tensorflow 360