xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA implementation of BatchNorm operations.
17 #include <algorithm>
18 #include <numeric>
19 #include <string>
20 
21 #include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/lib/math.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/util/tensor_format.h"
31 
32 namespace tensorflow {
33 namespace {
34 
35 class FusedBatchNormOp : public XlaOpKernel {
36  public:
FusedBatchNormOp(OpKernelConstruction * ctx)37   explicit FusedBatchNormOp(OpKernelConstruction* ctx)
38       : FusedBatchNormOp(ctx, false) {}
39 
FusedBatchNormOp(OpKernelConstruction * ctx,bool is_batch_norm_ex)40   FusedBatchNormOp(OpKernelConstruction* ctx, bool is_batch_norm_ex)
41       : XlaOpKernel(ctx) {
42     OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
43     OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_));
44     OP_REQUIRES_OK(
45         ctx, ctx->GetAttr("exponential_avg_factor", &exponential_avg_factor_));
46     string data_format_str;
47     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
48     OP_REQUIRES(
49         ctx, FormatFromString(data_format_str, &data_format_),
50         errors::InvalidArgument("Invalid data format: ", data_format_str));
51 
52     if (is_batch_norm_ex) {
53       int num_side_inputs;
54       OP_REQUIRES_OK(ctx, ctx->GetAttr("num_side_inputs", &num_side_inputs));
55       OP_REQUIRES(ctx, num_side_inputs >= 0 && num_side_inputs <= 1,
56                   errors::InvalidArgument(
57                       "FusedBatchNormEx supports at most 1 side input."));
58       add_side_input_ = (num_side_inputs == 1);
59       string activation_mode;
60       OP_REQUIRES_OK(ctx, ctx->GetAttr("activation_mode", &activation_mode));
61       OP_REQUIRES(ctx,
62                   activation_mode == "Identity" || activation_mode == "Relu",
63                   errors::InvalidArgument(
64                       "Unsupported FusedBatchNormEx activation mode: ",
65                       activation_mode));
66       apply_relu_ = (activation_mode == "Relu");
67     } else {
68       add_side_input_ = false;
69       apply_relu_ = false;
70     }
71     is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
72   }
73 
Compile(XlaOpKernelContext * ctx)74   void Compile(XlaOpKernelContext* ctx) override { CompileImpl(ctx); }
75 
76  protected:
CompileImpl(XlaOpKernelContext * ctx)77   virtual void CompileImpl(XlaOpKernelContext* ctx) {
78     xla::XlaBuilder* const b = ctx->builder();
79     xla::PrimitiveType input_type;
80     OP_REQUIRES_OK(ctx,
81                    DataTypeToPrimitiveType(ctx->input_type(0), &input_type));
82     xla::PrimitiveType scale_type;
83     OP_REQUIRES_OK(ctx,
84                    DataTypeToPrimitiveType(ctx->input_type(1), &scale_type));
85 
86     xla::XlaOp input = ctx->Input(0);
87     TensorShape input_shape = ctx->InputShape(0);
88 
89     int feature_index =
90         GetTensorFeatureDimIndex(input_shape.dims(), data_format_);
91 
92     // TODO(b/69928690): support mixed precision in the XLA batch normalization
93     // operators. As a workaround, cast everything to the statistics type (which
94     // may be more precise than the input type).
95     input = xla::ConvertElementType(input, scale_type);
96 
97     if (is_training_) {
98       xla::XlaOp output = xla::BatchNormTraining(
99           input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
100 
101       // In training mode, outputs the normalized value as well as the
102       // calculated mean and variance. Optionally we add side input and apply
103       // relu activation.
104       xla::XlaOp converted =
105           xla::ConvertElementType(xla::GetTupleElement(output, 0), input_type);
106       if (add_side_input_ && apply_relu_) {
107         ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted)));
108       } else if (apply_relu_) {
109         ctx->SetOutput(0, xla::Relu(converted));
110       } else {
111         ctx->SetOutput(0, converted);
112       }
113 
114       xla::XlaOp variance = xla::GetTupleElement(output, 2);
115       // Apply Bessel's correction.
116       int total_input_size = ctx->InputShape(0).num_elements();
117       int total_scale_size = ctx->InputShape(1).num_elements();
118       int sample_size =
119           total_scale_size > 0 ? total_input_size / total_scale_size : 0;
120       int sample_size_minus_one = std::max(1, sample_size - 1);
121       double factor = static_cast<double>(sample_size) /
122                       static_cast<double>(sample_size_minus_one);
123 
124       constexpr int kVarianceOutputIndex = 2;
125       xla::XlaOp corrected =
126           xla::Mul(variance, xla::ScalarLike(variance, factor));
127       if (input_shape.num_elements() == 0) {
128         auto status_or_output_shape = b->GetShape(corrected);
129         OP_REQUIRES_OK(ctx, status_or_output_shape.status());
130         ctx->SetOutput(1, xla::GetTupleElement(output, 1));
131         ctx->SetOutput(
132             kVarianceOutputIndex,
133             xla::Broadcast(
134                 xla::NanValue(b, ctx->output_xla_type(kVarianceOutputIndex)),
135                 status_or_output_shape.ValueOrDie().dimensions()));
136 
137       } else {
138         if (exponential_avg_factor_ == 1.0f) {
139           ctx->SetOutput(1, xla::GetTupleElement(output, 1));
140           ctx->SetOutput(2, corrected);
141         } else {
142           xla::XlaOp old_mean = ctx->Input(3);
143           xla::XlaOp alpha =
144               xla::ScalarLike(old_mean, 1.0f - exponential_avg_factor_);
145           xla::XlaOp beta = xla::ScalarLike(old_mean, exponential_avg_factor_);
146           // new_running_mean = alpha * old_mean + beta * batch_mean.
147           xla::XlaOp new_running_mean =
148               xla::Add(xla::Mul(old_mean, alpha),
149                        xla::Mul(xla::GetTupleElement(output, 1), beta));
150           ctx->SetOutput(1, new_running_mean);
151 
152           xla::XlaOp old_variance = ctx->Input(4);
153           xla::XlaOp new_running_variance = xla::Add(
154               xla::Mul(old_variance, alpha), xla::Mul(corrected, beta));
155           // new_running_variance = alpha * old_variance + beta *
156           // batch_variance.
157           ctx->SetOutput(2, new_running_variance);
158         }
159       }
160 
161       // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
162       // space 1 & 2". They are used to pass the per-batch mean and
163       // variance to the gradient. Here we maintain the same behavior by setting
164       // them to the mean and variance calculated by BatchNormTraining.
165       ctx->SetOutput(3, xla::GetTupleElement(output, 1));
166       if (is_on_gpu_) {
167         // The last two outputs from the FusedBatchNorm training TensorFlow GPU
168         // op are implementation defined.  For now we rely on the in-practice
169         // behavior of the op:
170         //   output 3 is the mean
171         //   output 4 is rsqrt(variance + epsilon)
172         ctx->SetOutput(4, xla::Rsqrt(xla::Add(
173                               variance, xla::ScalarLike(variance, epsilon_))));
174       } else {
175         ctx->SetOutput(4, variance);
176       }
177     } else {
178       xla::XlaOp output = xla::BatchNormInference(
179           input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
180           epsilon_, feature_index);
181 
182       xla::XlaOp converted = xla::ConvertElementType(output, input_type);
183       if (add_side_input_ && apply_relu_) {
184         ctx->SetOutput(0, xla::Relu(xla::Add(ctx->Input(5), converted)));
185       } else if (apply_relu_) {
186         ctx->SetOutput(0, xla::Relu(converted));
187       } else {
188         ctx->SetOutput(0, converted);
189       }
190 
191       // Directly send input to output as mean and variance in inference mode.
192       ctx->SetOutput(1, ctx->Input(3));
193       ctx->SetOutput(2, ctx->Input(4));
194       ctx->SetOutput(3, ctx->Input(3));
195       ctx->SetOutput(4, ctx->Input(4));
196     }
197   }
198 
199  private:
200   float epsilon_;
201   TensorFormat data_format_;
202   bool is_training_;
203   float exponential_avg_factor_;
204   bool add_side_input_;
205   bool apply_relu_;
206   bool is_on_gpu_;
207 };
208 
209 class FusedBatchNormOpV3 : public FusedBatchNormOp {
210  public:
FusedBatchNormOpV3(OpKernelConstruction * ctx)211   explicit FusedBatchNormOpV3(OpKernelConstruction* ctx)
212       : FusedBatchNormOp(ctx) {}
213 
Compile(XlaOpKernelContext * ctx)214   void Compile(XlaOpKernelContext* ctx) override {
215     FusedBatchNormOp::CompileImpl(ctx);
216     if (!ctx->status().ok()) {
217       return;
218     }
219     ctx->SetConstantOutput(5, Tensor());
220   }
221 };
222 
223 class FusedBatchNormOpEx : public FusedBatchNormOp {
224  public:
FusedBatchNormOpEx(OpKernelConstruction * ctx)225   explicit FusedBatchNormOpEx(OpKernelConstruction* ctx)
226       : FusedBatchNormOp(ctx, /*is_batch_norm_ex=*/true) {}
227 
Compile(XlaOpKernelContext * ctx)228   void Compile(XlaOpKernelContext* ctx) override {
229     FusedBatchNormOp::CompileImpl(ctx);
230     if (!ctx->status().ok()) {
231       return;
232     }
233     ctx->SetConstantOutput(5, Tensor());
234   }
235 };
236 
237 REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
238 REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
239 REGISTER_XLA_OP(Name("FusedBatchNormV3"), FusedBatchNormOpV3);
240 REGISTER_XLA_OP(Name("_FusedBatchNormEx"), FusedBatchNormOpEx);
241 
242 class FusedBatchNormGradOp : public XlaOpKernel {
243  public:
FusedBatchNormGradOp(OpKernelConstruction * ctx)244   explicit FusedBatchNormGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
245     OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_));
246     OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_));
247     string data_format_str;
248     OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
249     OP_REQUIRES(
250         ctx, FormatFromString(data_format_str, &data_format_),
251         errors::InvalidArgument("Invalid data format: ", data_format_str));
252     is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
253   }
254 
Compile(XlaOpKernelContext * ctx)255   void Compile(XlaOpKernelContext* ctx) override {
256     xla::XlaBuilder* const b = ctx->builder();
257     DataType input_dtype = ctx->input_type(0);
258     DataType scale_dtype = ctx->input_type(2);
259 
260     // TODO(b/69928690): support mixed precision in the XLA batch normalization
261     // operators. For now, cast everything to the statistics type (which
262     // may be more precise than the input type).
263     auto grad_backprop =
264         XlaHelpers::ConvertElementType(ctx->Input(0), scale_dtype);
265     auto activations =
266         XlaHelpers::ConvertElementType(ctx->Input(1), scale_dtype);
267     auto scale = ctx->Input(2);
268     auto mean = ctx->Input(3);
269     auto var = ctx->Input(4);
270 
271     const int input_dims = ctx->InputShape(0).dims();
272     const int feature_index =
273         GetTensorFeatureDimIndex(input_dims, data_format_);
274 
275     xla::XlaOp x_backprop;
276     xla::XlaOp scale_backprop;
277     xla::XlaOp offset_backprop;
278     if (is_training_) {
279       if (is_on_gpu_) {
280         // The last two inputs to the FusedBatchNormGrad training TensorFlow GPU
281         // op are implementation defined.  For now we rely on the in-practice
282         // behavior of the op: input 3 is the mean input 4 is rsqrt(variance +
283         // epsilon)
284         //
285         // The XLA op expects:
286         //   input 3 is the mean
287         //   input 4 is the variance
288         //
289         // so we adjust input 4 here.
290         xla::XlaOp one = xla::ScalarLike(var, 1.0f);
291         xla::XlaOp epsilon = xla::ScalarLike(var, epsilon_);
292         var = xla::Sub(one / (var * var), epsilon);
293       }
294 
295       xla::XlaOp output =
296           xla::BatchNormGrad(activations, scale, mean, var, grad_backprop,
297                              epsilon_, feature_index);
298 
299       x_backprop = xla::GetTupleElement(output, 0);
300       scale_backprop = xla::GetTupleElement(output, 1);
301       offset_backprop = xla::GetTupleElement(output, 2);
302     } else {
303       // Reduce over all dimensions except the feature dim.
304       std::vector<int64_t> reduction_dims(input_dims - 1);
305       std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index,
306                 0);
307       std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(),
308                 feature_index + 1);
309       // offset_backprop  = sum(y_backprop)
310       // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var +
311       // epsilon))
312       // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
313       const DataType accumulation_type =
314           XlaHelpers::SumAccumulationType(scale_dtype);
315       auto converted =
316           XlaHelpers::ConvertElementType(grad_backprop, accumulation_type);
317       auto reduce =
318           xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
319                       *ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
320       offset_backprop = XlaHelpers::ConvertElementType(reduce, scale_dtype);
321 
322       // scratch1 = rsqrt(pop_var + epsilon)
323       auto epsilon = XlaHelpers::FloatLiteral(b, scale_dtype, epsilon_);
324       auto scratch1 = xla::Rsqrt(xla::Add(var, epsilon));
325 
326       // scratch2 = sum(y_backprop * (x - mean))
327       auto mul =
328           xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index}));
329       converted = XlaHelpers::ConvertElementType(mul, accumulation_type);
330       reduce =
331           xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
332                       *ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
333       auto scratch2 = XlaHelpers::ConvertElementType(reduce, scale_dtype);
334 
335       x_backprop =
336           xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index});
337       scale_backprop = xla::Mul(scratch1, scratch2);
338     }
339 
340     ctx->SetOutput(0, XlaHelpers::ConvertElementType(x_backprop, input_dtype));
341     ctx->SetOutput(1, scale_backprop);
342     ctx->SetOutput(2, offset_backprop);
343     ctx->SetConstantOutput(3, Tensor());
344     ctx->SetConstantOutput(4, Tensor());
345   }
346 
347  private:
348   TensorFormat data_format_;
349   float epsilon_;
350   bool is_training_;
351   bool is_on_gpu_;
352 };
353 
354 REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp);
355 REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp);
356 REGISTER_XLA_OP(Name("FusedBatchNormGradV3"), FusedBatchNormGradOp);
357 
358 }  // namespace
359 }  // namespace tensorflow
360