1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests that multi-dimensional arrays can be reduced among various
17 // user-provided dimensions.
18 //
19 // Note that comments for these tests are white-box in that they talk about the
20 // default data layout.
21 //
22 // The test space for reductions is the cartesian product of:
23 //
24 // <possible ranks> x
25 // <possible layouts for chosen rank> x
26 // <possible subsets of dimensions in chosen rank>
27
28 #include <stdlib.h>
29
30 #include <algorithm>
31 #include <cmath>
32 #include <functional>
33 #include <memory>
34 #include <random>
35 #include <string>
36 #include <utility>
37 #include <vector>
38
39 #include "absl/algorithm/container.h"
40 #include "absl/strings/str_format.h"
41 #include "absl/strings/str_join.h"
42 #include "absl/types/span.h"
43 #include "tensorflow/compiler/xla/array2d.h"
44 #include "tensorflow/compiler/xla/array4d.h"
45 #include "tensorflow/compiler/xla/client/global_data.h"
46 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
47 #include "tensorflow/compiler/xla/client/local_client.h"
48 #include "tensorflow/compiler/xla/client/xla_builder.h"
49 #include "tensorflow/compiler/xla/client/xla_computation.h"
50 #include "tensorflow/compiler/xla/layout_util.h"
51 #include "tensorflow/compiler/xla/literal_util.h"
52 #include "tensorflow/compiler/xla/reference_util.h"
53 #include "tensorflow/compiler/xla/shape_util.h"
54 #include "tensorflow/compiler/xla/status_macros.h"
55 #include "tensorflow/compiler/xla/statusor.h"
56 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
57 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
58 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
59 #include "tensorflow/compiler/xla/tests/test_macros.h"
60 #include "tensorflow/compiler/xla/util.h"
61 #include "tensorflow/compiler/xla/xla_data.pb.h"
62 #include "tensorflow/core/lib/core/status_test_util.h"
63 #include "tensorflow/core/platform/test.h"
64
65 namespace xla {
66 namespace {
67
68 using FuncGeneratorForType = XlaComputation (*)(PrimitiveType, XlaBuilder*);
69
70 using FuncGenerator = XlaComputation (*)(XlaBuilder*);
71
72 class ReduceTest : public ClientLibraryTestBase {
73 protected:
ReduceTest()74 ReduceTest() {
75 // Implementation note: laid out z >> y >> x by default.
76 // clang-format off
77 literal_2d_ = LiteralUtil::CreateR2<float>({
78 // x0 x1 x2
79 { 1.f, 2.f, 3.f}, // y0
80 { 4.f, 5.f, 6.f}, // y1
81 });
82 literal_3d_ = LiteralUtil::CreateR3Projected<float>({
83 // x0 x1 x2
84 { 1.f, 2.f, 3.f}, // y0
85 { 4.f, 5.f, 6.f}, // y1
86 }, 4);
87 // clang-format on
88 CHECK(ShapeUtil::Equal(
89 literal_3d_.shape(),
90 ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
91 << literal_3d_.shape().ShortDebugString();
92 }
93
94 // Runs an R1 => R0 reduction test with the given number of elements.
RunR1ToR0Test(int64_t element_count)95 void RunR1ToR0Test(int64_t element_count) {
96 XlaBuilder builder(TestName());
97 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
98 const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count});
99 auto input = Parameter(&builder, 0, input_shape, "input");
100 auto zero = ConstantR0<float>(&builder, 0.0);
101 Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
102 std::minstd_rand rng(seed_);
103
104 std::vector<float> input_data(element_count);
105 for (int64_t i = 0; i < element_count; ++i) {
106 input_data[i] = rng() % 3;
107 if (rng() % 2 == 0) {
108 input_data[i] *= -1;
109 }
110 }
111 Literal input_literal =
112 LiteralUtil::CreateR1(absl::MakeConstSpan(input_data));
113 std::unique_ptr<GlobalData> input_global_data =
114 client_->TransferToServer(input_literal).value();
115
116 float expected = absl::c_accumulate(input_data, 0.0f);
117 ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
118 ErrorSpec(0.001));
119 }
120
RunR1ToR0PredTest(bool and_reduce,absl::Span<const int> input_data)121 void RunR1ToR0PredTest(bool and_reduce, absl::Span<const int> input_data) {
122 const int element_count = input_data.size();
123 XlaBuilder builder(TestName());
124 const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count});
125 auto input_par = Parameter(&builder, 0, input_shape, "input");
126 auto pred_values =
127 Eq(input_par, ConstantR1<int>(&builder, element_count, 1));
128 XlaOp init_value;
129 XlaComputation reduce;
130 if (and_reduce) {
131 init_value = ConstantR0<bool>(&builder, true);
132 reduce = CreateScalarAndComputation(PRED, &builder);
133 } else {
134 init_value = ConstantR0<bool>(&builder, false);
135 reduce = CreateScalarOrComputation(PRED, &builder);
136 }
137 Reduce(pred_values, init_value, reduce,
138 /*dimensions_to_reduce=*/{0});
139
140 Literal input_literal = LiteralUtil::CreateR1(input_data);
141 std::unique_ptr<GlobalData> input_global_data =
142 client_->TransferToServer(input_literal).value();
143
144 bool expected = and_reduce;
145 for (bool item : input_data) {
146 if (and_reduce) {
147 expected = expected && item;
148 } else {
149 expected = expected || item;
150 }
151 }
152 ComputeAndCompareR0<bool>(&builder, expected, {input_global_data.get()});
153 }
154
155 // Reduce predicate tensor with dimension rows * cols to dimension cols, to
156 // test the implementation of atomic operations on misaligned small data
157 // types.
158 template <int64_t cols>
RunR2ToR1PredTest(bool and_reduce,int64_t rows,int64_t minor=1,int64_t major=0)159 void RunR2ToR1PredTest(bool and_reduce, int64_t rows, int64_t minor = 1,
160 int64_t major = 0) {
161 XlaBuilder builder(TestName());
162 const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols});
163 auto input = Parameter(&builder, 0, input_shape, "input");
164 auto input_pred = Eq(input, ConstantR0<uint8_t>(&builder, 1));
165
166 XlaOp init_value;
167 XlaComputation reduce_op;
168 if (and_reduce) {
169 init_value = ConstantR0<bool>(&builder, true);
170 reduce_op = CreateScalarAndComputation(PRED, &builder);
171 } else {
172 init_value = ConstantR0<bool>(&builder, false);
173 reduce_op = CreateScalarOrComputation(PRED, &builder);
174 }
175
176 Reduce(input_pred, init_value, reduce_op,
177 /*dimensions_to_reduce=*/{0});
178
179 Array2D<uint8_t> input_data(rows, cols);
180 input_data.FillRandom(0, 1);
181 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
182 input_literal =
183 input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
184 std::unique_ptr<GlobalData> input_global_data =
185 client_->TransferToServer(input_literal).value();
186
187 std::array<bool, cols> expected;
188 for (int64_t colno = 0; colno < cols; ++colno) {
189 bool column_sum = and_reduce ? true : false;
190 for (int64_t rowno = 0; rowno < rows; ++rowno) {
191 if (and_reduce) {
192 column_sum = column_sum && input_data(rowno, colno);
193 } else {
194 column_sum = column_sum || input_data(rowno, colno);
195 }
196 }
197 expected[colno] = column_sum;
198 }
199
200 ComputeAndCompareR1<bool>(&builder, expected, {input_global_data.get()});
201 }
202
203 // Runs an R2 => R0 reduction test with the given number of (rows, cols).
RunR2ToR0Test(int64_t rows,int64_t cols,int64_t minor=1,int64_t major=0)204 void RunR2ToR0Test(int64_t rows, int64_t cols, int64_t minor = 1,
205 int64_t major = 0) {
206 XlaBuilder builder(TestName());
207 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
208 const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
209 auto input = Parameter(&builder, 0, input_shape, "input");
210 auto zero = ConstantR0<float>(&builder, 0.0);
211 Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1});
212
213 Array2D<float> input_data(rows, cols);
214 input_data.FillRandom(3.14f, 0.04);
215 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
216 input_literal =
217 input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
218 std::unique_ptr<GlobalData> input_global_data =
219 client_->TransferToServer(input_literal).value();
220
221 float expected = 0.0;
222 for (int64_t rowno = 0; rowno < rows; ++rowno) {
223 for (int64_t colno = 0; colno < cols; ++colno) {
224 expected += input_data(rowno, colno);
225 }
226 }
227 ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
228 ErrorSpec(0.01, 1e-4));
229 }
230
231 // Runs an R2 => R1 reduction test with the given number of (rows, cols).
RunR2ToR1Test(int64_t rows,int64_t cols,int64_t minor=1,int64_t major=0)232 void RunR2ToR1Test(int64_t rows, int64_t cols, int64_t minor = 1,
233 int64_t major = 0) {
234 XlaBuilder builder(TestName());
235 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
236 const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
237 auto input = Parameter(&builder, 0, input_shape, "input");
238 auto zero = ConstantR0<float>(&builder, 0.0);
239 Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
240
241 Array2D<float> input_data(rows, cols);
242 input_data.FillRandom(3.14f, 0.04);
243 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
244 input_literal =
245 input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
246 std::unique_ptr<GlobalData> input_global_data =
247 client_->TransferToServer(input_literal).value();
248
249 std::vector<float> expected;
250 expected.reserve(cols);
251 for (int64_t colno = 0; colno < cols; ++colno) {
252 float column_sum = 0;
253 for (int64_t rowno = 0; rowno < rows; ++rowno) {
254 column_sum += input_data(rowno, colno);
255 }
256 expected.push_back(column_sum);
257 }
258 ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
259 ErrorSpec(0.01, 1e-4));
260 }
261
262 template <typename NativeT>
ComputeAndCompareGeneric(typename std::enable_if<std::is_floating_point<NativeT>::value,XlaBuilder>::type * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)263 void ComputeAndCompareGeneric(
264 typename std::enable_if<std::is_floating_point<NativeT>::value,
265 XlaBuilder>::type* builder,
266 absl::Span<const NativeT> expected,
267 absl::Span<GlobalData* const> arguments) {
268 ComputeAndCompareR1<NativeT>(builder, expected, arguments,
269 ErrorSpec(0.01, 1e-4));
270 }
271
272 template <typename NativeT>
ComputeAndCompareGeneric(typename std::enable_if<std::is_integral<NativeT>::value,XlaBuilder>::type * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)273 void ComputeAndCompareGeneric(
274 typename std::enable_if<std::is_integral<NativeT>::value,
275 XlaBuilder>::type* builder,
276 absl::Span<const NativeT> expected,
277 absl::Span<GlobalData* const> arguments) {
278 ComputeAndCompareR1<NativeT>(builder, expected, arguments);
279 }
280
281 template <typename NativeT>
RunVectorizedReduceTestForType(const std::function<XlaComputation (XlaBuilder *)> & reduction_function_generator,const std::function<NativeT (NativeT,NativeT)> & reference_reduction_function,const NativeT & initial_value)282 void RunVectorizedReduceTestForType(
283 const std::function<XlaComputation(XlaBuilder*)>&
284 reduction_function_generator,
285 const std::function<NativeT(NativeT, NativeT)>&
286 reference_reduction_function,
287 const NativeT& initial_value) {
288 const int rows = 64, cols = 128;
289 const int minor = 1, major = 0;
290 XlaBuilder builder(TestName());
291 XlaComputation reduction_function = reduction_function_generator(&builder);
292 const Shape input_shape = ShapeUtil::MakeShape(
293 xla::primitive_util::NativeToPrimitiveType<NativeT>(), {rows, cols});
294 auto input = Parameter(&builder, 0, input_shape, "input");
295 auto zero = ConstantR0<NativeT>(&builder, initial_value);
296 Reduce(input, zero, reduction_function,
297 /*dimensions_to_reduce=*/{0});
298
299 Array2D<NativeT> input_data(rows, cols);
300 input_data.FillUnique(initial_value);
301 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
302 input_literal =
303 input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
304 std::unique_ptr<GlobalData> input_global_data =
305 client_->TransferToServer(input_literal).value();
306
307 // NativeT can be bool, and std::vector<bool> does not convert to
308 // Span.
309 std::unique_ptr<NativeT[]> expected(new NativeT[cols]);
310 for (int64_t colno = 0; colno < cols; ++colno) {
311 NativeT column_result = initial_value;
312 for (int64_t rowno = 0; rowno < rows; ++rowno) {
313 column_result = reference_reduction_function(column_result,
314 input_data(rowno, colno));
315 }
316 expected[colno] = column_result;
317 }
318
319 ComputeAndCompareGeneric<NativeT>(
320 &builder, absl::Span<const NativeT>(expected.get(), cols),
321 {input_global_data.get()});
322 }
323
RunVectorizedReduceTest(const std::function<XlaComputation (PrimitiveType,XlaBuilder *)> & reduction_function_generator_for_type,const std::function<float (float,float)> & reference_reduction_function_for_floats,const std::function<int32_t (int32_t,int32_t)> & reference_reduction_function_for_ints,const std::function<uint32_t (uint32_t,uint32_t)> & reference_reduction_function_for_uints,float floating_point_identity,int32_t signed_int_identity,uint32_t unsigned_int_identity)324 void RunVectorizedReduceTest(
325 const std::function<XlaComputation(PrimitiveType, XlaBuilder*)>&
326 reduction_function_generator_for_type,
327 const std::function<float(float, float)>&
328 reference_reduction_function_for_floats,
329 const std::function<int32_t(int32_t, int32_t)>&
330 reference_reduction_function_for_ints,
331 const std::function<uint32_t(uint32_t, uint32_t)>&
332 reference_reduction_function_for_uints,
333 float floating_point_identity, int32_t signed_int_identity,
334 uint32_t unsigned_int_identity) {
335 // Float version
336 RunVectorizedReduceTestForType<float>(
337 [&](XlaBuilder* builder) {
338 return reduction_function_generator_for_type(F32, builder);
339 },
340 reference_reduction_function_for_floats, floating_point_identity);
341
342 // Signed int version
343 RunVectorizedReduceTestForType<int32_t>(
344 [&](XlaBuilder* builder) {
345 return reduction_function_generator_for_type(S32, builder);
346 },
347 reference_reduction_function_for_ints, signed_int_identity);
348
349 // Unsigned int version
350 RunVectorizedReduceTestForType<uint32_t>(
351 [&](XlaBuilder* builder) {
352 return reduction_function_generator_for_type(U32, builder);
353 },
354 reference_reduction_function_for_uints, unsigned_int_identity);
355 }
356
357 Literal literal_2d_;
358 Literal literal_3d_;
359 uint32_t seed_ = 0xdeadbeef;
360 };
361
XLA_TEST_F(ReduceTest,ReduceR1_0_F32_To_R0)362 XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); }
XLA_TEST_F(ReduceTest,ReduceR1_1_F32_To_R0)363 XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); }
XLA_TEST_F(ReduceTest,ReduceR1_2_F32_To_R0)364 XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); }
XLA_TEST_F(ReduceTest,ReduceR1_16_F32_To_R0)365 XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); }
XLA_TEST_F(ReduceTest,ReduceR1_128_F32_To_R0)366 XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); }
XLA_TEST_F(ReduceTest,ReduceR1_129_F32_To_R0)367 XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); }
XLA_TEST_F(ReduceTest,ReduceR1_240_F32_To_R0)368 XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
XLA_TEST_F(ReduceTest,ReduceR1_256_F32_To_R0)369 XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); }
XLA_TEST_F(ReduceTest,ReduceR1_1024_F32_To_R0)370 XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); }
XLA_TEST_F(ReduceTest,ReduceR1_2048_F32_To_R0)371 XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); }
XLA_TEST_F(ReduceTest,ReduceR1_16K_F32_To_R0)372 XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_16KP1_F32_To_R0)373 XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) {
374 RunR1ToR0Test(16 * 1024 + 1);
375 }
XLA_TEST_F(ReduceTest,ReduceR1_64K_F32_To_R0)376 XLA_TEST_F(ReduceTest, ReduceR1_64K_F32_To_R0) { RunR1ToR0Test(64 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_1M_F32_To_R0)377 XLA_TEST_F(ReduceTest, ReduceR1_1M_F32_To_R0) { RunR1ToR0Test(1024 * 1024); }
XLA_TEST_F(ReduceTest,ReduceR1_16M_F32_To_R0)378 XLA_TEST_F(ReduceTest, ReduceR1_16M_F32_To_R0) { RunR1ToR0Test(4096 * 4096); }
379
XLA_TEST_F(ReduceTest,ReduceR2_0x0_To_R0)380 XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_0x2_To_R0)381 XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_1x1_To_R0)382 XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R0) { RunR2ToR0Test(1, 1); }
XLA_TEST_F(ReduceTest,ReduceR2_2x0_To_R0)383 XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R0) { RunR2ToR0Test(2, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_2x2_To_R0)384 XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R0) { RunR2ToR0Test(2, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_8x8_To_R0)385 XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R0) { RunR2ToR0Test(8, 8); }
XLA_TEST_F(ReduceTest,ReduceR2_9x9_To_R0)386 XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R0) { RunR2ToR0Test(9, 9); }
XLA_TEST_F(ReduceTest,ReduceR2_50x111_To_R0)387 XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R0) { RunR2ToR0Test(50, 111); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_To_R0)388 XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R0) { RunR2ToR0Test(111, 50); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_01_To_R0)389 XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R0) {
390 RunR2ToR0Test(111, 50, 0, 1);
391 }
XLA_TEST_F(ReduceTest,ReduceR2_1024x1024_To_R0)392 XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R0) { RunR2ToR0Test(1024, 1024); }
XLA_TEST_F(ReduceTest,ReduceR2_1000x1500_To_R0)393 XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R0) { RunR2ToR0Test(1000, 1500); }
394
395 // Disabled due to b/33245142. Failed on 2016-11-30.
396 // XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R1) { RunR2ToR1Test(0, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_0x2_To_R1)397 XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R1) { RunR2ToR1Test(0, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_1x1_To_R1)398 XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R1) { RunR2ToR1Test(1, 1); }
399 // Disabled due to b/33245142. Failed on 2016-11-30.
400 // XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R1) { RunR2ToR1Test(2, 0); }
XLA_TEST_F(ReduceTest,ReduceR2_2x2_To_R1)401 XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R1) { RunR2ToR1Test(2, 2); }
XLA_TEST_F(ReduceTest,ReduceR2_8x8_To_R1)402 XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R1) { RunR2ToR1Test(8, 8); }
XLA_TEST_F(ReduceTest,ReduceR2_9x9_To_R1)403 XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R1) { RunR2ToR1Test(9, 9); }
XLA_TEST_F(ReduceTest,ReduceR2_50x111_To_R1)404 XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R1) { RunR2ToR1Test(50, 111); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_To_R1)405 XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R1) { RunR2ToR1Test(111, 50); }
XLA_TEST_F(ReduceTest,ReduceR2_111x50_01_To_R1)406 XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) {
407 RunR2ToR1Test(111, 50, 0, 1);
408 }
XLA_TEST_F(ReduceTest,ReduceR2_1024x1024_To_R1)409 XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); }
XLA_TEST_F(ReduceTest,ReduceR2_1000x1500_To_R1)410 XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); }
411
XLA_TEST_F(ReduceTest,AndReduceAllOnesR1_10_Pred)412 XLA_TEST_F(ReduceTest, AndReduceAllOnesR1_10_Pred) {
413 constexpr int element_count = 10;
414 std::vector<int> input(element_count, 1);
415 RunR1ToR0PredTest(/*and_reduce=*/true, input);
416 }
417
XLA_TEST_F(ReduceTest,AndReduceOnesAndZerosR1_10_Pred)418 XLA_TEST_F(ReduceTest, AndReduceOnesAndZerosR1_10_Pred) {
419 constexpr int element_count = 10;
420 std::vector<int> input(element_count);
421 for (int i = 0; i < element_count; ++i) {
422 input[i] = i % 2;
423 }
424 RunR1ToR0PredTest(/*and_reduce=*/true, input);
425 }
426
XLA_TEST_F(ReduceTest,OrReduceAllOnesR1_10_Pred)427 XLA_TEST_F(ReduceTest, OrReduceAllOnesR1_10_Pred) {
428 constexpr int element_count = 10;
429 std::vector<int> input(element_count, 1);
430 RunR1ToR0PredTest(/*and_reduce=*/false, input);
431 }
432
XLA_TEST_F(ReduceTest,OrReduceOnesAndZerosR1_10_Pred)433 XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) {
434 constexpr int element_count = 10;
435 std::vector<int> input(element_count);
436 for (int i = 0; i < element_count; ++i) {
437 input[i] = i % 2;
438 }
439 RunR1ToR0PredTest(/*and_reduce=*/false, input);
440 }
441
XLA_TEST_F(ReduceTest,ReduceElementwiseR2_111x50_To_R1)442 XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
443 const int64_t rows = 111, cols = 50;
444
445 XlaBuilder builder(TestName());
446 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
447 const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
448 auto input = Parameter(&builder, 0, input_shape, "input");
449 auto zero = ConstantR0<float>(&builder, 0.0);
450 auto log_ = Log(input);
451 Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0});
452
453 Array2D<float> input_data(rows, cols);
454 input_data.FillRandom(3.14f, 0.04);
455 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
456 input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
457 std::unique_ptr<GlobalData> input_global_data =
458 client_->TransferToServer(input_literal).value();
459
460 std::vector<float> expected;
461 expected.reserve(cols);
462 for (int64_t colno = 0; colno < cols; ++colno) {
463 float column_sum = 0;
464 for (int64_t rowno = 0; rowno < rows; ++rowno) {
465 column_sum += std::log(input_data(rowno, colno));
466 }
467 expected.push_back(column_sum);
468 }
469 ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
470 ErrorSpec(0.01, 1e-4));
471 }
472
XLA_TEST_F(ReduceTest,TransposeAndReduceElementwiseR2_111x50_To_R1)473 XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
474 const int64_t rows = 111, cols = 50;
475
476 XlaBuilder builder(TestName());
477 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
478 const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
479 auto input = Parameter(&builder, 0, input_shape, "input");
480 auto zero = ConstantR0<float>(&builder, 0.0);
481 auto log_ = Log(input);
482 auto transpose = Transpose(log_, {1, 0});
483 Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1});
484
485 Array2D<float> input_data(rows, cols);
486 input_data.FillRandom(3.14f, 0.04);
487 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
488 input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
489 std::unique_ptr<GlobalData> input_global_data =
490 client_->TransferToServer(input_literal).value();
491
492 std::vector<float> expected;
493 expected.reserve(cols);
494 for (int64_t colno = 0; colno < cols; ++colno) {
495 float column_sum = 0;
496 for (int64_t rowno = 0; rowno < rows; ++rowno) {
497 column_sum += std::log(input_data(rowno, colno));
498 }
499 expected.push_back(column_sum);
500 }
501 ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
502 ErrorSpec(0.01, 1e-4));
503 }
504
505 // Test that algebraic simplifier does not incorrectly fold a transpose into a
506 // reduction operation.
XLA_TEST_F(ReduceTest,TransposeAndReduceR3_12x111x50_To_R2)507 XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
508 XlaBuilder builder(TestName());
509 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
510 const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50});
511 XlaOp input = Parameter(&builder, 0, input_shape, "input");
512 XlaOp zero = ConstantR0<float>(&builder, 0.0);
513 XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2});
514 Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
515
516 TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape));
517
518 ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4));
519 }
520
XLA_TEST_F(ReduceTest,Reshape_111x2x25Reduce_111x50_To_R1)521 XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
522 const int64_t rows = 111, cols = 50;
523
524 XlaBuilder builder(TestName());
525 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
526 const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2});
527 auto input = Parameter(&builder, 0, input_shape, "input");
528 auto zero = ConstantR0<float>(&builder, 0.0);
529 auto log_ = Tanh(input);
530 auto reshape = Reshape(log_, {rows, cols});
531 Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0});
532
533 Array3D<float> input_data(rows, 2, cols / 2);
534 input_data.FillRandom(3.14f, 0.04);
535 Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data);
536 std::unique_ptr<GlobalData> input_global_data =
537 client_->TransferToServer(input_literal).value();
538
539 std::vector<float> expected;
540 expected.reserve(cols);
541 for (int64_t major = 0; major < 2; ++major) {
542 for (int64_t colno = 0; colno < cols / 2; ++colno) {
543 float column_sum = 0;
544 for (int64_t rowno = 0; rowno < rows; ++rowno) {
545 column_sum += std::tanh(input_data(rowno, major, colno));
546 }
547 expected.push_back(column_sum);
548 }
549 }
550 ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
551 ErrorSpec(0.01, 1e-4));
552 }
553
554 struct BoundsLayout {
555 std::vector<int64_t> bounds;
556 std::vector<int64_t> layout;
557 std::vector<int64_t> reduce_dims;
558 };
559
PrintTo(const BoundsLayout & spec,std::ostream * os)560 void PrintTo(const BoundsLayout& spec, std::ostream* os) {
561 *os << absl::StrFormat("R%uToR%u%s_%s_Reduce%s", spec.bounds.size(),
562 spec.bounds.size() - spec.reduce_dims.size(),
563 absl::StrJoin(spec.bounds, "x"),
564 absl::StrJoin(spec.layout, ""),
565 absl::StrJoin(spec.reduce_dims, ""));
566 }
567
568 // Add-reduces a broadcasted scalar matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,AddReduce2DScalarToR0)569 XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
570 XlaBuilder builder(TestName());
571 auto add = CreateScalarAddComputation(F32, &builder);
572 auto scalar = ConstantR0<float>(&builder, 42.0);
573 auto broadcasted = Broadcast(scalar, {500, 500});
574 Reduce(broadcasted, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
575
576 float expected = 42.0f * static_cast<float>(500 * 500);
577 ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
578 }
579
580 // Max-reduces a broadcasted scalar matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MaxReduce2DScalarToR0)581 XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) {
582 XlaBuilder builder(TestName());
583 auto max = CreateScalarMaxComputation(F32, &builder);
584 auto scalar = ConstantR0<float>(&builder, 42.0);
585 auto broadcasted = Broadcast(scalar, {500, 500});
586 Reduce(broadcasted, ConstantR0<float>(&builder, 0.0f), max, {0, 1});
587
588 float expected = 42.0f;
589 ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
590 }
591
592 // Max-reduces a matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MaxReduce2DToR0)593 XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
594 XlaBuilder builder(TestName());
595 auto max = CreateScalarMaxComputation(F32, &builder);
596 Array2D<float> input(300, 250);
597 input.FillRandom(214.0f);
598 auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
599 Reduce(ConstantLiteral(&builder, input_literal),
600 ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
601 auto input_max = FLT_MIN;
602 input.Each(
603 [&](int64_t, int64_t, float* v) { input_max = std::max(input_max, *v); });
604 ComputeAndCompareR0<float>(&builder, input_max, {}, ErrorSpec(0.0001));
605 }
606
607 // Min-reduces matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest,MinReduce2DToR0)608 XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
609 XlaBuilder builder(TestName());
610 auto min = CreateScalarMinComputation(F32, &builder);
611 Array2D<float> input(150, 130);
612 input.FillRandom(214.0f);
613 auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
614 Reduce(ConstantLiteral(&builder, input_literal),
615 ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
616
617 auto input_min = FLT_MAX;
618 input.Each(
619 [&](int64_t, int64_t, float* v) { input_min = std::min(input_min, *v); });
620 ComputeAndCompareR0<float>(&builder, input_min, {}, ErrorSpec(0.0001));
621 }
622
XLA_TEST_F(ReduceTest,UnsignedInt_MinReduce)623 XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
624 XlaBuilder builder(TestName());
625 Array2D<uint32_t> input({{1}, {2}});
626 auto min = CreateScalarMinComputation(U32, &builder);
627 auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
628 auto initial_value =
629 ConstantR0<uint32_t>(&builder, std::numeric_limits<uint32_t>::max());
630
631 Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1});
632 ComputeAndCompareR0<uint32_t>(&builder, 1, {});
633 }
634
XLA_TEST_F(ReduceTest,UnsignedInt_MaxReduce)635 XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
636 XlaBuilder builder(TestName());
637 Array2D<uint32_t> input({{1}, {2}});
638 auto max = CreateScalarMaxComputation(U32, &builder);
639 auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
640 auto initial_value =
641 ConstantR0<uint32_t>(&builder, std::numeric_limits<uint32_t>::min());
642
643 Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1});
644 ComputeAndCompareR0<uint32_t>(&builder, 2, {});
645 }
646
647 // Reduces a matrix among dimension 1.
XLA_TEST_F(ReduceTest,Reduce2DAmong1)648 XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
649 XlaBuilder builder(TestName());
650 auto m = ConstantLiteral(&builder, literal_2d_);
651 auto add = CreateScalarAddComputation(F32, &builder);
652 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
653
654 std::vector<float> expected = {6.f, 15.f};
655 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
656 }
657
XLA_TEST_F(ReduceTest,Reduce2DAmong0and1)658 XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
659 // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
660 XlaBuilder builder(TestName());
661 auto m = ConstantLiteral(&builder, literal_2d_);
662 auto add = CreateScalarAddComputation(F32, &builder);
663 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
664
665 ComputeAndCompareR0<float>(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4));
666 }
667
668 // Tests 2D matrix ReduceToRow operation.
XLA_TEST_F(ReduceTest,Reduce2DAmongY)669 XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
670 XlaBuilder builder("reduce_among_y");
671 auto m = ConstantLiteral(&builder, literal_2d_);
672 auto add = CreateScalarAddComputation(F32, &builder);
673 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
674
675 std::vector<float> expected = {5.f, 7.f, 9.f};
676 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
677 }
678
XLA_TEST_F(ReduceTest,ReduceR3AmongDims_1_2)679 XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
680 XlaBuilder builder(TestName());
681 auto m = ConstantLiteral(&builder, literal_3d_);
682 auto add = CreateScalarAddComputation(F32, &builder);
683 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2});
684
685 std::vector<float> expected = {21.f, 21.f, 21.f, 21.f};
686 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
687 }
688
XLA_TEST_F(ReduceTest,ReduceR3AmongDims_0_1)689 XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
690 XlaBuilder builder(TestName());
691 auto m = ConstantLiteral(&builder, literal_3d_);
692 auto add = CreateScalarAddComputation(F32, &builder);
693 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
694
695 std::vector<float> expected = {20.f, 28.f, 36.f};
696 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
697 }
698
XLA_TEST_F(ReduceTest,ReduceR3ToR0)699 XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
700 XlaBuilder builder(TestName());
701 auto m = ConstantLiteral(&builder, literal_3d_);
702 auto add = CreateScalarAddComputation(F32, &builder);
703 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2});
704
705 float expected = 21.0f * 4.0;
706 ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
707 }
708
XLA_TEST_F(ReduceTest,ReduceR3AmongDim0)709 XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
710 XlaBuilder builder(TestName());
711 auto m = ConstantLiteral(&builder, literal_3d_);
712 auto add = CreateScalarAddComputation(F32, &builder);
713 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
714
715 // clang-format off
716 Array2D<float> expected({
717 {4.f, 8.f, 12.f},
718 {16.f, 20.f, 24.f},
719 });
720 // clang-format on
721 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
722 }
723
XLA_TEST_F(ReduceTest,ReduceR3AmongDim1)724 XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
725 XlaBuilder builder(TestName());
726 auto m = ConstantLiteral(&builder, literal_3d_);
727 auto add = CreateScalarAddComputation(F32, &builder);
728 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
729
730 // clang-format off
731 Array2D<float> expected({
732 {5.f, 7.f, 9.f},
733 {5.f, 7.f, 9.f},
734 {5.f, 7.f, 9.f},
735 {5.f, 7.f, 9.f},
736 });
737 // clang-format on
738 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
739 }
740
XLA_TEST_F(ReduceTest,ReduceR3AmongDim2)741 XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
742 XlaBuilder builder(TestName());
743 auto m = ConstantLiteral(&builder, literal_3d_);
744 auto add = CreateScalarAddComputation(F32, &builder);
745 Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2});
746
747 // clang-format off
748 Array2D<float> expected({
749 {6.f, 15.f},
750 {6.f, 15.f},
751 {6.f, 15.f},
752 {6.f, 15.f},
753 });
754 // clang-format on
755 ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
756 }
757
XLA_TEST_F(ReduceTest,VectorizedReduce_Add)758 XLA_TEST_F(ReduceTest, VectorizedReduce_Add) {
759 RunVectorizedReduceTest(
760 static_cast<FuncGeneratorForType>(CreateScalarAddComputation),
761 [](float a, float b) { return a + b; },
762 [](int32_t a, int32_t b) {
763 return static_cast<int32_t>(static_cast<uint32_t>(a) +
764 static_cast<uint32_t>(b));
765 },
766 [](uint32_t a, uint32_t b) { return a + b; }, 0.0, 0, 0);
767 }
768
XLA_TEST_F(ReduceTest,VectorizedReduce_Multiply)769 XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) {
770 RunVectorizedReduceTest(
771 static_cast<FuncGeneratorForType>(CreateScalarMultiplyComputation),
772 [](float a, float b) { return a * b; },
773 [](int32_t a, int32_t b) {
774 return static_cast<int32_t>(static_cast<uint32_t>(a) *
775 static_cast<uint32_t>(b));
776 },
777 [](uint32_t a, uint32_t b) { return a * b; }, 1.0, 1, 1);
778 }
779
XLA_TEST_F(ReduceTest,VectorizedReduce_Max)780 XLA_TEST_F(ReduceTest, VectorizedReduce_Max) {
781 RunVectorizedReduceTest(
782 static_cast<FuncGeneratorForType>(CreateScalarMaxComputation),
783 [](float a, float b) { return std::max(a, b); },
784 [](int32_t a, int32_t b) { return std::max(a, b); },
785 [](uint32_t a, uint32_t b) { return std::max(a, b); },
786 std::numeric_limits<float>::min(), std::numeric_limits<int32_t>::min(),
787 std::numeric_limits<uint32_t>::min());
788 }
789
XLA_TEST_F(ReduceTest,VectorizedReduce_Min)790 XLA_TEST_F(ReduceTest, VectorizedReduce_Min) {
791 RunVectorizedReduceTest(
792 static_cast<FuncGeneratorForType>(CreateScalarMinComputation),
793 [](float a, float b) { return std::min(a, b); },
794 [](int32_t a, int32_t b) { return std::min(a, b); },
795 [](uint32_t a, uint32_t b) { return std::min(a, b); },
796 std::numeric_limits<float>::max(), std::numeric_limits<int32_t>::max(),
797 std::numeric_limits<uint32_t>::max());
798 }
799
XLA_TEST_F(ReduceTest,VectorizedReduce_BooleanAnd)800 XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) {
801 RunVectorizedReduceTestForType<bool>(
802 static_cast<FuncGenerator>([](XlaBuilder* builder) {
803 return CreateScalarAndComputation(PRED, builder);
804 }),
805 [](bool a, bool b) { return a && b; }, true);
806 }
807
XLA_TEST_F(ReduceTest,VectorizedReduce_BooleanOr)808 XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) {
809 RunVectorizedReduceTestForType<bool>(
810 static_cast<FuncGenerator>([](XlaBuilder* builder) {
811 return CreateScalarOrComputation(PRED, builder);
812 }),
813 [](bool a, bool b) { return a || b; }, false);
814 }
815
816 class ReduceR3ToR2Test : public ReduceTest,
817 public ::testing::WithParamInterface<BoundsLayout> {};
818
XLA_TEST_P(ReduceR3ToR2Test,ReduceR3ToR2)819 XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
820 XlaBuilder builder(TestName());
821 const auto& bounds = GetParam().bounds;
822 Array3D<float> input_array(bounds[0], bounds[1], bounds[2]);
823 // input_array.FillRandom(3.14f, 0.05);
824 input_array.Fill(1.0f);
825
826 auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
827 input_literal =
828 input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout));
829 std::unique_ptr<GlobalData> input_data =
830 client_->TransferToServer(input_literal).value();
831
832 auto input_activations =
833 Parameter(&builder, 0, input_literal.shape(), "input");
834 XlaComputation add = CreateScalarAddComputation(F32, &builder);
835 Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
836 GetParam().reduce_dims);
837
838 auto expected =
839 ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims,
840 [](float a, float b) { return a + b; });
841
842 ComputeAndCompareR2<float>(&builder, *expected, {input_data.get()},
843 ErrorSpec(1e-3, 1e-3));
844 }
845
846 INSTANTIATE_TEST_CASE_P(
847 ReduceR3ToR2Test_Instantiation, ReduceR3ToR2Test,
848 // Specifies (shape, layout, reduction dimensions).
849 ::testing::Values(BoundsLayout{{4, 8, 128}, {2, 1, 0}, {0}},
850 BoundsLayout{{4, 8, 128}, {2, 1, 0}, {1}},
851 BoundsLayout{{4, 8, 128}, {2, 1, 0}, {2}},
852 // These should be simplified into a reshape.
853 BoundsLayout{{1, 21, 43}, {2, 1, 0}, {0}},
854 BoundsLayout{{1, 1, 1}, {2, 1, 0}, {0}},
855 BoundsLayout{{1, 1, 1}, {2, 1, 0}, {1}},
856 BoundsLayout{{1, 1, 1}, {2, 1, 0}, {2}},
857 BoundsLayout{{8, 16, 24}, {0, 1, 2}, {0}},
858 BoundsLayout{{8, 16, 24}, {0, 1, 2}, {1}},
859 BoundsLayout{{8, 16, 24}, {0, 1, 2}, {2}},
860 BoundsLayout{{5, 10, 250}, {2, 1, 0}, {0}},
861 BoundsLayout{{5, 10, 250}, {2, 1, 0}, {1}},
862 BoundsLayout{{5, 10, 250}, {2, 1, 0}, {2}},
863 BoundsLayout{{8, 16, 256}, {2, 1, 0}, {0}},
864 BoundsLayout{{8, 16, 256}, {2, 1, 0}, {1}},
865 BoundsLayout{{8, 16, 256}, {2, 1, 0}, {2}},
866 BoundsLayout{{2, 300, 784}, {2, 1, 0}, {2}},
867 BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
868 BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
869
XLA_TEST_F(ReduceTest,OperationOnConstantAsInitValue)870 XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
871 XlaBuilder builder(TestName());
872 XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder);
873
874 auto a = ConstantR0<float>(&builder, 2.0f);
875 auto a2 = Abs(a);
876
877 Literal b_literal = LiteralUtil::CreateR1<float>({1.0f, 4.0f});
878 std::unique_ptr<GlobalData> b_data =
879 client_->TransferToServer(b_literal).value();
880 auto b = Parameter(&builder, 0, b_literal.shape(), "b");
881 Reduce(b, a2, max_f32, {0});
882
883 ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
884 }
885
XLA_TEST_F(ReduceTest,ReduceAndPredR2_128x64_To_R1)886 XLA_TEST_F(ReduceTest, ReduceAndPredR2_128x64_To_R1) {
887 RunR2ToR1PredTest</*cols=64*/ 64>(/*and_reduce=true*/ true, /*rows=128*/ 128);
888 }
XLA_TEST_F(ReduceTest,ReduceOrPredR2_64x32_To_R1)889 XLA_TEST_F(ReduceTest, ReduceOrPredR2_64x32_To_R1) {
890 RunR2ToR1PredTest</*cols=32*/ 32>(/*and_reduce=false*/ false, /*rows=64*/ 64);
891 }
892
893 // Tests reductions with different initial values. There's no test macro that
894 // combines TYPED_TEST and TYPED_P, so we have to do it manually.
895 class ReduceInitializerTest : public ReduceTest {
896 protected:
897 template <typename T>
DoTest(T initializer,int num_elems)898 void DoTest(T initializer, int num_elems) {
899 XlaBuilder builder(TestName());
900 XlaComputation max_fn = CreateScalarMaxComputation(
901 primitive_util::NativeToPrimitiveType<T>(), &builder);
902
903 auto init = ConstantR0<T>(&builder, initializer);
904 std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
905 auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
906 auto input_data = client_->TransferToServer(input_literal).value();
907 Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn,
908 {0});
909
910 ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()});
911 }
912 };
913
XLA_TEST_F(ReduceInitializerTest,U8Small)914 XLA_TEST_F(ReduceInitializerTest, U8Small) { DoTest<uint8_t>(42, 2); }
915
XLA_TEST_F(ReduceInitializerTest,U8BigPowerOf2)916 XLA_TEST_F(ReduceInitializerTest, U8BigPowerOf2) { DoTest<uint8_t>(42, 4096); }
917
XLA_TEST_F(ReduceInitializerTest,U8InitializerBigNonPowerOf2)918 XLA_TEST_F(ReduceInitializerTest, U8InitializerBigNonPowerOf2) {
919 DoTest<uint8_t>(42, 4095);
920 }
921
XLA_TEST_F(ReduceInitializerTest,U64InitializerZero)922 XLA_TEST_F(ReduceInitializerTest, U64InitializerZero) {
923 DoTest<uint64_t>(0, 1024);
924 }
925
XLA_TEST_F(ReduceInitializerTest,U64InitializerOne)926 XLA_TEST_F(ReduceInitializerTest, U64InitializerOne) {
927 DoTest<uint64_t>(1, 1024);
928 }
929
XLA_TEST_F(ReduceInitializerTest,U64InitializerBigValue)930 XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) {
931 DoTest<uint64_t>(1234556789123, 1024);
932 }
933
934 // Test the operational semantic that the init value is passed on the lhs for
935 // reduces. Can be tested by performing an "identity" reduce (that simply
936 // returns one of the parameters). In this case, we return the rhs, which for
937 // a 1D array with one element, should not be the init value.
XLA_TEST_F(ReduceTest,ReduceIdentity)938 XLA_TEST_F(ReduceTest, ReduceIdentity) {
939 XlaBuilder builder(TestName());
940 Shape single_float = ShapeUtil::MakeShape(F32, {});
941 Parameter(&builder, 0, single_float, "lhs-unused");
942 Parameter(&builder, 1, single_float, "rhs-used");
943 auto computation_status = builder.Build();
944 TF_ASSERT_OK(computation_status.status());
945
946 Shape operand_shape = ShapeUtil::MakeShape(F32, {1});
947 Reduce(Parameter(&builder, 0, operand_shape, "operand"),
948 Parameter(&builder, 1, single_float, "init"),
949 computation_status.ValueOrDie(), {0});
950
951 float operand[] = {42.0f};
952 float init = 58.5f;
953 float expected = 42.0f;
954 Literal input_literal = LiteralUtil::CreateR1<float>(operand);
955 std::unique_ptr<GlobalData> input_global_data =
956 client_->TransferToServer(input_literal).value();
957 Literal input_literal2 = LiteralUtil::CreateR0<float>(init);
958 std::unique_ptr<GlobalData> input_global_data2 =
959 client_->TransferToServer(input_literal2).value();
960 ComputeAndCompareR0<float>(
961 &builder, expected, {input_global_data.get(), input_global_data2.get()},
962 ErrorSpec(0.0001));
963 }
964
XLA_TEST_F(ReduceTest,AndReduceU64)965 XLA_TEST_F(ReduceTest, AndReduceU64) {
966 XlaBuilder builder(TestName());
967 Array2D<uint64_t> initializer = {
968 {0x123456789ABCDEF0ULL, 0x3BCDEF12A4567890ULL},
969 {0XFFFFFFFFFFFFFFD6ULL, 101},
970 {1, 0XFFFFFFFFFFFFFFFFULL}};
971 auto reducer = CreateScalarAndComputation(U64, &builder);
972 auto m = ConstantR2FromArray2D(&builder, initializer);
973 Reduce(m, ConstantR0<uint64_t>(&builder, 0xFFFFFFFFFFFFFFFFLL), reducer, {1});
974
975 std::vector<uint64_t> expected = {0x1204461080145890LL, 68, 1};
976 ComputeAndCompareR1<uint64_t>(&builder, expected, {});
977 }
978
XLA_TEST_F(ReduceTest,OrReduceU64)979 XLA_TEST_F(ReduceTest, OrReduceU64) {
980 XlaBuilder builder(TestName());
981 Array2D<uint64_t> initializer = {
982 {0x123456789ABCDEF0ULL, 0x3BCDEF12A4567890ULL},
983 {0xFFFFFFFFFFFFFFD6ULL, 101},
984 {1, 0xCAFEBEEFABABABABULL}};
985 auto reducer = CreateScalarOrComputation(U64, &builder);
986 auto m = ConstantR2FromArray2D(&builder, initializer);
987 Reduce(m, ConstantR0<uint64_t>(&builder, 0), reducer, {1});
988
989 std::vector<uint64_t> expected = {
990 0X3BFDFF7ABEFEFEF0ULL, 0XFFFFFFFFFFFFFFF7ULL, 0xCAFEBEEFABABABABULL};
991 ComputeAndCompareR1<uint64_t>(&builder, expected, {});
992 }
993
XLA_TEST_F(ReduceTest,R0ReduceInDisguise)994 XLA_TEST_F(ReduceTest, R0ReduceInDisguise) {
995 XlaBuilder builder(TestName());
996 XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
997 constexpr int element_count = 127;
998 const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count, 1});
999 auto input = Parameter(&builder, 0, input_shape, "input");
1000 auto zero = ConstantR0<float>(&builder, 0.0);
1001 Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
1002
1003 Array2D<float> input_data(element_count, 1);
1004 input_data.FillRandom(3.0f);
1005 Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
1006 std::unique_ptr<GlobalData> input_global_data =
1007 client_->TransferToServer(input_literal).value();
1008
1009 float expected = absl::c_accumulate(input_data, 0.0f);
1010 ComputeAndCompareR1<float>(&builder, {expected}, {input_global_data.get()},
1011 ErrorSpec(0.001));
1012 }
1013
1014 class ReduceHloTest : public HloTestBase {};
1015
XLA_TEST_F(ReduceHloTest,HandleReductionToVectorAndOtherReduction)1016 XLA_TEST_F(ReduceHloTest, HandleReductionToVectorAndOtherReduction) {
1017 absl::string_view hlo_string = R"(
1018 HloModule HandleReductionToVectorAndOtherReduction
1019
1020 add {
1021 acc = f32[] parameter(1)
1022 op = f32[] parameter(0)
1023 ROOT out = f32[] add(acc, op)
1024 }
1025
1026 ENTRY main {
1027 iota.3 = s32[2,2]{1,0} iota(), iota_dimension=0
1028 iota.2 = s32[2,2]{1,0} iota(), iota_dimension=1
1029 compare.0 = pred[2,2]{1,0} compare(iota.3, iota.2), direction=EQ
1030 broadcast = pred[2,2,2,2]{3,2,1,0} broadcast(compare.0), dimensions={2,3}
1031 param_0.16 = f32[2,2,2,2]{3,2,1,0} parameter(0)
1032 constant_4 = f32[] constant(0)
1033 broadcast.9 = f32[2,2,2,2]{3,2,1,0} broadcast(constant_4), dimensions={}
1034 select.0 = f32[2,2,2,2]{3,2,1,0} select(broadcast, param_0.16, broadcast.9)
1035 reduce.1 = f32[2,2,2]{2,1,0} reduce(select.0, constant_4), dimensions={2},
1036 to_apply=add
1037 abs.0 = f32[2,2,2]{2,1,0} abs(reduce.1)
1038 log.0 = f32[2,2,2]{2,1,0} log(abs.0)
1039 reduce.0 = f32[2,2]{1,0} reduce(log.0, constant_4), dimensions={2},
1040 to_apply=add
1041 ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce.0, reduce.1)
1042 }
1043 )";
1044 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1045 }
1046
1047 class VariadicReduceTest : public HloTestBase {};
1048
XLA_TEST_F(VariadicReduceTest,Reduce_R3x2_to_R2x2_simple)1049 XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R2x2_simple) {
1050 absl::string_view hlo_string = R"(
1051 HloModule Reduce_R3x2_to_R1x2_simple
1052
1053 add {
1054 op1 = f32[] parameter(0)
1055 op2 = f32[] parameter(1)
1056 acc1 = f32[] parameter(2)
1057 acc2 = f32[] parameter(3)
1058 out1 = f32[] add(acc1, op1)
1059 out2 = f32[] add(acc2, op2)
1060 ROOT result = (f32[], f32[]) tuple(out1, out2)
1061 }
1062
1063 ENTRY main {
1064 inp1 = f32[3,4,5] parameter(0)
1065 inp2 = f32[3,4,5] parameter(1)
1066 zero = f32[] constant(0)
1067
1068 ROOT out = (f32[3,5], f32[3,5]) reduce(inp1, inp2, zero, zero),
1069 dimensions={1},
1070 to_apply=add
1071 }
1072 )";
1073
1074 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1075 }
1076
XLA_TEST_F(VariadicReduceTest,Reduce_R3x2_to_R1x2_simple)1077 XLA_TEST_F(VariadicReduceTest, Reduce_R3x2_to_R1x2_simple) {
1078 absl::string_view hlo_string = R"(
1079 HloModule Reduce_R3x2_to_R1x2_simple
1080
1081 add {
1082 op1 = f32[] parameter(0)
1083 op2 = f32[] parameter(1)
1084 acc1 = f32[] parameter(2)
1085 acc2 = f32[] parameter(3)
1086 out1 = f32[] add(acc1, op1)
1087 out2 = f32[] add(acc2, op2)
1088 ROOT result = (f32[], f32[]) tuple(out1, out2)
1089 }
1090
1091 ENTRY main {
1092 inp1 = f32[10,20,3] parameter(0)
1093 inp2 = f32[10,20,3] parameter(1)
1094 zero = f32[] constant(0)
1095
1096 ROOT out = (f32[10], f32[10]) reduce(inp1, inp2, zero, zero),
1097 dimensions={1,2},
1098 to_apply=add
1099 }
1100 )";
1101
1102 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1103 }
1104
XLA_TEST_F(VariadicReduceTest,Reduce_R1x2_to_R0x2_simple)1105 XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_simple) {
1106 absl::string_view hlo_string = R"(
1107 HloModule Reduce_R1x2_to_R0x2_simple
1108
1109 add {
1110 op1 = f32[] parameter(0)
1111 op2 = f32[] parameter(1)
1112 acc1 = f32[] parameter(2)
1113 acc2 = f32[] parameter(3)
1114 out1 = f32[] add(acc1, op1)
1115 out2 = f32[] add(acc2, op2)
1116 ROOT result = (f32[], f32[]) tuple(out1, out2)
1117 }
1118
1119 ENTRY main {
1120 inp1 = f32[100] parameter(0)
1121 inp2 = f32[100] parameter(1)
1122 zero = f32[] constant(0)
1123
1124 ROOT out = (f32[], f32[]) reduce(inp1, inp2, zero, zero),
1125 dimensions={0},
1126 to_apply=add
1127 }
1128 )";
1129
1130 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1131 }
1132
XLA_TEST_F(VariadicReduceTest,Reduce_R1x2_to_R0x2_argmax)1133 XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_argmax) {
1134 absl::string_view hlo_string = R"(
1135 HloModule Reduce_R1x2_to_R0x2_argmax
1136
1137 argmax {
1138 running_max = f32[] parameter(0)
1139 running_max_idx = u32[] parameter(1)
1140 current_value = f32[] parameter(2)
1141 current_value_idx = u32[] parameter(3)
1142
1143 current = (f32[], u32[]) tuple(running_max, running_max_idx)
1144 potential = (f32[], u32[]) tuple(current_value, current_value_idx)
1145
1146 cmp_code = pred[] compare(current_value, running_max), direction=GT
1147
1148 new_max = f32[] select(cmp_code, current_value, running_max)
1149 new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
1150
1151 ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
1152 }
1153
1154 ENTRY main {
1155 input = f32[100] parameter(0)
1156 idxs = u32[100]{0} iota(), iota_dimension=0
1157 zero = f32[] constant(0)
1158 zero_idx = u32[] constant(0)
1159
1160 ROOT out = (f32[], u32[]) reduce(
1161 input, idxs, zero, zero_idx),
1162 dimensions={0},
1163 to_apply=%argmax
1164 }
1165 )";
1166
1167 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1168 }
1169
XLA_TEST_F(VariadicReduceTest,Reduce_R1x2_to_R0x2_argmax_column)1170 XLA_TEST_F(VariadicReduceTest, Reduce_R1x2_to_R0x2_argmax_column) {
1171 absl::string_view hlo_string = R"(
1172 HloModule Reduce_R1x2_to_R0x2_argmax
1173
1174 add {
1175 acc = f32[] parameter(1)
1176 op = f32[] parameter(0)
1177 ROOT out = f32[] add(acc, op)
1178 }
1179
1180 argmax {
1181 running_max = f32[] parameter(0)
1182 running_max_idx = u32[] parameter(1)
1183 current_value = f32[] parameter(2)
1184 current_value_idx = u32[] parameter(3)
1185
1186 current = (f32[], u32[]) tuple(running_max, running_max_idx)
1187 potential = (f32[], u32[]) tuple(current_value, current_value_idx)
1188
1189 cmp_code = pred[] compare(current_value, running_max), direction=GT
1190
1191 new_max = f32[] select(cmp_code, current_value, running_max)
1192 new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
1193
1194 ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
1195 }
1196
1197 ENTRY main {
1198 input = f32[32,128] parameter(0)
1199 idxs = u32[32,128] iota(), iota_dimension=0
1200 zero = f32[] constant(0)
1201 zero_idx = u32[] constant(0)
1202
1203 ROOT argmax_result = (f32[128], u32[128]) reduce(
1204 input, idxs, zero, zero_idx),
1205 dimensions={0},
1206 to_apply=%argmax
1207 }
1208 )";
1209
1210 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1211 }
1212
XLA_TEST_F(VariadicReduceTest,ReduceMultiOutputVariadicAnd)1213 XLA_TEST_F(VariadicReduceTest, ReduceMultiOutputVariadicAnd) {
1214 absl::string_view hlo_string = R"(
1215 HloModule VariadicReduceMultiOutput
1216
1217 VariadicAnd {
1218 value = pred[] parameter(0)
1219 value_idx = u32[] parameter(1)
1220 current_value = pred[] parameter(2)
1221 current_value_idx = u32[] parameter(3)
1222 ROOT out = (pred[], u32[]) tuple(value, value_idx)
1223 }
1224
1225 ENTRY CheckBuffer {
1226 test_value = f32[] parameter(0)
1227 buffer = f32[100] parameter(1)
1228 value_broadcast = f32[100] broadcast(test_value), dimensions={}
1229 comparison_result = pred[100] compare(buffer, value_broadcast), direction=EQ
1230 true_constant = pred[] constant(true)
1231
1232 zero_idx = u32[] constant(0)
1233 idxs = u32[100]{0} iota(), iota_dimension=0
1234 out = (pred[], u32[]) reduce(
1235 comparison_result, idxs, true_constant, zero_idx
1236 ), dimensions={0}, to_apply=VariadicAnd
1237
1238 ROOT returned = u32[] get-tuple-element(out), index=1
1239 }
1240 )";
1241
1242 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1243 }
1244
XLA_TEST_F(VariadicReduceTest,ReduceMultiOutputVariadicDifferentLayout)1245 XLA_TEST_F(VariadicReduceTest, ReduceMultiOutputVariadicDifferentLayout) {
1246 absl::string_view hlo_string = R"(
1247 HloModule ReduceWithLayoutChangeVariadicDifferent
1248
1249 argmax {
1250 running_max = f32[] parameter(0)
1251 running_max_idx = u32[] parameter(1)
1252 current_value = f32[] parameter(2)
1253 current_value_idx = u32[] parameter(3)
1254
1255 current = (f32[], u32[]) tuple(running_max, running_max_idx)
1256 potential = (f32[], u32[]) tuple(current_value, current_value_idx)
1257
1258 cmp_code = pred[] compare(current_value, running_max), direction=GT
1259
1260 new_max = f32[] select(cmp_code, current_value, running_max)
1261 new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
1262
1263 ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
1264 }
1265
1266 ENTRY main {
1267 arg0 = f32[2,3,4,1024]{2,1,0,3} parameter(0)
1268 idxs = u32[2,3,4,1024]{3,2,1,0} parameter(1)
1269 constant0 = f32[] constant(0)
1270 constant1 = u32[] constant(0)
1271 ROOT reduce0 = (
1272 f32[2,3,4]{2,1,0},
1273 u32[2,3,4]{1,0,2}
1274 ) reduce(arg0, idxs, constant0,constant1), dimensions={3}, to_apply=argmax
1275 }
1276 )";
1277
1278 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
1279 }
1280
1281 } // namespace
1282 } // namespace xla
1283