xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/batch_normalization_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <memory>
18 #include <vector>
19 
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/array4d.h"
23 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/reference_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/test.h"
36 #include "tensorflow/compiler/xla/test_helpers.h"
37 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
39 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
40 #include "tensorflow/compiler/xla/tests/test_macros.h"
41 #include "tensorflow/compiler/xla/tests/test_utils.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/math/math_util.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/test.h"
48 
49 namespace xla {
50 namespace {
51 
52 class BatchNormalizationTest : public ClientLibraryTestBase {
53  protected:
BatchNormalizationTest()54   BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
55     Array2D<float> pz({
56         // z0 z1
57         {-1.0f, 4.1f},  // p0
58         {2.0f, 4.1f},   // p1
59         {5.0f, 4.4f},   // p2
60     });
61     input_array_.FillWithPZ(pz);
62     input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
63     CHECK_EQ(kSamples, input_array_.planes());
64     CHECK_EQ(kZ, input_array_.depth());
65     CHECK_EQ(kY, input_array_.height());
66     CHECK_EQ(kY, input_array_.width());
67   }
68 
CheckShape(XlaBuilder * b,const XlaOp operand,const Shape & expected_shape) const69   XlaOp CheckShape(XlaBuilder* b, const XlaOp operand,
70                    const Shape& expected_shape) const {
71     Shape actual_shape = b->GetShape(operand).value();
72     CHECK(ShapeUtil::Equal(expected_shape, actual_shape))
73         << "want " << ShapeUtil::HumanString(expected_shape) << " got "
74         << ShapeUtil::HumanString(actual_shape);
75     return operand;
76   }
77 
78   static constexpr int64_t kSamples = 3;
79   static constexpr int64_t kX = 1;
80   static constexpr int64_t kY = 1;
81   static constexpr int64_t kZ = 2;
82 
83   Array4D<float> input_array_;
84   Literal input_literal_;
85   const ErrorSpec error_spec_{0.001, 0.001};
86 };
87 
XLA_TEST_F(BatchNormalizationTest,SubtractInZ)88 XLA_TEST_F(BatchNormalizationTest, SubtractInZ) {
89   XlaBuilder builder("subtract_in_z_one_sample");
90   auto x = ConstantLiteral(&builder, input_literal_);
91   auto y = ConstantR1<float>(&builder, {3.14, 4.25});
92   Sub(x, y, /*broadcast_dimensions=*/{1});
93 
94   Array4D<float> expected(kSamples, kZ, kY, kX);
95   Array2D<float> pz({
96       {-1.0f - 3.14f, 4.1f - 4.25f},  // p0
97       {2.0f - 3.14f, 4.1f - 4.25f},   // p1
98       {5.0f - 3.14f, 4.4f - 4.25f},   // p2
99   });
100   expected.FillWithPZ(pz);
101   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
102 }
103 
XLA_TEST_F(BatchNormalizationTest,SquareTesseractElementwise)104 XLA_TEST_F(BatchNormalizationTest, SquareTesseractElementwise) {
105   XlaBuilder builder("square_tesseract_elementwise");
106   auto x = ConstantLiteral(&builder, input_literal_);
107   Square(x);
108 
109   using tensorflow::MathUtil;
110 
111   Array4D<float> expected(kSamples, kZ, kY, kX);
112   Array2D<float> expected_pz({
113       {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)},
114       {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)},
115       {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)},
116   });
117   expected.FillWithPZ(expected_pz);
118   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
119 }
120 
XLA_TEST_F(BatchNormalizationTest,SumToZ)121 XLA_TEST_F(BatchNormalizationTest, SumToZ) {
122   XlaBuilder builder("sum_to_z");
123   auto input_activations = ConstantLiteral(&builder, input_literal_);
124   XlaComputation add = CreateScalarAddComputation(F32, &builder);
125   // Reduce all but the Z dimension.
126   Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
127 
128   std::vector<float> expected = {6, 12.6};
129   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
130 }
131 
XLA_TEST_F(BatchNormalizationTest,SquareAndReduce)132 XLA_TEST_F(BatchNormalizationTest, SquareAndReduce) {
133   XlaBuilder builder("square_and_reduce");
134   auto input_activations = ConstantLiteral(&builder, input_literal_);
135   auto set_means = ConstantR1<float>(&builder, {2.f, 4.2f});
136   auto activation_deviations = Sub(input_activations, set_means,
137                                    /*broadcast_dimensions=*/{1});
138   XlaComputation add = CreateScalarAddComputation(F32, &builder);
139   auto dev_squares = Square(activation_deviations);
140   Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
141 
142   std::vector<float> expected = {18, 0.06};
143   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
144 }
145 
XLA_TEST_F(BatchNormalizationTest,VarianceToStddev)146 XLA_TEST_F(BatchNormalizationTest, VarianceToStddev) {
147   XlaBuilder builder("variance_to_stddev");
148   auto variance = ConstantR1<float>(&builder, {6.f, .02f});
149   Sqrt(variance);
150 
151   std::vector<float> expected = {2.44948974f, 0.14142136f};
152   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
153 }
154 
155 // Compare against a forward batch normalization example in the NN spec
156 // reference.
XLA_TEST_F(BatchNormalizationTest,SpecComparisonForward)157 XLA_TEST_F(BatchNormalizationTest, SpecComparisonForward) {
158   XlaBuilder builder("batch_normalize_per_spec");
159   auto input_activations =
160       CheckShape(&builder, ConstantLiteral(&builder, input_literal_),
161                  ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
162   auto gamma = ConstantR1<float>(&builder, {1.0, 1.0});
163   auto beta = ConstantR1<float>(&builder, {0.0, 0.0});
164   XlaComputation add = CreateScalarAddComputation(F32, &builder);
165   // Reduce all dimensions except dimension 1.
166   Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
167   auto sum = CheckShape(
168       &builder,
169       Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
170              /*dimensions_to_reduce=*/{0, 2, 3}),
171       TwoElementVectorF32);
172   auto input_shape = builder.GetShape(input_activations).value();
173   auto sum_shape = builder.GetShape(sum).value();
174   auto count =
175       ConstantR0<float>(&builder, ShapeUtil::ElementsIn(input_shape) /
176                                       ShapeUtil::ElementsIn(sum_shape));
177   auto set_means = Div(sum, count);
178 
179   const float kEpsilon = 1e-9f;
180   auto epsilon = ConstantR0<float>(&builder, kEpsilon);
181   auto epsilon2 = ConstantR1<float>(&builder, {kEpsilon, kEpsilon});
182   auto activation_deviations = Sub(input_activations, set_means,
183                                    /*broadcast_dimensions=*/{1});
184   auto dev_squares = Square(activation_deviations);
185   auto sum_of_squares =
186       CheckShape(&builder,
187                  Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add,
188                         /*dimensions_to_reduce=*/{0, 2, 3}),
189                  TwoElementVectorF32);
190   auto variance = Div(sum_of_squares, count);
191   auto standard_deviation = Sqrt(variance);
192   auto standard_deviation_above_epsilon =
193       CheckShape(&builder, Gt(standard_deviation, epsilon),
194                  ShapeUtil::MakeShape(PRED, {2}));
195   auto gt_eps =
196       Select(standard_deviation_above_epsilon, standard_deviation, epsilon2);
197   auto normalization_factors = Reciprocal(gt_eps);
198   auto normalized_input_activations =
199       Mul(activation_deviations, normalization_factors,
200           /*broadcast_dimensions=*/{1});
201   /* auto output_activations = */ Add(Mul(normalized_input_activations, gamma,
202                                           /*broadcast_dimensions=*/{1}),
203                                       beta, /*broadcast_dimensions=*/{1});
204 
205   Array4D<float> expected(kSamples, kZ, kY, kX);
206   Array2D<float> pz({
207       {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
208       {0.f, -.1f / std::sqrt(.02f)},
209       {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
210   });
211   expected.FillWithPZ(pz);
212 
213   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
214 }
215 
XLA_TEST_F(BatchNormalizationTest,BasicTraining)216 XLA_TEST_F(BatchNormalizationTest, BasicTraining) {
217   const int kFeatureIndex = 3;
218   XlaBuilder builder(TestName());
219 
220   auto operand = ConstantR4FromArray4D<float>(
221       &builder, {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
222 
223   auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
224 
225   auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
226 
227   BatchNormTraining(operand, scale, offset,
228                     /*epsilon=*/0.001, kFeatureIndex);
229 
230   auto expected = LiteralUtil::MakeTupleFromSlices(
231       {LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
232                                      {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
233        LiteralUtil::CreateR1<float>({4, 5}),
234        LiteralUtil::CreateR1<float>({5, 5})});
235 
236   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
237 }
238 
XLA_TEST_F(BatchNormalizationTest,BasicTraining_fp16)239 XLA_TEST_F(BatchNormalizationTest, BasicTraining_fp16) {
240   const int kFeatureIndex = 3;
241   XlaBuilder builder(TestName());
242   Array4D<Eigen::half> input = {{{{1.f, 2.f}}, {{3.f, 4.f}}},
243                                 {{{5.f, 6.f}}, {{7.f, 8.f}}}};
244   auto operand = ConstantR4FromArray4D<Eigen::half>(&builder, input);
245 
246   auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
247 
248   auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
249 
250   auto input_f32 = ConvertElementType(operand, F32);
251 
252   auto output = BatchNormTraining(input_f32, scale, offset,
253                                   /*epsilon=*/0.001, kFeatureIndex);
254 
255   auto converted = ConvertElementType(GetTupleElement(output, 0), F16);
256   Tuple(&builder,
257         {converted, GetTupleElement(output, 1), GetTupleElement(output, 2)});
258 
259   Array4D<Eigen::half> out = {{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
260                               {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}};
261 
262   auto expected = LiteralUtil::MakeTupleFromSlices(
263       {LiteralUtil::CreateFromArray(out), LiteralUtil::CreateR1<float>({4, 5}),
264        LiteralUtil::CreateR1<float>({5, 5})});
265 
266   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
267 }
268 
XLA_TEST_F(BatchNormalizationTest,BasicTrainingOnDimension2)269 XLA_TEST_F(BatchNormalizationTest, BasicTrainingOnDimension2) {
270   const int kFeatureIndex = 2;
271   XlaBuilder builder(TestName());
272 
273   auto operand = ConstantR4FromArray4D<float>(
274       &builder,
275       {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
276 
277   auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
278 
279   auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
280 
281   BatchNormTraining(operand, scale, offset,
282                     /*epsilon=*/0.001, kFeatureIndex);
283 
284   auto expected = LiteralUtil::MakeTupleFromSlices(
285       {LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
286                                      {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
287        LiteralUtil::CreateR1<float>({4, 5}),
288        LiteralUtil::CreateR1<float>({5, 5})});
289 
290   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
291 }
292 
XLA_TEST_F(BatchNormalizationTest,BasicTrainingOnDimension2_fp16)293 XLA_TEST_F(BatchNormalizationTest, BasicTrainingOnDimension2_fp16) {
294   const int kFeatureIndex = 2;
295   XlaBuilder builder(TestName());
296   Array4D<Eigen::half> input = {{{{1.f}, {2.f}}, {{3.f}, {4.f}}},
297                                 {{{5.f}, {6.f}}, {{7.f}, {8.f}}}};
298   auto operand = ConstantR4FromArray4D<Eigen::half>(&builder, input);
299 
300   auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
301 
302   auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
303 
304   auto input_f32 = ConvertElementType(operand, F32);
305 
306   auto output = BatchNormTraining(input_f32, scale, offset,
307                                   /*epsilon=*/0.001, kFeatureIndex);
308 
309   auto converted = ConvertElementType(GetTupleElement(output, 0), F16);
310   Tuple(&builder,
311         {converted, GetTupleElement(output, 1), GetTupleElement(output, 2)});
312 
313   Array4D<Eigen::half> out = {{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
314                               {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}};
315 
316   auto expected = LiteralUtil::MakeTupleFromSlices(
317       {LiteralUtil::CreateFromArray(out), LiteralUtil::CreateR1<float>({4, 5}),
318        LiteralUtil::CreateR1<float>({5, 5})});
319 
320   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
321 }
322 
XLA_TEST_F(BatchNormalizationTest,TrainingWithFeatureOnLowDimension)323 XLA_TEST_F(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
324   // Use 0 dimension as feature, tests layout analyzer.
325   const int kFeatureIndex = 0;
326   XlaBuilder builder(TestName());
327 
328   XlaOp h0;
329   auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f),
330                                           /*parameter_number=*/0, "operand",
331                                           &builder, &h0);
332   XlaOp h1;
333   auto scale =
334       CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
335                                /*parameter_number=*/1, "scale", &builder, &h1);
336   XlaOp h2;
337   auto offset =
338       CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
339                                /*parameter_number=*/2, "offset", &builder, &h2);
340 
341   BatchNormTraining(h0, h1, h2,
342                     /*epsilon=*/1, kFeatureIndex);
343 
344   auto expected = LiteralUtil::MakeTupleFromSlices(
345       {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
346        LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
347        LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
348 
349   ComputeAndCompareTuple(&builder, expected,
350                          {operand.get(), scale.get(), offset.get()},
351                          ErrorSpec(0.1));
352 }
353 
XLA_TEST_F(BatchNormalizationTest,LargeEpsilonTest)354 XLA_TEST_F(BatchNormalizationTest, LargeEpsilonTest) {
355   // Test the correctness of choosing a large epsilon value.
356   const int kFeatureIndex = 2;
357   XlaBuilder builder(TestName());
358 
359   XlaOp h0;
360   auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}},
361                                           /*parameter_number=*/0, "operand",
362                                           &builder, &h0);
363   XlaOp h1;
364   auto scale =
365       CreateR1Parameter<float>(std::vector<float>(1, 1.0f),
366                                /*parameter_number=*/1, "scale", &builder, &h1);
367   XlaOp h2;
368   auto offset =
369       CreateR1Parameter<float>(std::vector<float>(1, 0.0f),
370                                /*parameter_number=*/2, "offset", &builder, &h2);
371 
372   // var = 125, mean = 15, epsilon = -100
373   BatchNormTraining(h0, h1, h2,
374                     /*epsilon=*/-100, kFeatureIndex);
375 
376   auto expected = LiteralUtil::MakeTupleFromSlices(
377       {LiteralUtil::CreateR3FromArray3D<float>(
378            {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
379        LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
380        LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
381 
382   ComputeAndCompareTuple(&builder, expected,
383                          {operand.get(), scale.get(), offset.get()},
384                          ErrorSpec(0.1));
385 }
386 
XLA_TEST_F(BatchNormalizationTest,BatchNormGradBasic)387 XLA_TEST_F(BatchNormalizationTest, BatchNormGradBasic) {
388   const int kFeatureIndex = 2;
389   XlaBuilder builder(TestName());
390 
391   auto operand =
392       ConstantR4FromArray4D<float>(&builder, Array4D<float>(2, 2, 2, 1, 0.0f));
393 
394   auto scale = ConstantR1<float>(&builder, {1.0f, 1.0f});
395 
396   auto mean = ConstantR1<float>(&builder, {0.0f, 0.0f});
397 
398   auto var = ConstantR1<float>(&builder, {1.0f, 1.0f});
399 
400   auto grad_output = ConstantR4FromArray4D<float>(
401       &builder,
402       {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
403 
404   BatchNormGrad(operand, scale, mean, var, grad_output,
405                 /*epsilon=*/0.0, kFeatureIndex);
406 
407   auto expected = LiteralUtil::MakeTupleFromSlices(
408       {LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
409                                      {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
410        LiteralUtil::CreateR1<float>({0, 0}),
411        LiteralUtil::CreateR1<float>({16, 20})});
412 
413   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
414 }
415 
XLA_TEST_F(BatchNormalizationTest,BatchNormGradBasic_fp16)416 XLA_TEST_F(BatchNormalizationTest, BatchNormGradBasic_fp16) {
417   const int kFeatureIndex = 2;
418   XlaBuilder builder(TestName());
419   auto operand = ConstantR4FromArray4D<Eigen::half>(
420       &builder,
421       Array4D<Eigen::half>(2, 2, 2, 1, static_cast<Eigen::half>(0.0f)));
422 
423   auto operand_f32 = ConvertElementType(operand, F32);
424 
425   auto scale = ConstantR1<float>(&builder, {1.0f, 1.0f});
426 
427   auto mean = ConstantR1<float>(&builder, {0.0f, 0.0f});
428 
429   auto var = ConstantR1<float>(&builder, {1.0f, 1.0f});
430 
431   auto grad_output = ConstantR4FromArray4D<Eigen::half>(
432       &builder,
433       {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
434 
435   auto grad_output_f32 = ConvertElementType(grad_output, F32);
436 
437   auto output = BatchNormGrad(operand_f32, scale, mean, var, grad_output_f32,
438                               /*epsilon=*/0.001, kFeatureIndex);
439 
440   auto converted_output = ConvertElementType(GetTupleElement(output, 0), F16);
441   Tuple(&builder, {converted_output, GetTupleElement(output, 1),
442                    GetTupleElement(output, 2)});
443 
444   Array4D<Eigen::half> out = {{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
445                               {{{1.f}, {1.f}}, {{3.f}, {3.f}}}};
446   auto expected = LiteralUtil::MakeTupleFromSlices(
447       {LiteralUtil::CreateFromArray(out), LiteralUtil::CreateR1<float>({0, 0}),
448        LiteralUtil::CreateR1<float>({16, 20})});
449 
450   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
451 }
452 
453 struct BatchNormTestParam {
454   std::vector<int64_t> bounds;
455   int64_t feature_index;
456   float random_value_mean;
457   float random_value_var;
458 
operator <<(::std::ostream & os,const BatchNormTestParam & p)459   friend ::std::ostream& operator<<(::std::ostream& os,
460                                     const BatchNormTestParam& p) {
461     os << "bounds={" << absl::StrJoin(p.bounds, ", ") << "}, ";
462     os << "feature_index=" << p.feature_index << ", ";
463     os << "random_value_mean=" << p.random_value_mean << ", ";
464     os << "random_value_var=" << p.random_value_var;
465     return os;
466   }
467 };
468 
469 // Tests to test the fused operation of BatchNorm.
470 class BatchNormTestManySizes
471     : public ClientLibraryTestBase,
472       public ::testing::WithParamInterface<BatchNormTestParam> {};
473 
BuildBatchNormTestParams()474 std::vector<BatchNormTestParam> BuildBatchNormTestParams() {
475   std::vector<BatchNormTestParam> params;
476 
477   auto add_testcase = [&](std::vector<int64_t> bounds, int64_t feature_index,
478                           float random_value_mean, float random_value_var) {
479     BatchNormTestParam p{bounds, feature_index, random_value_mean,
480                          random_value_var};
481     params.push_back(p);
482   };
483 
484   add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f);
485   add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f);
486 
487   add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f);
488   add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f);
489   add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f);
490   add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f);
491   add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f);
492   add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f);
493   add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f);
494   add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f);
495 
496   add_testcase({24, 129, 1, 2}, 2, 10000, 10000);
497   add_testcase({24, 129, 1, 2}, 3, 10000, 10000);
498 
499   // Feature on low dimension to trigger relayout, check that internal logical
500   // to physical dimension calculation is correct after relayout.
501   add_testcase({1, 2, 3, 4}, 0, 100, 100);
502 
503   // Zero-sized tensor.
504   add_testcase({1, 0, 100, 42}, 0, 100, 100);
505 
506   return params;
507 }
508 
509 INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes,
510                         ::testing::ValuesIn(BuildBatchNormTestParams()));
511 
XLA_TEST_P(BatchNormTestManySizes,RandomizedTrainingTests)512 XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
513   float epsilon = 0.001;
514   XlaBuilder builder(TestName());
515   const std::vector<int64_t>& bounds = GetParam().bounds;
516   Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
517   input_array.FillRandom(GetParam().random_value_var,
518                          GetParam().random_value_mean);
519 
520   const int64_t feature_index = GetParam().feature_index;
521   const int64_t num_elements_per_feature =
522       Product(bounds) / bounds[feature_index];
523   const int64_t feature_bound = bounds[feature_index];
524   std::vector<float> offset(feature_bound, 1);
525   std::vector<float> scale(feature_bound, 2);
526 
527   auto input_squared =
528       ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
529   std::vector<int64_t> reduce_dims;
530   for (int64_t i = 0; i < static_cast<int64_t>(bounds.size()); ++i) {
531     if (i != feature_index) {
532       reduce_dims.push_back(i);
533     }
534   }
535 
536   auto sum =
537       ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
538                                   [](float a, float b) { return a + b; });
539 
540   auto sum_squared =
541       ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
542                                   [](float a, float b) { return a + b; });
543 
544   std::vector<float> mean(feature_bound);
545 
546   for (int64_t i = 0; i < feature_bound; ++i) {
547     mean[i] = sum[i] / num_elements_per_feature;
548   }
549 
550   std::vector<float> mean_square(feature_bound);
551   for (int64_t i = 0; i < feature_bound; ++i) {
552     mean_square[i] = mean[i] * mean[i];
553   }
554 
555   std::vector<float> square_mean(feature_bound);
556   for (int64_t i = 0; i < feature_bound; ++i) {
557     square_mean[i] = sum_squared[i] / num_elements_per_feature;
558   }
559 
560   std::vector<float> var(feature_bound);
561   for (int64_t i = 0; i < feature_bound; ++i) {
562     var[i] = square_mean[i] - mean_square[i];
563   }
564 
565   Array4D<float> mean4D =
566       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
567   auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
568   auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
569   auto offset4D =
570       *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
571 
572   auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
573                                                 scale4D, offset4D, epsilon);
574 
575   auto expected_normalized =
576       LiteralUtil::CreateR4FromArray4D<float>(normalized);
577 
578   auto offset_literal = LiteralUtil::CreateR1<float>(offset);
579   auto scale_literal = LiteralUtil::CreateR1<float>(scale);
580   auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
581 
582   auto input_activations =
583       Parameter(&builder, 0, input_literal.shape(), "input");
584   auto scale_activations =
585       Parameter(&builder, 1, scale_literal.shape(), "offset");
586   auto offset_activations =
587       Parameter(&builder, 2, offset_literal.shape(), "scale");
588 
589   auto expected = LiteralUtil::MakeTupleFromSlices(
590       {expected_normalized, LiteralUtil::CreateR1<float>(mean),
591        LiteralUtil::CreateR1<float>(var)});
592 
593   std::unique_ptr<GlobalData> input_data =
594       client_->TransferToServer(input_literal).value();
595   std::unique_ptr<GlobalData> scale_data =
596       client_->TransferToServer(scale_literal).value();
597   std::unique_ptr<GlobalData> offset_data =
598       client_->TransferToServer(offset_literal).value();
599 
600   BatchNormTraining(input_activations, scale_activations, offset_activations,
601                     epsilon, feature_index);
602 
603   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
604   // disables constant folding, but we want it enabled for our zero-sized tensor
605   // testcase.
606   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
607   ComputeAndCompareTuple(
608       &builder, expected,
609       {input_data.get(), scale_data.get(), offset_data.get()},
610       ErrorSpec(0.01, 1));
611 }
612 
XLA_TEST_P(BatchNormTestManySizes,RandomizedInferencingTests)613 XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
614   float epsilon = 0.001;
615   XlaBuilder builder(TestName());
616   const std::vector<int64_t>& bounds = GetParam().bounds;
617   Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
618   input_array.FillRandom(GetParam().random_value_var,
619                          GetParam().random_value_mean);
620 
621   const int64_t feature_index = GetParam().feature_index;
622   const int64_t num_elements_per_feature =
623       Product(bounds) / bounds[feature_index];
624   const int64_t feature_bound = bounds[feature_index];
625   std::vector<float> offset(feature_bound, 1);
626   std::vector<float> scale(feature_bound, 2);
627 
628   auto input_squared =
629       ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
630   std::vector<int64_t> reduce_dims;
631   for (int64_t i = 0; i < static_cast<int64_t>(bounds.size()); ++i) {
632     if (i != feature_index) {
633       reduce_dims.push_back(i);
634     }
635   }
636 
637   auto sum =
638       ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
639                                   [](float a, float b) { return a + b; });
640 
641   auto sum_squared =
642       ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
643                                   [](float a, float b) { return a + b; });
644 
645   std::vector<float> mean(feature_bound);
646 
647   for (int64_t i = 0; i < feature_bound; ++i) {
648     mean[i] = sum[i] / num_elements_per_feature;
649   }
650 
651   std::vector<float> mean_square(feature_bound);
652   for (int64_t i = 0; i < feature_bound; ++i) {
653     mean_square[i] = mean[i] * mean[i];
654   }
655 
656   std::vector<float> square_mean(feature_bound);
657   for (int64_t i = 0; i < feature_bound; ++i) {
658     square_mean[i] = sum_squared[i] / num_elements_per_feature;
659   }
660 
661   std::vector<float> var(feature_bound);
662   for (int64_t i = 0; i < feature_bound; ++i) {
663     var[i] = square_mean[i] - mean_square[i];
664   }
665 
666   Array4D<float> mean4D =
667       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
668   auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
669   auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
670   auto offset4D =
671       *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
672 
673   auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
674                                                 scale4D, offset4D, epsilon);
675 
676   auto offset_literal = LiteralUtil::CreateR1<float>(offset);
677   auto scale_literal = LiteralUtil::CreateR1<float>(scale);
678   auto mean_literal = LiteralUtil::CreateR1<float>(mean);
679   auto var_literal = LiteralUtil::CreateR1<float>(var);
680   auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
681 
682   auto input_activations =
683       Parameter(&builder, 0, input_literal.shape(), "input");
684   auto scale_activations =
685       Parameter(&builder, 1, scale_literal.shape(), "offset");
686   auto offset_activations =
687       Parameter(&builder, 2, offset_literal.shape(), "scale");
688   auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
689   auto variance_activations =
690       Parameter(&builder, 4, var_literal.shape(), "variance");
691 
692   Array4D<float> expected = normalized;
693 
694   std::unique_ptr<GlobalData> input_data =
695       client_->TransferToServer(input_literal).value();
696   std::unique_ptr<GlobalData> scale_data =
697       client_->TransferToServer(scale_literal).value();
698   std::unique_ptr<GlobalData> offset_data =
699       client_->TransferToServer(offset_literal).value();
700   std::unique_ptr<GlobalData> mean_data =
701       client_->TransferToServer(mean_literal).value();
702   std::unique_ptr<GlobalData> variance_data =
703       client_->TransferToServer(var_literal).value();
704 
705   BatchNormInference(input_activations, scale_activations, offset_activations,
706                      mean_activations, variance_activations, epsilon,
707                      feature_index);
708 
709   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
710   // disables constant folding, but we want it enabled for our zero-sized tensor
711   // testcase.
712   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
713 
714   ComputeAndCompareR4<float>(
715       &builder, expected,
716       {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
717        variance_data.get()},
718       ErrorSpec(0.01, 1));
719 }
720 
XLA_TEST_P(BatchNormTestManySizes,RandomizedGradTests)721 XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
722   float epsilon = 0.001;
723   XlaBuilder builder(TestName());
724   const std::vector<int64_t>& bounds = GetParam().bounds;
725   Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
726   input_array.FillRandom(GetParam().random_value_var,
727                          GetParam().random_value_mean);
728 
729   Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]);
730   grad_output_array.FillRandom(GetParam().random_value_var,
731                                GetParam().random_value_mean);
732 
733   const int64_t feature_index = GetParam().feature_index;
734   const int64_t num_elements_per_feature =
735       Product(bounds) / bounds[feature_index];
736   const int64_t feature_bound = bounds[feature_index];
737   std::vector<float> scale(feature_bound, 2);
738 
739   auto input_squared =
740       ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
741   std::vector<int64_t> reduce_dims;
742   for (int64_t i = 0; i < static_cast<int64_t>(bounds.size()); ++i) {
743     if (i != feature_index) {
744       reduce_dims.push_back(i);
745     }
746   }
747 
748   auto sum =
749       ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
750                                   [](float a, float b) { return a + b; });
751 
752   auto sum_squared =
753       ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
754                                   [](float a, float b) { return a + b; });
755 
756   std::vector<float> mean(feature_bound);
757 
758   for (int64_t i = 0; i < feature_bound; ++i) {
759     if (num_elements_per_feature > 0) {
760       mean[i] = sum[i] / num_elements_per_feature;
761     } else {
762       mean[i] = 0;
763     }
764   }
765 
766   std::vector<float> mean_square(feature_bound);
767   for (int64_t i = 0; i < feature_bound; ++i) {
768     mean_square[i] = mean[i] * mean[i];
769   }
770 
771   std::vector<float> square_mean(feature_bound);
772   for (int64_t i = 0; i < feature_bound; ++i) {
773     if (num_elements_per_feature > 0) {
774       square_mean[i] = sum_squared[i] / num_elements_per_feature;
775     } else {
776       square_mean[i] = 0;
777     }
778   }
779 
780   std::vector<float> var(feature_bound);
781   for (int64_t i = 0; i < feature_bound; ++i) {
782     var[i] = square_mean[i] - mean_square[i];
783   }
784 
785   Array4D<float> mean4D =
786       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
787   auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
788   auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
789 
790   auto var_add_epsilon = *ReferenceUtil::MapArray4D(
791       var4D, [epsilon](float a) { return a + epsilon; });
792 
793   auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
794       var_add_epsilon, [](float a) { return 1 / std::sqrt(a); });
795 
796   auto grad_output_times_var =
797       *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
798                                  [](float a, float b) { return a * b; });
799 
800   auto activation_shifted = *ReferenceUtil::MapArray4D(
801       input_array, mean4D, [](float a, float b) { return a - b; });
802 
803   auto activation_shifted_times_grad_output =
804       *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted,
805                                  [](float a, float b) { return a * b; });
806 
807   auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D(
808       activation_shifted_times_grad_output, rsqrt_var_add_epsilon,
809       [](float a, float b) { return a * b; });
810 
811   auto grad_scale = ReferenceUtil::Reduce4DTo1D(
812       grad_scale_before_reduction, /*init=*/0.0f, reduce_dims,
813       [](float a, float b) { return a + b; });
814 
815   auto grad_offset =
816       ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims,
817                                   [](float a, float b) { return a + b; });
818 
819   auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
820       scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; });
821 
822   auto I1 = *ReferenceUtil::MapArray4D(
823       grad_output_array, [&](float a) { return num_elements_per_feature * a; });
824 
825   auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index);
826 
827   // I3 = sum(output_grad * (activation - mean(activation)))
828   auto I3 = *ReferenceUtil::Broadcast1DTo4D(
829       ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output,
830                                   /*init=*/0.0f, reduce_dims,
831                                   [](float a, float b) { return a + b; }),
832       bounds, feature_index);
833 
834   // I4 = (activation - mean(activation)) *
835   //   sum(output_grad * (activation - mean(activation)))
836   auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted,
837                                        [](float a, float b) { return a * b; });
838 
839   // I5 = (activation - mean(activation)) *
840   //   sum(output_grad * (activation - mean(activation))) / (variance +
841   //   epsilon))
842   auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon,
843                                        [](float a, float b) { return a / b; });
844 
845   auto grad_activation = *ReferenceUtil::MapArray4D(
846       I1, I2, [](float a, float b) { return a - b; });
847 
848   grad_activation = *ReferenceUtil::MapArray4D(
849       grad_activation, I5, [](float a, float b) { return a - b; });
850 
851   grad_activation = *ReferenceUtil::MapArray4D(
852       grad_activation, scale4D, [](float a, float b) { return a * b; });
853 
854   grad_activation = *ReferenceUtil::MapArray4D(
855       grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) {
856         if (num_elements_per_feature > 0) {
857           return a * b / num_elements_per_feature;
858         }
859         return 0.f;
860       });
861 
862   auto expected_grad_activation =
863       LiteralUtil::CreateR4FromArray4D<float>(grad_activation);
864 
865   auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
866   auto scale_literal = LiteralUtil::CreateR1<float>(scale);
867   auto mean_literal = LiteralUtil::CreateR1<float>(mean);
868   auto var_literal = LiteralUtil::CreateR1<float>(var);
869   auto grad_output_literal =
870       LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
871 
872   auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
873   auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
874   auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
875   auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
876   auto grad_output_parameter =
877       Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
878 
879   std::unique_ptr<GlobalData> input_data =
880       client_->TransferToServer(input_literal).value();
881   std::unique_ptr<GlobalData> scale_data =
882       client_->TransferToServer(scale_literal).value();
883   std::unique_ptr<GlobalData> mean_data =
884       client_->TransferToServer(mean_literal).value();
885   std::unique_ptr<GlobalData> var_data =
886       client_->TransferToServer(var_literal).value();
887   std::unique_ptr<GlobalData> grad_output_data =
888       client_->TransferToServer(grad_output_literal).value();
889 
890   BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
891                 grad_output_parameter, epsilon, feature_index);
892 
893   auto expected = LiteralUtil::MakeTupleFromSlices(
894       {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
895        LiteralUtil::CreateR1<float>(grad_offset)});
896 
897   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
898   // disables constant folding, but we want it enabled for our zero-sized tensor
899   // testcase.
900   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
901 
902   ComputeAndCompareTuple(&builder, expected,
903                          {input_data.get(), scale_data.get(), mean_data.get(),
904                           var_data.get(), grad_output_data.get()},
905                          ErrorSpec(0.01, 1));
906 }
907 
908 }  // namespace
909 }  // namespace xla
910