1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17 #include <memory>
18 #include <vector>
19
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/array4d.h"
23 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/reference_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/test.h"
36 #include "tensorflow/compiler/xla/test_helpers.h"
37 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
39 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
40 #include "tensorflow/compiler/xla/tests/test_macros.h"
41 #include "tensorflow/compiler/xla/tests/test_utils.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/util.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/math/math_util.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/test.h"
48
49 namespace xla {
50 namespace {
51
52 class BatchNormalizationTest : public ClientLibraryTestBase {
53 protected:
BatchNormalizationTest()54 BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
55 Array2D<float> pz({
56 // z0 z1
57 {-1.0f, 4.1f}, // p0
58 {2.0f, 4.1f}, // p1
59 {5.0f, 4.4f}, // p2
60 });
61 input_array_.FillWithPZ(pz);
62 input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
63 CHECK_EQ(kSamples, input_array_.planes());
64 CHECK_EQ(kZ, input_array_.depth());
65 CHECK_EQ(kY, input_array_.height());
66 CHECK_EQ(kY, input_array_.width());
67 }
68
CheckShape(XlaBuilder * b,const XlaOp operand,const Shape & expected_shape) const69 XlaOp CheckShape(XlaBuilder* b, const XlaOp operand,
70 const Shape& expected_shape) const {
71 Shape actual_shape = b->GetShape(operand).value();
72 CHECK(ShapeUtil::Equal(expected_shape, actual_shape))
73 << "want " << ShapeUtil::HumanString(expected_shape) << " got "
74 << ShapeUtil::HumanString(actual_shape);
75 return operand;
76 }
77
78 static constexpr int64_t kSamples = 3;
79 static constexpr int64_t kX = 1;
80 static constexpr int64_t kY = 1;
81 static constexpr int64_t kZ = 2;
82
83 Array4D<float> input_array_;
84 Literal input_literal_;
85 const ErrorSpec error_spec_{0.001, 0.001};
86 };
87
XLA_TEST_F(BatchNormalizationTest,SubtractInZ)88 XLA_TEST_F(BatchNormalizationTest, SubtractInZ) {
89 XlaBuilder builder("subtract_in_z_one_sample");
90 auto x = ConstantLiteral(&builder, input_literal_);
91 auto y = ConstantR1<float>(&builder, {3.14, 4.25});
92 Sub(x, y, /*broadcast_dimensions=*/{1});
93
94 Array4D<float> expected(kSamples, kZ, kY, kX);
95 Array2D<float> pz({
96 {-1.0f - 3.14f, 4.1f - 4.25f}, // p0
97 {2.0f - 3.14f, 4.1f - 4.25f}, // p1
98 {5.0f - 3.14f, 4.4f - 4.25f}, // p2
99 });
100 expected.FillWithPZ(pz);
101 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
102 }
103
XLA_TEST_F(BatchNormalizationTest,SquareTesseractElementwise)104 XLA_TEST_F(BatchNormalizationTest, SquareTesseractElementwise) {
105 XlaBuilder builder("square_tesseract_elementwise");
106 auto x = ConstantLiteral(&builder, input_literal_);
107 Square(x);
108
109 using tensorflow::MathUtil;
110
111 Array4D<float> expected(kSamples, kZ, kY, kX);
112 Array2D<float> expected_pz({
113 {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)},
114 {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)},
115 {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)},
116 });
117 expected.FillWithPZ(expected_pz);
118 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
119 }
120
XLA_TEST_F(BatchNormalizationTest,SumToZ)121 XLA_TEST_F(BatchNormalizationTest, SumToZ) {
122 XlaBuilder builder("sum_to_z");
123 auto input_activations = ConstantLiteral(&builder, input_literal_);
124 XlaComputation add = CreateScalarAddComputation(F32, &builder);
125 // Reduce all but the Z dimension.
126 Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
127
128 std::vector<float> expected = {6, 12.6};
129 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
130 }
131
XLA_TEST_F(BatchNormalizationTest,SquareAndReduce)132 XLA_TEST_F(BatchNormalizationTest, SquareAndReduce) {
133 XlaBuilder builder("square_and_reduce");
134 auto input_activations = ConstantLiteral(&builder, input_literal_);
135 auto set_means = ConstantR1<float>(&builder, {2.f, 4.2f});
136 auto activation_deviations = Sub(input_activations, set_means,
137 /*broadcast_dimensions=*/{1});
138 XlaComputation add = CreateScalarAddComputation(F32, &builder);
139 auto dev_squares = Square(activation_deviations);
140 Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
141
142 std::vector<float> expected = {18, 0.06};
143 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
144 }
145
XLA_TEST_F(BatchNormalizationTest,VarianceToStddev)146 XLA_TEST_F(BatchNormalizationTest, VarianceToStddev) {
147 XlaBuilder builder("variance_to_stddev");
148 auto variance = ConstantR1<float>(&builder, {6.f, .02f});
149 Sqrt(variance);
150
151 std::vector<float> expected = {2.44948974f, 0.14142136f};
152 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
153 }
154
155 // Compare against a forward batch normalization example in the NN spec
156 // reference.
XLA_TEST_F(BatchNormalizationTest,SpecComparisonForward)157 XLA_TEST_F(BatchNormalizationTest, SpecComparisonForward) {
158 XlaBuilder builder("batch_normalize_per_spec");
159 auto input_activations =
160 CheckShape(&builder, ConstantLiteral(&builder, input_literal_),
161 ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
162 auto gamma = ConstantR1<float>(&builder, {1.0, 1.0});
163 auto beta = ConstantR1<float>(&builder, {0.0, 0.0});
164 XlaComputation add = CreateScalarAddComputation(F32, &builder);
165 // Reduce all dimensions except dimension 1.
166 Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
167 auto sum = CheckShape(
168 &builder,
169 Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
170 /*dimensions_to_reduce=*/{0, 2, 3}),
171 TwoElementVectorF32);
172 auto input_shape = builder.GetShape(input_activations).value();
173 auto sum_shape = builder.GetShape(sum).value();
174 auto count =
175 ConstantR0<float>(&builder, ShapeUtil::ElementsIn(input_shape) /
176 ShapeUtil::ElementsIn(sum_shape));
177 auto set_means = Div(sum, count);
178
179 const float kEpsilon = 1e-9f;
180 auto epsilon = ConstantR0<float>(&builder, kEpsilon);
181 auto epsilon2 = ConstantR1<float>(&builder, {kEpsilon, kEpsilon});
182 auto activation_deviations = Sub(input_activations, set_means,
183 /*broadcast_dimensions=*/{1});
184 auto dev_squares = Square(activation_deviations);
185 auto sum_of_squares =
186 CheckShape(&builder,
187 Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add,
188 /*dimensions_to_reduce=*/{0, 2, 3}),
189 TwoElementVectorF32);
190 auto variance = Div(sum_of_squares, count);
191 auto standard_deviation = Sqrt(variance);
192 auto standard_deviation_above_epsilon =
193 CheckShape(&builder, Gt(standard_deviation, epsilon),
194 ShapeUtil::MakeShape(PRED, {2}));
195 auto gt_eps =
196 Select(standard_deviation_above_epsilon, standard_deviation, epsilon2);
197 auto normalization_factors = Reciprocal(gt_eps);
198 auto normalized_input_activations =
199 Mul(activation_deviations, normalization_factors,
200 /*broadcast_dimensions=*/{1});
201 /* auto output_activations = */ Add(Mul(normalized_input_activations, gamma,
202 /*broadcast_dimensions=*/{1}),
203 beta, /*broadcast_dimensions=*/{1});
204
205 Array4D<float> expected(kSamples, kZ, kY, kX);
206 Array2D<float> pz({
207 {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
208 {0.f, -.1f / std::sqrt(.02f)},
209 {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
210 });
211 expected.FillWithPZ(pz);
212
213 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
214 }
215
XLA_TEST_F(BatchNormalizationTest,BasicTraining)216 XLA_TEST_F(BatchNormalizationTest, BasicTraining) {
217 const int kFeatureIndex = 3;
218 XlaBuilder builder(TestName());
219
220 auto operand = ConstantR4FromArray4D<float>(
221 &builder, {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
222
223 auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
224
225 auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
226
227 BatchNormTraining(operand, scale, offset,
228 /*epsilon=*/0.001, kFeatureIndex);
229
230 auto expected = LiteralUtil::MakeTupleFromSlices(
231 {LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
232 {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
233 LiteralUtil::CreateR1<float>({4, 5}),
234 LiteralUtil::CreateR1<float>({5, 5})});
235
236 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
237 }
238
XLA_TEST_F(BatchNormalizationTest,BasicTraining_fp16)239 XLA_TEST_F(BatchNormalizationTest, BasicTraining_fp16) {
240 const int kFeatureIndex = 3;
241 XlaBuilder builder(TestName());
242 Array4D<Eigen::half> input = {{{{1.f, 2.f}}, {{3.f, 4.f}}},
243 {{{5.f, 6.f}}, {{7.f, 8.f}}}};
244 auto operand = ConstantR4FromArray4D<Eigen::half>(&builder, input);
245
246 auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
247
248 auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
249
250 auto input_f32 = ConvertElementType(operand, F32);
251
252 auto output = BatchNormTraining(input_f32, scale, offset,
253 /*epsilon=*/0.001, kFeatureIndex);
254
255 auto converted = ConvertElementType(GetTupleElement(output, 0), F16);
256 Tuple(&builder,
257 {converted, GetTupleElement(output, 1), GetTupleElement(output, 2)});
258
259 Array4D<Eigen::half> out = {{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
260 {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}};
261
262 auto expected = LiteralUtil::MakeTupleFromSlices(
263 {LiteralUtil::CreateFromArray(out), LiteralUtil::CreateR1<float>({4, 5}),
264 LiteralUtil::CreateR1<float>({5, 5})});
265
266 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
267 }
268
XLA_TEST_F(BatchNormalizationTest,BasicTrainingOnDimension2)269 XLA_TEST_F(BatchNormalizationTest, BasicTrainingOnDimension2) {
270 const int kFeatureIndex = 2;
271 XlaBuilder builder(TestName());
272
273 auto operand = ConstantR4FromArray4D<float>(
274 &builder,
275 {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
276
277 auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
278
279 auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
280
281 BatchNormTraining(operand, scale, offset,
282 /*epsilon=*/0.001, kFeatureIndex);
283
284 auto expected = LiteralUtil::MakeTupleFromSlices(
285 {LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
286 {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
287 LiteralUtil::CreateR1<float>({4, 5}),
288 LiteralUtil::CreateR1<float>({5, 5})});
289
290 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
291 }
292
XLA_TEST_F(BatchNormalizationTest,BasicTrainingOnDimension2_fp16)293 XLA_TEST_F(BatchNormalizationTest, BasicTrainingOnDimension2_fp16) {
294 const int kFeatureIndex = 2;
295 XlaBuilder builder(TestName());
296 Array4D<Eigen::half> input = {{{{1.f}, {2.f}}, {{3.f}, {4.f}}},
297 {{{5.f}, {6.f}}, {{7.f}, {8.f}}}};
298 auto operand = ConstantR4FromArray4D<Eigen::half>(&builder, input);
299
300 auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
301
302 auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
303
304 auto input_f32 = ConvertElementType(operand, F32);
305
306 auto output = BatchNormTraining(input_f32, scale, offset,
307 /*epsilon=*/0.001, kFeatureIndex);
308
309 auto converted = ConvertElementType(GetTupleElement(output, 0), F16);
310 Tuple(&builder,
311 {converted, GetTupleElement(output, 1), GetTupleElement(output, 2)});
312
313 Array4D<Eigen::half> out = {{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
314 {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}};
315
316 auto expected = LiteralUtil::MakeTupleFromSlices(
317 {LiteralUtil::CreateFromArray(out), LiteralUtil::CreateR1<float>({4, 5}),
318 LiteralUtil::CreateR1<float>({5, 5})});
319
320 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
321 }
322
XLA_TEST_F(BatchNormalizationTest,TrainingWithFeatureOnLowDimension)323 XLA_TEST_F(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
324 // Use 0 dimension as feature, tests layout analyzer.
325 const int kFeatureIndex = 0;
326 XlaBuilder builder(TestName());
327
328 XlaOp h0;
329 auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f),
330 /*parameter_number=*/0, "operand",
331 &builder, &h0);
332 XlaOp h1;
333 auto scale =
334 CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
335 /*parameter_number=*/1, "scale", &builder, &h1);
336 XlaOp h2;
337 auto offset =
338 CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
339 /*parameter_number=*/2, "offset", &builder, &h2);
340
341 BatchNormTraining(h0, h1, h2,
342 /*epsilon=*/1, kFeatureIndex);
343
344 auto expected = LiteralUtil::MakeTupleFromSlices(
345 {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
346 LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
347 LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
348
349 ComputeAndCompareTuple(&builder, expected,
350 {operand.get(), scale.get(), offset.get()},
351 ErrorSpec(0.1));
352 }
353
XLA_TEST_F(BatchNormalizationTest,LargeEpsilonTest)354 XLA_TEST_F(BatchNormalizationTest, LargeEpsilonTest) {
355 // Test the correctness of choosing a large epsilon value.
356 const int kFeatureIndex = 2;
357 XlaBuilder builder(TestName());
358
359 XlaOp h0;
360 auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}},
361 /*parameter_number=*/0, "operand",
362 &builder, &h0);
363 XlaOp h1;
364 auto scale =
365 CreateR1Parameter<float>(std::vector<float>(1, 1.0f),
366 /*parameter_number=*/1, "scale", &builder, &h1);
367 XlaOp h2;
368 auto offset =
369 CreateR1Parameter<float>(std::vector<float>(1, 0.0f),
370 /*parameter_number=*/2, "offset", &builder, &h2);
371
372 // var = 125, mean = 15, epsilon = -100
373 BatchNormTraining(h0, h1, h2,
374 /*epsilon=*/-100, kFeatureIndex);
375
376 auto expected = LiteralUtil::MakeTupleFromSlices(
377 {LiteralUtil::CreateR3FromArray3D<float>(
378 {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
379 LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
380 LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
381
382 ComputeAndCompareTuple(&builder, expected,
383 {operand.get(), scale.get(), offset.get()},
384 ErrorSpec(0.1));
385 }
386
XLA_TEST_F(BatchNormalizationTest,BatchNormGradBasic)387 XLA_TEST_F(BatchNormalizationTest, BatchNormGradBasic) {
388 const int kFeatureIndex = 2;
389 XlaBuilder builder(TestName());
390
391 auto operand =
392 ConstantR4FromArray4D<float>(&builder, Array4D<float>(2, 2, 2, 1, 0.0f));
393
394 auto scale = ConstantR1<float>(&builder, {1.0f, 1.0f});
395
396 auto mean = ConstantR1<float>(&builder, {0.0f, 0.0f});
397
398 auto var = ConstantR1<float>(&builder, {1.0f, 1.0f});
399
400 auto grad_output = ConstantR4FromArray4D<float>(
401 &builder,
402 {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
403
404 BatchNormGrad(operand, scale, mean, var, grad_output,
405 /*epsilon=*/0.0, kFeatureIndex);
406
407 auto expected = LiteralUtil::MakeTupleFromSlices(
408 {LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
409 {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
410 LiteralUtil::CreateR1<float>({0, 0}),
411 LiteralUtil::CreateR1<float>({16, 20})});
412
413 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
414 }
415
XLA_TEST_F(BatchNormalizationTest,BatchNormGradBasic_fp16)416 XLA_TEST_F(BatchNormalizationTest, BatchNormGradBasic_fp16) {
417 const int kFeatureIndex = 2;
418 XlaBuilder builder(TestName());
419 auto operand = ConstantR4FromArray4D<Eigen::half>(
420 &builder,
421 Array4D<Eigen::half>(2, 2, 2, 1, static_cast<Eigen::half>(0.0f)));
422
423 auto operand_f32 = ConvertElementType(operand, F32);
424
425 auto scale = ConstantR1<float>(&builder, {1.0f, 1.0f});
426
427 auto mean = ConstantR1<float>(&builder, {0.0f, 0.0f});
428
429 auto var = ConstantR1<float>(&builder, {1.0f, 1.0f});
430
431 auto grad_output = ConstantR4FromArray4D<Eigen::half>(
432 &builder,
433 {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
434
435 auto grad_output_f32 = ConvertElementType(grad_output, F32);
436
437 auto output = BatchNormGrad(operand_f32, scale, mean, var, grad_output_f32,
438 /*epsilon=*/0.001, kFeatureIndex);
439
440 auto converted_output = ConvertElementType(GetTupleElement(output, 0), F16);
441 Tuple(&builder, {converted_output, GetTupleElement(output, 1),
442 GetTupleElement(output, 2)});
443
444 Array4D<Eigen::half> out = {{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
445 {{{1.f}, {1.f}}, {{3.f}, {3.f}}}};
446 auto expected = LiteralUtil::MakeTupleFromSlices(
447 {LiteralUtil::CreateFromArray(out), LiteralUtil::CreateR1<float>({0, 0}),
448 LiteralUtil::CreateR1<float>({16, 20})});
449
450 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
451 }
452
453 struct BatchNormTestParam {
454 std::vector<int64_t> bounds;
455 int64_t feature_index;
456 float random_value_mean;
457 float random_value_var;
458
operator <<(::std::ostream & os,const BatchNormTestParam & p)459 friend ::std::ostream& operator<<(::std::ostream& os,
460 const BatchNormTestParam& p) {
461 os << "bounds={" << absl::StrJoin(p.bounds, ", ") << "}, ";
462 os << "feature_index=" << p.feature_index << ", ";
463 os << "random_value_mean=" << p.random_value_mean << ", ";
464 os << "random_value_var=" << p.random_value_var;
465 return os;
466 }
467 };
468
469 // Tests to test the fused operation of BatchNorm.
470 class BatchNormTestManySizes
471 : public ClientLibraryTestBase,
472 public ::testing::WithParamInterface<BatchNormTestParam> {};
473
BuildBatchNormTestParams()474 std::vector<BatchNormTestParam> BuildBatchNormTestParams() {
475 std::vector<BatchNormTestParam> params;
476
477 auto add_testcase = [&](std::vector<int64_t> bounds, int64_t feature_index,
478 float random_value_mean, float random_value_var) {
479 BatchNormTestParam p{bounds, feature_index, random_value_mean,
480 random_value_var};
481 params.push_back(p);
482 };
483
484 add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f);
485 add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f);
486
487 add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f);
488 add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f);
489 add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f);
490 add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f);
491 add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f);
492 add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f);
493 add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f);
494 add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f);
495
496 add_testcase({24, 129, 1, 2}, 2, 10000, 10000);
497 add_testcase({24, 129, 1, 2}, 3, 10000, 10000);
498
499 // Feature on low dimension to trigger relayout, check that internal logical
500 // to physical dimension calculation is correct after relayout.
501 add_testcase({1, 2, 3, 4}, 0, 100, 100);
502
503 // Zero-sized tensor.
504 add_testcase({1, 0, 100, 42}, 0, 100, 100);
505
506 return params;
507 }
508
509 INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes,
510 ::testing::ValuesIn(BuildBatchNormTestParams()));
511
XLA_TEST_P(BatchNormTestManySizes,RandomizedTrainingTests)512 XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
513 float epsilon = 0.001;
514 XlaBuilder builder(TestName());
515 const std::vector<int64_t>& bounds = GetParam().bounds;
516 Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
517 input_array.FillRandom(GetParam().random_value_var,
518 GetParam().random_value_mean);
519
520 const int64_t feature_index = GetParam().feature_index;
521 const int64_t num_elements_per_feature =
522 Product(bounds) / bounds[feature_index];
523 const int64_t feature_bound = bounds[feature_index];
524 std::vector<float> offset(feature_bound, 1);
525 std::vector<float> scale(feature_bound, 2);
526
527 auto input_squared =
528 ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
529 std::vector<int64_t> reduce_dims;
530 for (int64_t i = 0; i < static_cast<int64_t>(bounds.size()); ++i) {
531 if (i != feature_index) {
532 reduce_dims.push_back(i);
533 }
534 }
535
536 auto sum =
537 ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
538 [](float a, float b) { return a + b; });
539
540 auto sum_squared =
541 ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
542 [](float a, float b) { return a + b; });
543
544 std::vector<float> mean(feature_bound);
545
546 for (int64_t i = 0; i < feature_bound; ++i) {
547 mean[i] = sum[i] / num_elements_per_feature;
548 }
549
550 std::vector<float> mean_square(feature_bound);
551 for (int64_t i = 0; i < feature_bound; ++i) {
552 mean_square[i] = mean[i] * mean[i];
553 }
554
555 std::vector<float> square_mean(feature_bound);
556 for (int64_t i = 0; i < feature_bound; ++i) {
557 square_mean[i] = sum_squared[i] / num_elements_per_feature;
558 }
559
560 std::vector<float> var(feature_bound);
561 for (int64_t i = 0; i < feature_bound; ++i) {
562 var[i] = square_mean[i] - mean_square[i];
563 }
564
565 Array4D<float> mean4D =
566 *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
567 auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
568 auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
569 auto offset4D =
570 *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
571
572 auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
573 scale4D, offset4D, epsilon);
574
575 auto expected_normalized =
576 LiteralUtil::CreateR4FromArray4D<float>(normalized);
577
578 auto offset_literal = LiteralUtil::CreateR1<float>(offset);
579 auto scale_literal = LiteralUtil::CreateR1<float>(scale);
580 auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
581
582 auto input_activations =
583 Parameter(&builder, 0, input_literal.shape(), "input");
584 auto scale_activations =
585 Parameter(&builder, 1, scale_literal.shape(), "offset");
586 auto offset_activations =
587 Parameter(&builder, 2, offset_literal.shape(), "scale");
588
589 auto expected = LiteralUtil::MakeTupleFromSlices(
590 {expected_normalized, LiteralUtil::CreateR1<float>(mean),
591 LiteralUtil::CreateR1<float>(var)});
592
593 std::unique_ptr<GlobalData> input_data =
594 client_->TransferToServer(input_literal).value();
595 std::unique_ptr<GlobalData> scale_data =
596 client_->TransferToServer(scale_literal).value();
597 std::unique_ptr<GlobalData> offset_data =
598 client_->TransferToServer(offset_literal).value();
599
600 BatchNormTraining(input_activations, scale_activations, offset_activations,
601 epsilon, feature_index);
602
603 // Run all HLO passes during this test. In particular, ClientLibraryTestBase
604 // disables constant folding, but we want it enabled for our zero-sized tensor
605 // testcase.
606 execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
607 ComputeAndCompareTuple(
608 &builder, expected,
609 {input_data.get(), scale_data.get(), offset_data.get()},
610 ErrorSpec(0.01, 1));
611 }
612
XLA_TEST_P(BatchNormTestManySizes,RandomizedInferencingTests)613 XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
614 float epsilon = 0.001;
615 XlaBuilder builder(TestName());
616 const std::vector<int64_t>& bounds = GetParam().bounds;
617 Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
618 input_array.FillRandom(GetParam().random_value_var,
619 GetParam().random_value_mean);
620
621 const int64_t feature_index = GetParam().feature_index;
622 const int64_t num_elements_per_feature =
623 Product(bounds) / bounds[feature_index];
624 const int64_t feature_bound = bounds[feature_index];
625 std::vector<float> offset(feature_bound, 1);
626 std::vector<float> scale(feature_bound, 2);
627
628 auto input_squared =
629 ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
630 std::vector<int64_t> reduce_dims;
631 for (int64_t i = 0; i < static_cast<int64_t>(bounds.size()); ++i) {
632 if (i != feature_index) {
633 reduce_dims.push_back(i);
634 }
635 }
636
637 auto sum =
638 ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
639 [](float a, float b) { return a + b; });
640
641 auto sum_squared =
642 ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
643 [](float a, float b) { return a + b; });
644
645 std::vector<float> mean(feature_bound);
646
647 for (int64_t i = 0; i < feature_bound; ++i) {
648 mean[i] = sum[i] / num_elements_per_feature;
649 }
650
651 std::vector<float> mean_square(feature_bound);
652 for (int64_t i = 0; i < feature_bound; ++i) {
653 mean_square[i] = mean[i] * mean[i];
654 }
655
656 std::vector<float> square_mean(feature_bound);
657 for (int64_t i = 0; i < feature_bound; ++i) {
658 square_mean[i] = sum_squared[i] / num_elements_per_feature;
659 }
660
661 std::vector<float> var(feature_bound);
662 for (int64_t i = 0; i < feature_bound; ++i) {
663 var[i] = square_mean[i] - mean_square[i];
664 }
665
666 Array4D<float> mean4D =
667 *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
668 auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
669 auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
670 auto offset4D =
671 *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
672
673 auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
674 scale4D, offset4D, epsilon);
675
676 auto offset_literal = LiteralUtil::CreateR1<float>(offset);
677 auto scale_literal = LiteralUtil::CreateR1<float>(scale);
678 auto mean_literal = LiteralUtil::CreateR1<float>(mean);
679 auto var_literal = LiteralUtil::CreateR1<float>(var);
680 auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
681
682 auto input_activations =
683 Parameter(&builder, 0, input_literal.shape(), "input");
684 auto scale_activations =
685 Parameter(&builder, 1, scale_literal.shape(), "offset");
686 auto offset_activations =
687 Parameter(&builder, 2, offset_literal.shape(), "scale");
688 auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
689 auto variance_activations =
690 Parameter(&builder, 4, var_literal.shape(), "variance");
691
692 Array4D<float> expected = normalized;
693
694 std::unique_ptr<GlobalData> input_data =
695 client_->TransferToServer(input_literal).value();
696 std::unique_ptr<GlobalData> scale_data =
697 client_->TransferToServer(scale_literal).value();
698 std::unique_ptr<GlobalData> offset_data =
699 client_->TransferToServer(offset_literal).value();
700 std::unique_ptr<GlobalData> mean_data =
701 client_->TransferToServer(mean_literal).value();
702 std::unique_ptr<GlobalData> variance_data =
703 client_->TransferToServer(var_literal).value();
704
705 BatchNormInference(input_activations, scale_activations, offset_activations,
706 mean_activations, variance_activations, epsilon,
707 feature_index);
708
709 // Run all HLO passes during this test. In particular, ClientLibraryTestBase
710 // disables constant folding, but we want it enabled for our zero-sized tensor
711 // testcase.
712 execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
713
714 ComputeAndCompareR4<float>(
715 &builder, expected,
716 {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
717 variance_data.get()},
718 ErrorSpec(0.01, 1));
719 }
720
XLA_TEST_P(BatchNormTestManySizes,RandomizedGradTests)721 XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
722 float epsilon = 0.001;
723 XlaBuilder builder(TestName());
724 const std::vector<int64_t>& bounds = GetParam().bounds;
725 Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
726 input_array.FillRandom(GetParam().random_value_var,
727 GetParam().random_value_mean);
728
729 Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]);
730 grad_output_array.FillRandom(GetParam().random_value_var,
731 GetParam().random_value_mean);
732
733 const int64_t feature_index = GetParam().feature_index;
734 const int64_t num_elements_per_feature =
735 Product(bounds) / bounds[feature_index];
736 const int64_t feature_bound = bounds[feature_index];
737 std::vector<float> scale(feature_bound, 2);
738
739 auto input_squared =
740 ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
741 std::vector<int64_t> reduce_dims;
742 for (int64_t i = 0; i < static_cast<int64_t>(bounds.size()); ++i) {
743 if (i != feature_index) {
744 reduce_dims.push_back(i);
745 }
746 }
747
748 auto sum =
749 ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
750 [](float a, float b) { return a + b; });
751
752 auto sum_squared =
753 ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
754 [](float a, float b) { return a + b; });
755
756 std::vector<float> mean(feature_bound);
757
758 for (int64_t i = 0; i < feature_bound; ++i) {
759 if (num_elements_per_feature > 0) {
760 mean[i] = sum[i] / num_elements_per_feature;
761 } else {
762 mean[i] = 0;
763 }
764 }
765
766 std::vector<float> mean_square(feature_bound);
767 for (int64_t i = 0; i < feature_bound; ++i) {
768 mean_square[i] = mean[i] * mean[i];
769 }
770
771 std::vector<float> square_mean(feature_bound);
772 for (int64_t i = 0; i < feature_bound; ++i) {
773 if (num_elements_per_feature > 0) {
774 square_mean[i] = sum_squared[i] / num_elements_per_feature;
775 } else {
776 square_mean[i] = 0;
777 }
778 }
779
780 std::vector<float> var(feature_bound);
781 for (int64_t i = 0; i < feature_bound; ++i) {
782 var[i] = square_mean[i] - mean_square[i];
783 }
784
785 Array4D<float> mean4D =
786 *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
787 auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
788 auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
789
790 auto var_add_epsilon = *ReferenceUtil::MapArray4D(
791 var4D, [epsilon](float a) { return a + epsilon; });
792
793 auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
794 var_add_epsilon, [](float a) { return 1 / std::sqrt(a); });
795
796 auto grad_output_times_var =
797 *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
798 [](float a, float b) { return a * b; });
799
800 auto activation_shifted = *ReferenceUtil::MapArray4D(
801 input_array, mean4D, [](float a, float b) { return a - b; });
802
803 auto activation_shifted_times_grad_output =
804 *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted,
805 [](float a, float b) { return a * b; });
806
807 auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D(
808 activation_shifted_times_grad_output, rsqrt_var_add_epsilon,
809 [](float a, float b) { return a * b; });
810
811 auto grad_scale = ReferenceUtil::Reduce4DTo1D(
812 grad_scale_before_reduction, /*init=*/0.0f, reduce_dims,
813 [](float a, float b) { return a + b; });
814
815 auto grad_offset =
816 ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims,
817 [](float a, float b) { return a + b; });
818
819 auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
820 scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; });
821
822 auto I1 = *ReferenceUtil::MapArray4D(
823 grad_output_array, [&](float a) { return num_elements_per_feature * a; });
824
825 auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index);
826
827 // I3 = sum(output_grad * (activation - mean(activation)))
828 auto I3 = *ReferenceUtil::Broadcast1DTo4D(
829 ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output,
830 /*init=*/0.0f, reduce_dims,
831 [](float a, float b) { return a + b; }),
832 bounds, feature_index);
833
834 // I4 = (activation - mean(activation)) *
835 // sum(output_grad * (activation - mean(activation)))
836 auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted,
837 [](float a, float b) { return a * b; });
838
839 // I5 = (activation - mean(activation)) *
840 // sum(output_grad * (activation - mean(activation))) / (variance +
841 // epsilon))
842 auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon,
843 [](float a, float b) { return a / b; });
844
845 auto grad_activation = *ReferenceUtil::MapArray4D(
846 I1, I2, [](float a, float b) { return a - b; });
847
848 grad_activation = *ReferenceUtil::MapArray4D(
849 grad_activation, I5, [](float a, float b) { return a - b; });
850
851 grad_activation = *ReferenceUtil::MapArray4D(
852 grad_activation, scale4D, [](float a, float b) { return a * b; });
853
854 grad_activation = *ReferenceUtil::MapArray4D(
855 grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) {
856 if (num_elements_per_feature > 0) {
857 return a * b / num_elements_per_feature;
858 }
859 return 0.f;
860 });
861
862 auto expected_grad_activation =
863 LiteralUtil::CreateR4FromArray4D<float>(grad_activation);
864
865 auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
866 auto scale_literal = LiteralUtil::CreateR1<float>(scale);
867 auto mean_literal = LiteralUtil::CreateR1<float>(mean);
868 auto var_literal = LiteralUtil::CreateR1<float>(var);
869 auto grad_output_literal =
870 LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
871
872 auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
873 auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
874 auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
875 auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
876 auto grad_output_parameter =
877 Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
878
879 std::unique_ptr<GlobalData> input_data =
880 client_->TransferToServer(input_literal).value();
881 std::unique_ptr<GlobalData> scale_data =
882 client_->TransferToServer(scale_literal).value();
883 std::unique_ptr<GlobalData> mean_data =
884 client_->TransferToServer(mean_literal).value();
885 std::unique_ptr<GlobalData> var_data =
886 client_->TransferToServer(var_literal).value();
887 std::unique_ptr<GlobalData> grad_output_data =
888 client_->TransferToServer(grad_output_literal).value();
889
890 BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
891 grad_output_parameter, epsilon, feature_index);
892
893 auto expected = LiteralUtil::MakeTupleFromSlices(
894 {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
895 LiteralUtil::CreateR1<float>(grad_offset)});
896
897 // Run all HLO passes during this test. In particular, ClientLibraryTestBase
898 // disables constant folding, but we want it enabled for our zero-sized tensor
899 // testcase.
900 execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
901
902 ComputeAndCompareTuple(&builder, expected,
903 {input_data.get(), scale_data.get(), mean_data.get(),
904 var_data.get(), grad_output_data.get()},
905 ErrorSpec(0.01, 1));
906 }
907
908 } // namespace
909 } // namespace xla
910