xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/reduce_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests that multi-dimensional arrays can be reduced among various
17 // user-provided dimensions.
18 //
19 // Note that comments for these tests are white-box in that they talk about the
20 // default data layout.
21 //
22 // The test space for reductions is the cartesian product of:
23 //
24 //    <possible ranks> x
25 //    <possible layouts for chosen rank> x
26 //    <possible subsets of dimensions in chosen rank>
27 
28 #include <stdlib.h>
29 
30 #include <algorithm>
31 #include <cmath>
32 #include <functional>
33 #include <memory>
34 #include <random>
35 #include <string>
36 #include <utility>
37 #include <vector>
38 
39 #include "absl/algorithm/container.h"
40 #include "absl/strings/str_format.h"
41 #include "absl/strings/str_join.h"
42 #include "absl/types/span.h"
43 #include "tensorflow/compiler/xla/array2d.h"
44 #include "tensorflow/compiler/xla/array4d.h"
45 #include "tensorflow/compiler/xla/client/global_data.h"
46 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
47 #include "tensorflow/compiler/xla/client/local_client.h"
48 #include "tensorflow/compiler/xla/client/xla_builder.h"
49 #include "tensorflow/compiler/xla/client/xla_computation.h"
50 #include "tensorflow/compiler/xla/layout_util.h"
51 #include "tensorflow/compiler/xla/literal_util.h"
52 #include "tensorflow/compiler/xla/reference_util.h"
53 #include "tensorflow/compiler/xla/shape_util.h"
54 #include "tensorflow/compiler/xla/status_macros.h"
55 #include "tensorflow/compiler/xla/statusor.h"
56 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
57 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
58 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
59 #include "tensorflow/compiler/xla/tests/test_macros.h"
60 #include "tensorflow/compiler/xla/util.h"
61 #include "tensorflow/compiler/xla/xla_data.pb.h"
62 #include "tensorflow/core/lib/core/status_test_util.h"
63 #include "tensorflow/core/platform/test.h"
64 
65 namespace xla {
66 namespace {
67 
68 using FuncGeneratorForType = XlaComputation (*)(PrimitiveType, XlaBuilder*);
69 
70 using FuncGenerator = XlaComputation (*)(XlaBuilder*);
71 
72 class ReduceTest : public ClientLibraryTestBase {
73  protected:
ReduceTest()74   ReduceTest() {
75     // Implementation note: laid out z >> y >> x by default.
76     // clang-format off
77     literal_2d_ = LiteralUtil::CreateR2<float>({
78       // x0   x1   x2
79       { 1.f, 2.f, 3.f},  // y0
80       { 4.f, 5.f, 6.f},  // y1
81     });
82     literal_3d_ = LiteralUtil::CreateR3Projected<float>({
83       // x0   x1   x2
84       { 1.f, 2.f, 3.f},  // y0
85       { 4.f, 5.f, 6.f},  // y1
86     }, 4);
87     // clang-format on
88     CHECK(ShapeUtil::Equal(
89         literal_3d_.shape(),
90         ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
91         << literal_3d_.shape().ShortDebugString();
92   }
93 
94   // Runs an R1 => R0 reduction test with the given number of elements.
RunR1ToR0Test(int64_t element_count)95   void RunR1ToR0Test(int64_t element_count) {
96     XlaBuilder builder(TestName());
97     XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
98     const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count});
99     auto input = Parameter(&builder, 0, input_shape, "input");
100     auto zero = ConstantR0<float>(&builder, 0.0);
101     Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
102     std::minstd_rand rng(seed_);
103 
104     std::vector<float> input_data(element_count);
105     for (int64_t i = 0; i < element_count; ++i) {
106       input_data[i] = rng() % 3;
107       if (rng() % 2 == 0) {
108         input_data[i] *= -1;
109       }
110     }
111     Literal input_literal =
112         LiteralUtil::CreateR1(absl::MakeConstSpan(input_data));
113     std::unique_ptr<GlobalData> input_global_data =
114         client_->TransferToServer(input_literal).value();
115 
116     float expected = absl::c_accumulate(input_data, 0.0f);
117     ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
118                                ErrorSpec(0.001));
119   }
120 
RunR1ToR0PredTest(bool and_reduce,absl::Span<const int> input_data)121   void RunR1ToR0PredTest(bool and_reduce, absl::Span<const int> input_data) {
122     const int element_count = input_data.size();
123     XlaBuilder builder(TestName());
124     const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count});
125     auto input_par = Parameter(&builder, 0, input_shape, "input");
126     auto pred_values =
127         Eq(input_par, ConstantR1<int>(&builder, element_count, 1));
128     XlaOp init_value;
129     XlaComputation reduce;
130     if (and_reduce) {
131       init_value = ConstantR0<bool>(&builder, true);
132       reduce = CreateScalarAndComputation(PRED, &builder);
133     } else {
134       init_value = ConstantR0<bool>(&builder, false);
135       reduce = CreateScalarOrComputation(PRED, &builder);
136     }
137     Reduce(pred_values, init_value, reduce,
138            /*dimensions_to_reduce=*/{0});
139 
140     Literal input_literal = LiteralUtil::CreateR1(input_data);
141     std::unique_ptr<GlobalData> input_global_data =
142         client_->TransferToServer(input_literal).value();
143 
144     bool expected = and_reduce;
145     for (bool item : input_data) {
146       if (and_reduce) {
147         expected = expected && item;
148       } else {
149         expected = expected || item;
150       }
151     }
152     ComputeAndCompareR0<bool>(&builder, expected, {input_global_data.get()});
153   }
154 
155   // Reduce predicate tensor with dimension rows * cols to dimension cols, to
156   // test the implementation of atomic operations on misaligned small data
157   // types.
158   template <int64_t cols>
RunR2ToR1PredTest(bool and_reduce,int64_t rows,int64_t minor=1,int64_t major=0)159   void RunR2ToR1PredTest(bool and_reduce, int64_t rows, int64_t minor = 1,
160                          int64_t major = 0) {
161     XlaBuilder builder(TestName());
162     const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols});
163     auto input = Parameter(&builder, 0, input_shape, "input");
164     auto input_pred = Eq(input, ConstantR0<uint8_t>(&builder, 1));
165 
166     XlaOp init_value;
167     XlaComputation reduce_op;
168     if (and_reduce) {
169       init_value = ConstantR0<bool>(&builder, true);
170       reduce_op = CreateScalarAndComputation(PRED, &builder);
171     } else {
172       init_value = ConstantR0<bool>(&builder, false);
173       reduce_op = CreateScalarOrComputation(PRED, &builder);
174     }
175 
176     Reduce(input_pred, init_value, reduce_op,
177            /*dimensions_to_reduce=*/{0});
178 
179     Array2D<uint8_t> input_data(rows, cols);
180     input_data.FillRandom(0, 1);
181     Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
182     input_literal =
183         input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
184     std::unique_ptr<GlobalData> input_global_data =
185         client_->TransferToServer(input_literal).value();
186 
187     std::array<bool, cols> expected;
188     for (int64_t colno = 0; colno < cols; ++colno) {
189       bool column_sum = and_reduce ? true : false;
190       for (int64_t rowno = 0; rowno < rows; ++rowno) {
191         if (and_reduce) {
192           column_sum = column_sum && input_data(rowno, colno);
193         } else {
194           column_sum = column_sum || input_data(rowno, colno);
195         }
196       }
197       expected[colno] = column_sum;
198     }
199 
200     ComputeAndCompareR1<bool>(&builder, expected, {input_global_data.get()});
201   }
202 
203   // Runs an R2 => R0 reduction test with the given number of (rows, cols).
RunR2ToR0Test(int64_t rows,int64_t cols,int64_t minor=1,int64_t major=0)204   void RunR2ToR0Test(int64_t rows, int64_t cols, int64_t minor = 1,
205                      int64_t major = 0) {
206     XlaBuilder builder(TestName());
207     XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
208     const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
209     auto input = Parameter(&builder, 0, input_shape, "input");
210     auto zero = ConstantR0<float>(&builder, 0.0);
211     Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1});
212 
213     Array2D<float> input_data(rows, cols);
214     input_data.FillRandom(3.14f, 0.04);
215     Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
216     input_literal =
217         input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
218     std::unique_ptr<GlobalData> input_global_data =
219         client_->TransferToServer(input_literal).value();
220 
221     float expected = 0.0;
222     for (int64_t rowno = 0; rowno < rows; ++rowno) {
223       for (int64_t colno = 0; colno < cols; ++colno) {
224         expected += input_data(rowno, colno);
225       }
226     }
227     ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
228                                ErrorSpec(0.01, 1e-4));
229   }
230 
231   // Runs an R2 => R1 reduction test with the given number of (rows, cols).
RunR2ToR1Test(int64_t rows,int64_t cols,int64_t minor=1,int64_t major=0)232   void RunR2ToR1Test(int64_t rows, int64_t cols, int64_t minor = 1,
233                      int64_t major = 0) {
234     XlaBuilder builder(TestName());
235     XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
236     const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
237     auto input = Parameter(&builder, 0, input_shape, "input");
238     auto zero = ConstantR0<float>(&builder, 0.0);
239     Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
240 
241     Array2D<float> input_data(rows, cols);
242     input_data.FillRandom(3.14f, 0.04);
243     Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
244     input_literal =
245         input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
246     std::unique_ptr<GlobalData> input_global_data =
247         client_->TransferToServer(input_literal).value();
248 
249     std::vector<float> expected;
250     expected.reserve(cols);
251     for (int64_t colno = 0; colno < cols; ++colno) {
252       float column_sum = 0;
253       for (int64_t rowno = 0; rowno < rows; ++rowno) {
254         column_sum += input_data(rowno, colno);
255       }
256       expected.push_back(column_sum);
257     }
258     ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
259                                ErrorSpec(0.01, 1e-4));
260   }
261 
262   template <typename NativeT>
ComputeAndCompareGeneric(typename std::enable_if<std::is_floating_point<NativeT>::value,XlaBuilder>::type * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)263   void ComputeAndCompareGeneric(
264       typename std::enable_if<std::is_floating_point<NativeT>::value,
265                               XlaBuilder>::type* builder,
266       absl::Span<const NativeT> expected,
267       absl::Span<GlobalData* const> arguments) {
268     ComputeAndCompareR1<NativeT>(builder, expected, arguments,
269                                  ErrorSpec(0.01, 1e-4));
270   }
271 
272   template <typename NativeT>
ComputeAndCompareGeneric(typename std::enable_if<std::is_integral<NativeT>::value,XlaBuilder>::type * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)273   void ComputeAndCompareGeneric(
274       typename std::enable_if<std::is_integral<NativeT>::value,
275                               XlaBuilder>::type* builder,
276       absl::Span<const NativeT> expected,
277       absl::Span<GlobalData* const> arguments) {
278     ComputeAndCompareR1<NativeT>(builder, expected, arguments);
279   }
280 
281   template <typename NativeT>
RunVectorizedReduceTestForType(const std::function<XlaComputation (XlaBuilder *)> & reduction_function_generator,const std::function<NativeT (NativeT,NativeT)> & reference_reduction_function,const NativeT & initial_value)282   void RunVectorizedReduceTestForType(
283       const std::function<XlaComputation(XlaBuilder*)>&
284           reduction_function_generator,
285       const std::function<NativeT(NativeT, NativeT)>&
286           reference_reduction_function,
287       const NativeT& initial_value) {
288     const int rows = 64, cols = 128;
289     const int minor = 1, major = 0;
290     XlaBuilder builder(TestName());
291     XlaComputation reduction_function = reduction_function_generator(&builder);
292     const Shape input_shape = ShapeUtil::MakeShape(
293         xla::primitive_util::NativeToPrimitiveType<NativeT>(), {rows, cols});
294     auto input = Parameter(&builder, 0, input_shape, "input");
295     auto zero = ConstantR0<NativeT>(&builder, initial_value);
296     Reduce(input, zero, reduction_function,
297            /*dimensions_to_reduce=*/{0});
298 
299     Array2D<NativeT> input_data(rows, cols);
300     input_data.FillUnique(initial_value);
301     Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
302     input_literal =
303         input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
304     std::unique_ptr<GlobalData> input_global_data =
305         client_->TransferToServer(input_literal).value();
306 
307     // NativeT can be bool, and std::vector<bool> does not convert to
308     // Span.
309     std::unique_ptr<NativeT[]> expected(new NativeT[cols]);
310     for (int64_t colno = 0; colno < cols; ++colno) {
311       NativeT column_result = initial_value;
312       for (int64_t rowno = 0; rowno < rows; ++rowno) {
313         column_result = reference_reduction_function(column_result,
314                                                      input_data(rowno, colno));
315       }
316       expected[colno] = column_result;
317     }
318 
319     ComputeAndCompareGeneric<NativeT>(
320         &builder, absl::Span<const NativeT>(expected.get(), cols),
321         {input_global_data.get()});
322   }
323 
RunVectorizedReduceTest(const std::function<XlaComputation (PrimitiveType,XlaBuilder *)> & reduction_function_generator_for_type,const std::function<float (float,float)> & reference_reduction_function_for_floats,const std::function<int32_t (int32_t,int32_t)> & reference_reduction_function_for_ints,const std::function<uint32_t (uint32_t,uint32_t)> & reference_reduction_function_for_uints,float floating_point_identity,int32_t signed_int_identity,uint32_t unsigned_int_identity)324   void RunVectorizedReduceTest(
325       const std::function<XlaComputation(PrimitiveType, XlaBuilder*)>&
326           reduction_function_generator_for_type,
327       const std::function<float(float, float)>&
328           reference_reduction_function_for_floats,
329       const std::function<int32_t(int32_t, int32_t)>&
330           reference_reduction_function_for_ints,
331       const std::function<uint32_t(uint32_t, uint32_t)>&
332           reference_reduction_function_for_uints,
333       float floating_point_identity, int32_t signed_int_identity,
334       uint32_t unsigned_int_identity) {
335     // Float version
336     RunVectorizedReduceTestForType<float>(
337         [&](XlaBuilder* builder) {
338           return reduction_function_generator_for_type(F32, builder);
339         },
340         reference_reduction_function_for_floats, floating_point_identity);
341 
342     // Signed int version
343     RunVectorizedReduceTestForType<int32_t>(
344         [&](XlaBuilder* builder) {
345           return reduction_function_generator_for_type(S32, builder);
346         },
347         reference_reduction_function_for_ints, signed_int_identity);
348 
349     // Unsigned int version
350     RunVectorizedReduceTestForType<uint32_t>(
351         [&](XlaBuilder* builder) {
352           return reduction_function_generator_for_type(U32, builder);
353         },
354         reference_reduction_function_for_uints, unsigned_int_identity);
355   }
356 
357   Literal literal_2d_;
358   Literal literal_3d_;
359   uint32_t seed_ = 0xdeadbeef;
360 };
361 
XLA_TEST_F(ReduceTest,ReduceR1_0_F32_To_R0)362 XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); }
XLA_TEST_F(ReduceTest,ReduceR1_1_F32_To_R0)363 XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); }
XLA_TEST_F(ReduceTest,ReduceR1_2_F32_To_R0)364 XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); }
XLA_TEST_F(ReduceTest,ReduceR1_16_F32_To_R0)365 XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); }
XLA_TEST_F(ReduceTest,ReduceR1_128_F32_To_R0)366 XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); }
XLA_TEST_F(ReduceTest,ReduceR1_129_F32_To_R0)367 XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); }
XLA_TEST_F(ReduceTest,ReduceR1_240_F32_To_R0)368 XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
XLA_TEST_F(ReduceTest,ReduceR1_256_F32_To_R0)369 XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); }
XLA_TEST_F(ReduceTest,ReduceR1_1024_F32_To_R0)370 XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); }
XLA_TEST_F(ReduceTest,ReduceR1_2048_F32_To_R0)371 XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); }
XLA_TEST_F(ReduceTest,ReduceR1_16K_F32_To_R0)372 XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_16KP1_F32_To_R0)373 XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) {
374   RunR1ToR0Test(16 * 1024 + 1);
375 }
XLA_TEST_F(ReduceTest,ReduceR1_64K_F32_To_R0)376 XLA_TEST_F(ReduceTest, ReduceR1_64K_F32_To_R0) { RunR1ToR0Test(64 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_1M_F32_To_R0)377 XLA_TEST_F(ReduceTest, ReduceR1_1M_F32_To_R0) { RunR1ToR0Test(1024 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_16M_F32_To_R0)378 XLA_TEST_F(ReduceTest, ReduceR1_16M_F32_To_R0) { RunR1ToR0Test(4096 * 4096); }
379 
XLA_TEST_F(ReduceTest,ReduceR2_0x0_To_R0)380 XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_0x2_To_R0)381 XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_1x1_To_R0)382 XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R0) { RunR2ToR0Test(1, 1); }
XLA_TEST_F(ReduceTest,ReduceR2_2x0_To_R0)383 XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R0) { RunR2ToR0Test(2, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_2x2_To_R0)384 XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R0) { RunR2ToR0Test(2, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_8x8_To_R0)385 XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R0) { RunR2ToR0Test(8, 8); }
XLA_TEST_F(ReduceTest,ReduceR2_9x9_To_R0)386 XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R0) { RunR2ToR0Test(9, 9); }
XLA_TEST_F(ReduceTest,ReduceR2_50x111_To_R0)387 XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R0) { RunR2ToR0Test(50, 111); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_To_R0)388 XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R0) { RunR2ToR0Test(111, 50); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_01_To_R0)389 XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R0) {
390   RunR2ToR0Test(111, 50, 0, 1);
391 }
XLA_TEST_F(ReduceTest,ReduceR2_1024x1024_To_R0)392 XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R0) { RunR2ToR0Test(1024, 1024); }
XLA_TEST_F(ReduceTest,ReduceR2_1000x1500_To_R0)393 XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R0) { RunR2ToR0Test(1000, 1500); }
394 
395 // Disabled due to b/33245142. Failed on 2016-11-30.
396 // XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R1) { RunR2ToR1Test(0, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_0x2_To_R1)397 XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R1) { RunR2ToR1Test(0, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_1x1_To_R1)398 XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R1) { RunR2ToR1Test(1, 1); }
399 // Disabled due to b/33245142. Failed on 2016-11-30.
400 // XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R1) { RunR2ToR1Test(2, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_2x2_To_R1)401 XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R1) { RunR2ToR1Test(2, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_8x8_To_R1)402 XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R1) { RunR2ToR1Test(8, 8); }
XLA_TEST_F(ReduceTest,ReduceR2_9x9_To_R1)403 XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R1) { RunR2ToR1Test(9, 9); }
XLA_TEST_F(ReduceTest,ReduceR2_50x111_To_R1)404 XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R1) { RunR2ToR1Test(50, 111); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_To_R1)405 XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R1) { RunR2ToR1Test(111, 50); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_01_To_R1)406 XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) {
407   RunR2ToR1Test(111, 50, 0, 1);
408 }
XLA_TEST_F(ReduceTest,ReduceR2_1024x1024_To_R1)409 XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); }
XLA_TEST_F(ReduceTest,ReduceR2_1000x1500_To_R1)410 XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); }
411 
XLA_TEST_F(ReduceTest,AndReduceAllOnesR1_10_Pred)412 XLA_TEST_F(ReduceTest, AndReduceAllOnesR1_10_Pred) {
413   constexpr int element_count = 10;
414   std::vector<int> input(element_count, 1);
415   RunR1ToR0PredTest(/*and_reduce=*/true, input);
416 }
417 
XLA_TEST_F(ReduceTest,AndReduceOnesAndZerosR1_10_Pred)418 XLA_TEST_F(ReduceTest, AndReduceOnesAndZerosR1_10_Pred) {
419   constexpr int element_count = 10;
420   std::vector<int> input(element_count);
421   for (int i = 0; i < element_count; ++i) {
422     input[i] = i % 2;
423   }
424   RunR1ToR0PredTest(/*and_reduce=*/true, input);
425 }
426 
XLA_TEST_F(ReduceTest,OrReduceAllOnesR1_10_Pred)427 XLA_TEST_F(ReduceTest, OrReduceAllOnesR1_10_Pred) {
428   constexpr int element_count = 10;
429   std::vector<int> input(element_count, 1);
430   RunR1ToR0PredTest(/*and_reduce=*/false, input);
431 }
432 
XLA_TEST_F(ReduceTest,OrReduceOnesAndZerosR1_10_Pred)433 XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) {
434   constexpr int element_count = 10;
435   std::vector<int> input(element_count);
436   for (int i = 0; i < element_count; ++i) {
437     input[i] = i % 2;
438   }
439   RunR1ToR0PredTest(/*and_reduce=*/false, input);
440 }
441 
XLA_TEST_F(ReduceTest,ReduceElementwiseR2_111x50_To_R1)442 XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
443   const int64_t rows = 111, cols = 50;
444 
445   XlaBuilder builder(TestName());
446   XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
447   const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
448   auto input = Parameter(&builder, 0, input_shape, "input");
449   auto zero = ConstantR0<float>(&builder, 0.0);
450   auto log_ = Log(input);
451   Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0});
452 
453   Array2D<float> input_data(rows, cols);
454   input_data.FillRandom(3.14f, 0.04);
455   Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
456   input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
457   std::unique_ptr<GlobalData> input_global_data =
458       client_->TransferToServer(input_literal).value();
459 
460   std::vector<float> expected;
461   expected.reserve(cols);
462   for (int64_t colno = 0; colno < cols; ++colno) {
463     float column_sum = 0;
464     for (int64_t rowno = 0; rowno < rows; ++rowno) {
465       column_sum += std::log(input_data(rowno, colno));
466     }
467     expected.push_back(column_sum);
468   }
469   ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
470                              ErrorSpec(0.01, 1e-4));
471 }
472 
XLA_TEST_F(ReduceTest,TransposeAndReduceElementwiseR2_111x50_To_R1)473 XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
474   const int64_t rows = 111, cols = 50;
475 
476   XlaBuilder builder(TestName());
477   XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
478   const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
479   auto input = Parameter(&builder, 0, input_shape, "input");
480   auto zero = ConstantR0<float>(&builder, 0.0);
481   auto log_ = Log(input);
482   auto transpose = Transpose(log_, {1, 0});
483   Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1});
484 
485   Array2D<float> input_data(rows, cols);
486   input_data.FillRandom(3.14f, 0.04);
487   Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
488   input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
489   std::unique_ptr<GlobalData> input_global_data =
490       client_->TransferToServer(input_literal).value();
491 
492   std::vector<float> expected;
493   expected.reserve(cols);
494   for (int64_t colno = 0; colno < cols; ++colno) {
495     float column_sum = 0;
496     for (int64_t rowno = 0; rowno < rows; ++rowno) {
497       column_sum += std::log(input_data(rowno, colno));
498     }
499     expected.push_back(column_sum);
500   }
501   ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
502                              ErrorSpec(0.01, 1e-4));
503 }
504 
505 // Test that algebraic simplifier does not incorrectly fold a transpose into a
506 // reduction operation.
XLA_TEST_F(ReduceTest,TransposeAndReduceR3_12x111x50_To_R2)507 XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
508   XlaBuilder builder(TestName());
509   XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
510   const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50});
511   XlaOp input = Parameter(&builder, 0, input_shape, "input");
512   XlaOp zero = ConstantR0<float>(&builder, 0.0);
513   XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2});
514   Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
515 
516   TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape));
517 
518   ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4));
519 }
520 
XLA_TEST_F(ReduceTest,Reshape_111x2x25Reduce_111x50_To_R1)521 XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
522   const int64_t rows = 111, cols = 50;
523 
524   XlaBuilder builder(TestName());
525   XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
526   const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2});
527   auto input = Parameter(&builder, 0, input_shape, "input");
528   auto zero = ConstantR0<float>(&builder, 0.0);
529   auto log_ = Tanh(input);
530   auto reshape = Reshape(log_, {rows, cols});
531   Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0});
532 
533   Array3D<float> input_data(rows, 2, cols / 2);
534   input_data.FillRandom(3.14f, 0.04);
535   Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data);
536   std::unique_ptr<GlobalData> input_global_data =
537       client_->TransferToServer(input_literal).value();
538 
539   std::vector<float> expected;
540   expected.reserve(cols);
541   for (int64_t major = 0; major < 2; ++major) {
542     for (int64_t colno = 0; colno < cols / 2; ++colno) {
543       float column_sum = 0;
544       for (int64_t rowno = 0; rowno < rows; ++rowno) {
545         column_sum += std::tanh(input_data(rowno, major, colno));
546       }
547       expected.push_back(column_sum);
548     }
549   }
550   ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
551                              ErrorSpec(0.01, 1e-4));
552 }
553 
554 struct BoundsLayout {
555   std::vector<int64_t> bounds;
556   std::vector<int64_t> layout;
557   std::vector<int64_t> reduce_dims;
558 };
559 
PrintTo(const BoundsLayout & spec,std::ostream * os)560 void PrintTo(const BoundsLayout& spec, std::ostream* os) {
561   *os << absl::StrFormat("R%uToR%u%s_%s_Reduce%s", spec.bounds.size(),
562                          spec.bounds.size() - spec.reduce_dims.size(),
563                          absl::StrJoin(spec.bounds, "x"),
564                          absl::StrJoin(spec.layout, ""),
565                          absl::StrJoin(spec.reduce_dims, ""));
566 }
567 
568 // Add-reduces a broadcasted scalar matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,AddReduce2DScalarToR0)569 XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
570   XlaBuilder builder(TestName());
571   auto add = CreateScalarAddComputation(F32, &builder);
572   auto scalar = ConstantR0<float>(&builder, 42.0);
573   auto broadcasted = Broadcast(scalar, {500, 500});
574   Reduce(broadcasted, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
575 
576   float expected = 42.0f * static_cast<float>(500 * 500);
577   ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
578 }
579 
580 // Max-reduces a broadcasted scalar matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MaxReduce2DScalarToR0)581 XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) {
582   XlaBuilder builder(TestName());
583   auto max = CreateScalarMaxComputation(F32, &builder);
584   auto scalar = ConstantR0<float>(&builder, 42.0);
585   auto broadcasted = Broadcast(scalar, {500, 500});
586   Reduce(broadcasted, ConstantR0<float>(&builder, 0.0f), max, {0, 1});
587 
588   float expected = 42.0f;
589   ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
590 }
591 
592 // Max-reduces a matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MaxReduce2DToR0)593 XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
594   XlaBuilder builder(TestName());
595   auto max = CreateScalarMaxComputation(F32, &builder);
596   Array2D<float> input(300, 250);
597   input.FillRandom(214.0f);
598   auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
599   Reduce(ConstantLiteral(&builder, input_literal),
600          ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
601   auto input_max = FLT_MIN;
602   input.Each(
603       [&](int64_t, int64_t, float* v) { input_max = std::max(input_max, *v); });
604   ComputeAndCompareR0<float>(&builder, input_max, {}, ErrorSpec(0.0001));
605 }
606 
607 // Min-reduces matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MinReduce2DToR0)608 XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
609   XlaBuilder builder(TestName());
610   auto min = CreateScalarMinComputation(F32, &builder);
611   Array2D<float> input(150, 130);
612   input.FillRandom(214.0f);
613   auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
614   Reduce(ConstantLiteral(&builder, input_literal),
615          ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
616 
617   auto input_min = FLT_MAX;
618   input.Each(
619       [&](int64_t, int64_t, float* v) { input_min = std::min(input_min, *v); });
620   ComputeAndCompareR0<float>(&builder, input_min, {}, ErrorSpec(0.0001));
621 }
622 
XLA_TEST_F(ReduceTest,UnsignedInt_MinReduce)623 XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
624   XlaBuilder builder(TestName());
625   Array2D<uint32_t> input({{1}, {2}});
626   auto min = CreateScalarMinComputation(U32, &builder);
627   auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
628   auto initial_value =
629       ConstantR0<uint32_t>(&builder, std::numeric_limits<uint32_t>::max());
630 
631   Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1});
632   ComputeAndCompareR0<uint32_t>(&builder, 1, {});
633 }
634 
XLA_TEST_F(ReduceTest,UnsignedInt_MaxReduce)635 XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
636   XlaBuilder builder(TestName());
637   Array2D<uint32_t> input({{1}, {2}});
638   auto max = CreateScalarMaxComputation(U32, &builder);
639   auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
640   auto initial_value =
641       ConstantR0<uint32_t>(&builder, std::numeric_limits<uint32_t>::min());
642 
643   Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1});
644   ComputeAndCompareR0<uint32_t>(&builder, 2, {});
645 }
646 
647 // Reduces a matrix among dimension 1.
XLA_TEST_F(ReduceTest,Reduce2DAmong1)648 XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
649   XlaBuilder builder(TestName());
650   auto m = ConstantLiteral(&builder, literal_2d_);
651   auto add = CreateScalarAddComputation(F32, &builder);
652   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
653 
654   std::vector<float> expected = {6.f, 15.f};
655   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
656 }
657 
XLA_TEST_F(ReduceTest,Reduce2DAmong0and1)658 XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
659   // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
660   XlaBuilder builder(TestName());
661   auto m = ConstantLiteral(&builder, literal_2d_);
662   auto add = CreateScalarAddComputation(F32, &builder);
663   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
664 
665   ComputeAndCompareR0<float>(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4));
666 }
667 
668 // Tests 2D matrix ReduceToRow operation.
XLA_TEST_F(ReduceTest,Reduce2DAmongY)669 XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
670   XlaBuilder builder("reduce_among_y");
671   auto m = ConstantLiteral(&builder, literal_2d_);
672   auto add = CreateScalarAddComputation(F32, &builder);
673   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
674 
675   std::vector<float> expected = {5.f, 7.f, 9.f};
676   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
677 }
678 
XLA_TEST_F(ReduceTest,ReduceR3AmongDims_1_2)679 XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
680   XlaBuilder builder(TestName());
681   auto m = ConstantLiteral(&builder, literal_3d_);
682   auto add = CreateScalarAddComputation(F32, &builder);
683   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2});
684 
685   std::vector<float> expected = {21.f, 21.f, 21.f, 21.f};
686   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
687 }
688 
XLA_TEST_F(ReduceTest,ReduceR3AmongDims_0_1)689 XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
690   XlaBuilder builder(TestName());
691   auto m = ConstantLiteral(&builder, literal_3d_);
692   auto add = CreateScalarAddComputation(F32, &builder);
693   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
694 
695   std::vector<float> expected = {20.f, 28.f, 36.f};
696   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
697 }
698 
XLA_TEST_F(ReduceTest,ReduceR3ToR0)699 XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
700   XlaBuilder builder(TestName());
701   auto m = ConstantLiteral(&builder, literal_3d_);
702   auto add = CreateScalarAddComputation(F32, &builder);
703   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2});
704 
705   float expected = 21.0f * 4.0;
706   ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
707 }
708 
XLA_TEST_F(ReduceTest,ReduceR3AmongDim0)709 XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
710   XlaBuilder builder(TestName());
711   auto m = ConstantLiteral(&builder, literal_3d_);
712   auto add = CreateScalarAddComputation(F32, &builder);
713   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
714 
715   // clang-format off
716   Array2D<float> expected({
717       {4.f, 8.f, 12.f},
718       {16.f, 20.f, 24.f},
719   });
720   // clang-format on
721   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
722 }
723 
XLA_TEST_F(ReduceTest,ReduceR3AmongDim1)724 XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
725   XlaBuilder builder(TestName());
726   auto m = ConstantLiteral(&builder, literal_3d_);
727   auto add = CreateScalarAddComputation(F32, &builder);
728   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
729 
730   // clang-format off
731   Array2D<float> expected({
732       {5.f, 7.f, 9.f},
733       {5.f, 7.f, 9.f},
734       {5.f, 7.f, 9.f},
735       {5.f, 7.f, 9.f},
736   });
737   // clang-format on
738   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
739 }
740 
XLA_TEST_F(ReduceTest,ReduceR3AmongDim2)741 XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
742   XlaBuilder builder(TestName());
743   auto m = ConstantLiteral(&builder, literal_3d_);
744   auto add = CreateScalarAddComputation(F32, &builder);
745   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2});
746 
747   // clang-format off
748   Array2D<float> expected({
749       {6.f, 15.f},
750       {6.f, 15.f},
751       {6.f, 15.f},
752       {6.f, 15.f},
753   });
754   // clang-format on
755   ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
756 }
757 
XLA_TEST_F(ReduceTest,VectorizedReduce_Add)758 XLA_TEST_F(ReduceTest, VectorizedReduce_Add) {
759   RunVectorizedReduceTest(
760       static_cast<FuncGeneratorForType>(CreateScalarAddComputation),
761       [](float a, float b) { return a + b; },
762       [](int32_t a, int32_t b) {
763         return static_cast<int32_t>(static_cast<uint32_t>(a) +
764                                     static_cast<uint32_t>(b));
765       },
766       [](uint32_t a, uint32_t b) { return a + b; }, 0.0, 0, 0);
767 }
768 
XLA_TEST_F(ReduceTest,VectorizedReduce_Multiply)769 XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) {
770   RunVectorizedReduceTest(
771       static_cast<FuncGeneratorForType>(CreateScalarMultiplyComputation),
772       [](float a, float b) { return a * b; },
773       [](int32_t a, int32_t b) {
774         return static_cast<int32_t>(static_cast<uint32_t>(a) *
775                                     static_cast<uint32_t>(b));
776       },
777       [](uint32_t a, uint32_t b) { return a * b; }, 1.0, 1, 1);
778 }
779 
XLA_TEST_F(ReduceTest,VectorizedReduce_Max)780 XLA_TEST_F(ReduceTest, VectorizedReduce_Max) {
781   RunVectorizedReduceTest(
782       static_cast<FuncGeneratorForType>(CreateScalarMaxComputation),
783       [](float a, float b) { return std::max(a, b); },
784       [](int32_t a, int32_t b) { return std::max(a, b); },
785       [](uint32_t a, uint32_t b) { return std::max(a, b); },
786       std::numeric_limits<float>::min(), std::numeric_limits<int32_t>::min(),
787       std::numeric_limits<uint32_t>::min());
788 }
789 
XLA_TEST_F(ReduceTest,VectorizedReduce_Min)790 XLA_TEST_F(ReduceTest, VectorizedReduce_Min) {
791   RunVectorizedReduceTest(
792       static_cast<FuncGeneratorForType>(CreateScalarMinComputation),
793       [](float a, float b) { return std::min(a, b); },
794       [](int32_t a, int32_t b) { return std::min(a, b); },
795       [](uint32_t a, uint32_t b) { return std::min(a, b); },
796       std::numeric_limits<float>::max(), std::numeric_limits<int32_t>::max(),
797       std::numeric_limits<uint32_t>::max());
798 }
799 
XLA_TEST_F(ReduceTest,VectorizedReduce_BooleanAnd)800 XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) {
801   RunVectorizedReduceTestForType<bool>(
802       static_cast<FuncGenerator>([](XlaBuilder* builder) {
803         return CreateScalarAndComputation(PRED, builder);
804       }),
805       [](bool a, bool b) { return a && b; }, true);
806 }
807 
XLA_TEST_F(ReduceTest,VectorizedReduce_BooleanOr)808 XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) {
809   RunVectorizedReduceTestForType<bool>(
810       static_cast<FuncGenerator>([](XlaBuilder* builder) {
811         return CreateScalarOrComputation(PRED, builder);
812       }),
813       [](bool a, bool b) { return a || b; }, false);
814 }
815 
816 class ReduceR3ToR2Test : public ReduceTest,
817                          public ::testing::WithParamInterface<BoundsLayout> {};
818 
XLA_TEST_P(ReduceR3ToR2Test,ReduceR3ToR2)819 XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
820   XlaBuilder builder(TestName());
821   const auto& bounds = GetParam().bounds;
822   Array3D<float> input_array(bounds[0], bounds[1], bounds[2]);
823   //  input_array.FillRandom(3.14f, 0.05);
824   input_array.Fill(1.0f);
825 
826   auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
827   input_literal =
828       input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout));
829   std::unique_ptr<GlobalData> input_data =
830       client_->TransferToServer(input_literal).value();
831 
832   auto input_activations =
833       Parameter(&builder, 0, input_literal.shape(), "input");
834   XlaComputation add = CreateScalarAddComputation(F32, &builder);
835   Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
836          GetParam().reduce_dims);
837 
838   auto expected =
839       ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims,
840                                   [](float a, float b) { return a + b; });
841 
842   ComputeAndCompareR2<float>(&builder, *expected, {input_data.get()},
843                              ErrorSpec(1e-3, 1e-3));
844 }
845 
846 INSTANTIATE_TEST_CASE_P(
847     ReduceR3ToR2Test_Instantiation, ReduceR3ToR2Test,
848     // Specifies (shape, layout, reduction dimensions).
849     ::testing::Values(BoundsLayout{{4, 8, 128}, {2, 1, 0}, {0}},
850                       BoundsLayout{{4, 8, 128}, {2, 1, 0}, {1}},
851                       BoundsLayout{{4, 8, 128}, {2, 1, 0}, {2}},
852                       // These should be simplified into a reshape.
853                       BoundsLayout{{1, 21, 43}, {2, 1, 0}, {0}},
854                       BoundsLayout{{1, 1, 1}, {2, 1, 0}, {0}},
855                       BoundsLayout{{1, 1, 1}, {2, 1, 0}, {1}},
856                       BoundsLayout{{1, 1, 1}, {2, 1, 0}, {2}},
857                       BoundsLayout{{8, 16, 24}, {0, 1, 2}, {0}},
858                       BoundsLayout{{8, 16, 24}, {0, 1, 2}, {1}},
859                       BoundsLayout{{8, 16, 24}, {0, 1, 2}, {2}},
860                       BoundsLayout{{5, 10, 250}, {2, 1, 0}, {0}},
861                       BoundsLayout{{5, 10, 250}, {2, 1, 0}, {1}},
862                       BoundsLayout{{5, 10, 250}, {2, 1, 0}, {2}},
863                       BoundsLayout{{8, 16, 256}, {2, 1, 0}, {0}},
864                       BoundsLayout{{8, 16, 256}, {2, 1, 0}, {1}},
865                       BoundsLayout{{8, 16, 256}, {2, 1, 0}, {2}},
866                       BoundsLayout{{2, 300, 784}, {2, 1, 0}, {2}},
867                       BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
868                       BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
869 
XLA_TEST_F(ReduceTest,OperationOnConstantAsInitValue)870 XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
871   XlaBuilder builder(TestName());
872   XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder);
873 
874   auto a = ConstantR0<float>(&builder, 2.0f);
875   auto a2 = Abs(a);
876 
877   Literal b_literal = LiteralUtil::CreateR1<float>({1.0f, 4.0f});
878   std::unique_ptr<GlobalData> b_data =
879       client_->TransferToServer(b_literal).value();
880   auto b = Parameter(&builder, 0, b_literal.shape(), "b");
881   Reduce(b, a2, max_f32, {0});
882 
883   ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
884 }
885 
XLA_TEST_F(ReduceTest,ReduceAndPredR2_128x64_To_R1)886 XLA_TEST_F(ReduceTest, ReduceAndPredR2_128x64_To_R1) {
887   RunR2ToR1PredTest</*cols=64*/ 64>(/*and_reduce=true*/ true, /*rows=128*/ 128);
888 }
XLA_TEST_F(ReduceTest,ReduceOrPredR2_64x32_To_R1)889 XLA_TEST_F(ReduceTest, ReduceOrPredR2_64x32_To_R1) {
890   RunR2ToR1PredTest</*cols=32*/ 32>(/*and_reduce=false*/ false, /*rows=64*/ 64);
891 }
892 
893 // Tests reductions with different initial values.  There's no test macro that
894 // combines TYPED_TEST and TYPED_P, so we have to do it manually.
895 class ReduceInitializerTest : public ReduceTest {
896  protected:
897   template <typename T>
DoTest(T initializer,int num_elems)898   void DoTest(T initializer, int num_elems) {
899     XlaBuilder builder(TestName());
900     XlaComputation max_fn = CreateScalarMaxComputation(
901         primitive_util::NativeToPrimitiveType<T>(), &builder);
902 
903     auto init = ConstantR0<T>(&builder, initializer);
904     std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
905     auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
906     auto input_data = client_->TransferToServer(input_literal).value();
907     Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn,
908            {0});
909 
910     ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()});
911   }
912 };
913 
XLA_TEST_F(ReduceInitializerTest,U8Small)914 XLA_TEST_F(ReduceInitializerTest, U8Small) { DoTest<uint8_t>(42, 2); }
915 
XLA_TEST_F(ReduceInitializerTest,U8BigPowerOf2)916 XLA_TEST_F(ReduceInitializerTest, U8BigPowerOf2) { DoTest<uint8_t>(42, 4096); }
917 
XLA_TEST_F(ReduceInitializerTest,U8InitializerBigNonPowerOf2)918 XLA_TEST_F(ReduceInitializerTest, U8InitializerBigNonPowerOf2) {
919   DoTest<uint8_t>(42, 4095);
920 }
921 
XLA_TEST_F(ReduceInitializerTest,U64InitializerZero)922 XLA_TEST_F(ReduceInitializerTest, U64InitializerZero) {
923   DoTest<uint64_t>(0, 1024);
924 }
925 
XLA_TEST_F(ReduceInitializerTest,U64InitializerOne)926 XLA_TEST_F(ReduceInitializerTest, U64InitializerOne) {
927   DoTest<uint64_t>(1, 1024);
928 }
929 
XLA_TEST_F(ReduceInitializerTest,U64InitializerBigValue)930 XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) {
931   DoTest<uint64_t>(1234556789123, 1024);
932 }
933 
934 // Test the operational semantic that the init value is passed on the lhs for
935 // reduces. Can be tested by performing an "identity" reduce (that simply
936 // returns one of the parameters). In this case, we return the rhs, which for
937 // a 1D array with one element, should not be the init value.
XLA_TEST_F(ReduceTest,ReduceIdentity)938 XLA_TEST_F(ReduceTest, ReduceIdentity) {
939   XlaBuilder builder(TestName());
940   Shape single_float = ShapeUtil::MakeShape(F32, {});
941   Parameter(&builder, 0, single_float, "lhs-unused");
942   Parameter(&builder, 1, single_float, "rhs-used");
943   auto computation_status = builder.Build();
944   TF_ASSERT_OK(computation_status.status());
945 
946   Shape operand_shape = ShapeUtil::MakeShape(F32, {1});
947   Reduce(Parameter(&builder, 0, operand_shape, "operand"),
948          Parameter(&builder, 1, single_float, "init"),
949          computation_status.ValueOrDie(), {0});
950 
951   float operand[] = {42.0f};
952   float init = 58.5f;
953   float expected = 42.0f;
954   Literal input_literal = LiteralUtil::CreateR1<float>(operand);
955   std::unique_ptr<GlobalData> input_global_data =
956       client_->TransferToServer(input_literal).value();
957   Literal input_literal2 = LiteralUtil::CreateR0<float>(init);
958   std::unique_ptr<GlobalData> input_global_data2 =
959       client_->TransferToServer(input_literal2).value();
960   ComputeAndCompareR0<float>(
961       &builder, expected, {input_global_data.get(), input_global_data2.get()},
962       ErrorSpec(0.0001));
963 }
964 
XLA_TEST_F(ReduceTest,AndReduceU64)965 XLA_TEST_F(ReduceTest, AndReduceU64) {
966   XlaBuilder builder(TestName());
967   Array2D<uint64_t> initializer = {
968       {0x123456789ABCDEF0ULL, 0x3BCDEF12A4567890ULL},
969       {0XFFFFFFFFFFFFFFD6ULL, 101},
970       {1, 0XFFFFFFFFFFFFFFFFULL}};
971   auto reducer = CreateScalarAndComputation(U64, &builder);
972   auto m = ConstantR2FromArray2D(&builder, initializer);
973   Reduce(m, ConstantR0<uint64_t>(&builder, 0xFFFFFFFFFFFFFFFFLL), reducer, {1});
974 
975   std::vector<uint64_t> expected = {0x1204461080145890LL, 68, 1};
976   ComputeAndCompareR1<uint64_t>(&builder, expected, {});
977 }
978 
XLA_TEST_F(ReduceTest,OrReduceU64)979 XLA_TEST_F(ReduceTest, OrReduceU64) {
980   XlaBuilder builder(TestName());
981   Array2D<uint64_t> initializer = {
982       {0x123456789ABCDEF0ULL, 0x3BCDEF12A4567890ULL},
983       {0xFFFFFFFFFFFFFFD6ULL, 101},
984       {1, 0xCAFEBEEFABABABABULL}};
985   auto reducer = CreateScalarOrComputation(U64, &builder);
986   auto m = ConstantR2FromArray2D(&builder, initializer);
987   Reduce(m, ConstantR0<uint64_t>(&builder, 0), reducer, {1});
988 
989   std::vector<uint64_t> expected = {
990       0X3BFDFF7ABEFEFEF0ULL, 0XFFFFFFFFFFFFFFF7ULL, 0xCAFEBEEFABABABABULL};
991   ComputeAndCompareR1<uint64_t>(&builder, expected, {});
992 }
993 
XLA_TEST_F(ReduceTest,R0ReduceInDisguise)994 XLA_TEST_F(ReduceTest, R0ReduceInDisguise) {
995   XlaBuilder builder(TestName());
996   XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
997   constexpr int element_count = 127;
998   const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count, 1});
999   auto input = Parameter(&builder, 0, input_shape, "input");
1000   auto zero = ConstantR0<float>(&builder, 0.0);
1001   Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
1002 
1003   Array2D<float> input_data(element_count, 1);
1004   input_data.FillRandom(3.0f);
1005   Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
1006   std::unique_ptr<GlobalData> input_global_data =
1007       client_->TransferToServer(input_literal).value();
1008 
1009   float expected = absl::c_accumulate(input_data, 0.0f);
1010   ComputeAndCompareR1<float>(&builder, {expected}, {input_global_data.get()},
1011                              ErrorSpec(0.001));
1012 }
1013 
1014 class ReduceHloTest : public HloTestBase {};
1015 
XLA_TEST_F(ReduceHloTest,HandleReductionToVectorAndOtherReduction)1016 XLA_TEST_F(ReduceHloTest, HandleReductionToVectorAndOtherReduction) {
1017   absl::string_view hlo_string = R"(
1018   HloModule HandleReductionToVectorAndOtherReduction
1019 
1020   add {
1021     acc = f32[] parameter(1)
1022     op = f32[] parameter(0)
1023     ROOT out = f32[] add(acc, op)
1024   }
1025 
1026   ENTRY main {
1027     iota.3 = s32[2,2]{1,0} iota(), iota_dimension=0
1028     iota.2 = s32[2,2]{1,0} iota(), iota_dimension=1
1029     compare.0 = pred[2,2]{1,0} compare(iota.3, iota.2), direction=EQ
1030     broadcast = pred[2,2,2,2]{3,2,1,0} broadcast(compare.0), dimensions={2,3}
1031     param_0.16 = f32[2,2,2,2]{3,2,1,0} parameter(0)
1032     constant_4 = f32[] constant(0)
1033     broadcast.9 = f32[2,2,2,2]{3,2,1,0} broadcast(constant_4), dimensions={}
1034     select.0 = f32[2,2,2,2]{3,2,1,0} select(broadcast, param_0.16, broadcast.9)
1035     reduce.1 = f32[2,2,2]{2,1,0} reduce(select.0, constant_4), dimensions={2},
1036                to_apply=add
1037     abs.0 = f32[2,2,2]{2,1,0} abs(reduce.1)
1038     log.0 = f32[2,2,2]{2,1,0} log(abs.0)
1039     reduce.0 = f32[2,2]{1,0} reduce(log.0, constant_4), dimensions={2},
1040                to_apply=add
1041     ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce.0, reduce.1)
1042   }
1043   )";
1044   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1045 }
1046 
1047 class VariadicReduceTest : public HloTestBase {};
1048 
XLA_TEST_F(VariadicReduceTest,Reduce_R3x2_to_R2x2_simple)1049 XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R2x2_simple) {
1050   absl::string_view hlo_string = R"(
1051   HloModule Reduce_R3x2_to_R1x2_simple
1052 
1053   add {
1054     op1 = f32[] parameter(0)
1055     op2 = f32[] parameter(1)
1056     acc1 = f32[] parameter(2)
1057     acc2 = f32[] parameter(3)
1058     out1 = f32[] add(acc1, op1)
1059     out2 = f32[] add(acc2, op2)
1060     ROOT result = (f32[], f32[]) tuple(out1, out2)
1061   }
1062 
1063   ENTRY main {
1064     inp1 = f32[3,4,5] parameter(0)
1065     inp2 = f32[3,4,5] parameter(1)
1066     zero = f32[] constant(0)
1067 
1068     ROOT out = (f32[3,5], f32[3,5]) reduce(inp1, inp2, zero, zero),
1069       dimensions={1},
1070       to_apply=add
1071   }
1072 )";
1073 
1074   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1075 }
1076 
XLA_TEST_F(VariadicReduceTest,Reduce_R3x2_to_R1x2_simple)1077 XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R1x2_simple) {
1078   absl::string_view hlo_string = R"(
1079   HloModule Reduce_R3x2_to_R1x2_simple
1080 
1081   add {
1082     op1 = f32[] parameter(0)
1083     op2 = f32[] parameter(1)
1084     acc1 = f32[] parameter(2)
1085     acc2 = f32[] parameter(3)
1086     out1 = f32[] add(acc1, op1)
1087     out2 = f32[] add(acc2, op2)
1088     ROOT result = (f32[], f32[]) tuple(out1, out2)
1089   }
1090 
1091   ENTRY main {
1092     inp1 = f32[10,20,3] parameter(0)
1093     inp2 = f32[10,20,3] parameter(1)
1094     zero = f32[] constant(0)
1095 
1096     ROOT out = (f32[10], f32[10]) reduce(inp1, inp2, zero, zero),
1097       dimensions={1,2},
1098       to_apply=add
1099   }
1100 )";
1101 
1102   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1103 }
1104 
XLA_TEST_F(VariadicReduceTest,Reduce_R1x2_to_R0x2_simple)1105 XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_simple) {
1106   absl::string_view hlo_string = R"(
1107   HloModule Reduce_R1x2_to_R0x2_simple
1108 
1109   add {
1110     op1 = f32[] parameter(0)
1111     op2 = f32[] parameter(1)
1112     acc1 = f32[] parameter(2)
1113     acc2 = f32[] parameter(3)
1114     out1 = f32[] add(acc1, op1)
1115     out2 = f32[] add(acc2, op2)
1116     ROOT result = (f32[], f32[]) tuple(out1, out2)
1117   }
1118 
1119   ENTRY main {
1120     inp1 = f32[100] parameter(0)
1121     inp2 = f32[100] parameter(1)
1122     zero = f32[] constant(0)
1123 
1124     ROOT out = (f32[], f32[]) reduce(inp1, inp2, zero, zero),
1125       dimensions={0},
1126       to_apply=add
1127   }
1128 )";
1129 
1130   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1131 }
1132 
XLA_TEST_F(VariadicReduceTest,Reduce_R1x2_to_R0x2_argmax)1133 XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_argmax) {
1134   absl::string_view hlo_string = R"(
1135     HloModule Reduce_R1x2_to_R0x2_argmax
1136 
1137     argmax {
1138       running_max = f32[] parameter(0)
1139       running_max_idx = u32[] parameter(1)
1140       current_value = f32[] parameter(2)
1141       current_value_idx = u32[] parameter(3)
1142 
1143       current = (f32[], u32[]) tuple(running_max, running_max_idx)
1144       potential = (f32[], u32[]) tuple(current_value, current_value_idx)
1145 
1146       cmp_code = pred[] compare(current_value, running_max), direction=GT
1147 
1148       new_max = f32[] select(cmp_code, current_value, running_max)
1149       new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
1150 
1151       ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
1152     }
1153 
1154     ENTRY main {
1155       input = f32[100] parameter(0)
1156       idxs = u32[100]{0} iota(), iota_dimension=0
1157       zero = f32[] constant(0)
1158       zero_idx = u32[] constant(0)
1159 
1160       ROOT out = (f32[], u32[]) reduce(
1161         input, idxs, zero, zero_idx),
1162         dimensions={0},
1163         to_apply=%argmax
1164     }
1165 )";
1166 
1167   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1168 }
1169 
XLA_TEST_F(VariadicReduceTest,Reduce_R1x2_to_R0x2_argmax_column)1170 XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_argmax_column) {
1171   absl::string_view hlo_string = R"(
1172     HloModule Reduce_R1x2_to_R0x2_argmax
1173 
1174     add {
1175       acc = f32[] parameter(1)
1176       op = f32[] parameter(0)
1177       ROOT out = f32[] add(acc, op)
1178     }
1179 
1180     argmax {
1181       running_max = f32[] parameter(0)
1182       running_max_idx = u32[] parameter(1)
1183       current_value = f32[] parameter(2)
1184       current_value_idx = u32[] parameter(3)
1185 
1186       current = (f32[], u32[]) tuple(running_max, running_max_idx)
1187       potential = (f32[], u32[]) tuple(current_value, current_value_idx)
1188 
1189       cmp_code = pred[] compare(current_value, running_max), direction=GT
1190 
1191       new_max = f32[] select(cmp_code, current_value, running_max)
1192       new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
1193 
1194       ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
1195     }
1196 
1197     ENTRY main {
1198       input = f32[32,128] parameter(0)
1199       idxs = u32[32,128] iota(), iota_dimension=0
1200       zero = f32[] constant(0)
1201       zero_idx = u32[] constant(0)
1202 
1203       ROOT argmax_result = (f32[128], u32[128]) reduce(
1204         input, idxs, zero, zero_idx),
1205         dimensions={0},
1206         to_apply=%argmax
1207     }
1208 )";
1209 
1210   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1211 }
1212 
XLA_TEST_F(VariadicReduceTest,ReduceMultiOutputVariadicAnd)1213 XLA_TEST_F(VariadicReduceTest, ReduceMultiOutputVariadicAnd) {
1214   absl::string_view hlo_string = R"(
1215     HloModule VariadicReduceMultiOutput
1216 
1217     VariadicAnd {
1218       value = pred[] parameter(0)
1219       value_idx = u32[] parameter(1)
1220       current_value = pred[] parameter(2)
1221       current_value_idx = u32[] parameter(3)
1222       ROOT out = (pred[], u32[]) tuple(value, value_idx)
1223     }
1224 
1225     ENTRY CheckBuffer {
1226       test_value = f32[] parameter(0)
1227       buffer = f32[100] parameter(1)
1228       value_broadcast = f32[100] broadcast(test_value), dimensions={}
1229       comparison_result = pred[100] compare(buffer, value_broadcast), direction=EQ
1230       true_constant = pred[] constant(true)
1231 
1232       zero_idx = u32[] constant(0)
1233       idxs = u32[100]{0} iota(), iota_dimension=0
1234       out = (pred[], u32[]) reduce(
1235          comparison_result, idxs, true_constant, zero_idx
1236       ), dimensions={0}, to_apply=VariadicAnd
1237 
1238       ROOT returned = u32[] get-tuple-element(out), index=1
1239     }
1240 )";
1241 
1242   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1243 }
1244 
XLA_TEST_F(VariadicReduceTest,ReduceMultiOutputVariadicDifferentLayout)1245 XLA_TEST_F(VariadicReduceTest, ReduceMultiOutputVariadicDifferentLayout) {
1246   absl::string_view hlo_string = R"(
1247 HloModule ReduceWithLayoutChangeVariadicDifferent
1248 
1249 argmax {
1250   running_max = f32[] parameter(0)
1251   running_max_idx = u32[] parameter(1)
1252   current_value = f32[] parameter(2)
1253   current_value_idx = u32[] parameter(3)
1254 
1255   current = (f32[], u32[]) tuple(running_max, running_max_idx)
1256   potential = (f32[], u32[]) tuple(current_value, current_value_idx)
1257 
1258   cmp_code = pred[] compare(current_value, running_max), direction=GT
1259 
1260   new_max = f32[] select(cmp_code, current_value, running_max)
1261   new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
1262 
1263   ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
1264 }
1265 
1266 ENTRY main {
1267   arg0 = f32[2,3,4,1024]{2,1,0,3}  parameter(0)
1268   idxs = u32[2,3,4,1024]{3,2,1,0}  parameter(1)
1269   constant0 = f32[] constant(0)
1270   constant1 = u32[] constant(0)
1271   ROOT reduce0 = (
1272       f32[2,3,4]{2,1,0},
1273       u32[2,3,4]{1,0,2}
1274     ) reduce(arg0, idxs, constant0,constant1), dimensions={3}, to_apply=argmax
1275 }
1276 )";
1277 
1278   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1279 }
1280 
1281 }  // namespace
1282 }  // namespace xla
1283