1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17 #include <functional>
18 #include <limits>
19 #include <memory>
20 #include <numeric>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/base/casts.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/client/global_data.h"
31 #include "tensorflow/compiler/xla/client/local_client.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/test.h"
37 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
38 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
39 #include "tensorflow/compiler/xla/tests/test_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41
42 namespace xla {
43 namespace {
44
45 class ArrayElementwiseOpTest : public ClientLibraryTestBase {
46 public:
47 ErrorSpec error_spec_{0.0001, 0.0001};
48 ErrorSpec strict_error_spec_{3.6e-15, 3.6e-15};
49 };
50
51 class ArrayElementwiseOpTestParamCount
52 : public ArrayElementwiseOpTest,
53 public ::testing::WithParamInterface<int> {};
54
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementF32)55 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
56 XlaBuilder builder(TestName());
57 auto a = ConstantR1<float>(&builder, {});
58 Neg(a);
59
60 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
61 }
62
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantF32)63 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
64 XlaBuilder builder(TestName());
65 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
66 Neg(a);
67
68 ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
69 error_spec_);
70 }
71
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantF64)72 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF64) {
73 XlaBuilder builder(TestName());
74 auto a = ConstantR1<double>(&builder, {-2.5, 3.14, 2.25, -10.0, 6.0});
75 Neg(a);
76
77 ComputeAndCompare(&builder, {}, strict_error_spec_);
78 }
79
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS32)80 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
81 XlaBuilder builder(TestName());
82 auto a = ConstantR1<int32_t>(
83 &builder, {-1, 0, 1, 324, std::numeric_limits<int32_t>::min(),
84 std::numeric_limits<int32_t>::max()});
85 Neg(a);
86
87 // -min == min for int32_t due to an overflow. In C++ it is undefined behavior
88 // to do this calculation. For XLA we have not specified that, so it
89 // ought to work.
90 ComputeAndCompareR1<int32_t>(
91 &builder,
92 {1, 0, -1, -324, std::numeric_limits<int32_t>::min(),
93 -std::numeric_limits<int32_t>::max()},
94 {});
95 }
96
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementC64)97 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) {
98 XlaBuilder builder(TestName());
99 auto a = ConstantR1<complex64>(&builder, {});
100 Neg(a);
101
102 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
103 }
104
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantC64)105 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
106 XlaBuilder builder(TestName());
107 auto a = ConstantR1<complex64>(
108 &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
109 Neg(a);
110
111 ComputeAndCompareR1<complex64>(
112 &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}},
113 {}, error_spec_);
114 }
115
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS64)116 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) {
117 XlaBuilder builder(TestName());
118 auto a = ConstantR1<int64_t>(&builder,
119 {
120 -1,
121 1,
122 0,
123 0x12345678,
124 static_cast<int64_t>(0xffffffff12345678l),
125 static_cast<int64_t>(0x8000000000000000LL),
126 static_cast<int64_t>(0x8000000000000001LL),
127 });
128 Neg(a);
129 LOG(INFO) << -static_cast<int64_t>(0x7FFFFFFFFFFFFFFFLL);
130
131 ComputeAndCompareR1<int64_t>(&builder,
132 {
133 1,
134 -1,
135 0,
136 -0x12345678,
137 0xedcba988,
138 static_cast<int64_t>(0x8000000000000000LL),
139 -static_cast<int64_t>(0x8000000000000001LL),
140 },
141 {});
142 }
143
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteZeroElementF32s)144 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
145 XlaBuilder builder(TestName());
146 auto a = ConstantR1<float>(&builder, {});
147 IsFinite(a);
148
149 ComputeAndCompareR1<bool>(&builder, {}, {});
150 }
151
XLA_TEST_F(ArrayElementwiseOpTest,IntPow)152 XLA_TEST_F(ArrayElementwiseOpTest, IntPow) {
153 XlaBuilder builder(TestName());
154 XlaOp lhs =
155 ConstantR1<int32_t>(&builder, {0, 1, 2, 3, 4, 5, -1, -2, 3, 5, 3, 1});
156 XlaOp rhs =
157 ConstantR1<int32_t>(&builder, {0, 3, 3, 3, 3, 3, 2, 3, 2, 10, -100, -2});
158 Pow(lhs, rhs);
159
160 std::vector<int32_t> expected = {1, 1, 8, 27, 64, 125,
161 1, -8, 9, 9765625, 0, 1};
162
163 ComputeAndCompareR1<int32_t>(&builder, expected, {});
164 }
165
XLA_TEST_F(ArrayElementwiseOpTest,IntPowLarge)166 XLA_TEST_F(ArrayElementwiseOpTest, IntPowLarge) {
167 XlaBuilder builder(TestName());
168 XlaOp lhs = ConstantR1<int64_t>(&builder, {2});
169 XlaOp rhs = ConstantR1<int64_t>(&builder, {62});
170 Pow(lhs, rhs);
171
172 std::vector<int64_t> expected = {4611686018427387904};
173
174 ComputeAndCompareR1<int64_t>(&builder, expected, {});
175 }
176
177 // A non-canonical quiet NaN value.
178 static const float kNonCanonicalNaN = absl::bit_cast<float>(0x7FD01234);
179
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteScalarF32)180 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
181 XlaBuilder builder(TestName());
182 IsFinite(ConstantR0<float>(&builder, NAN));
183 ComputeAndCompareR0<bool>(&builder, false, {});
184
185 EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
186 IsFinite(ConstantR0<float>(&builder, kNonCanonicalNaN));
187 ComputeAndCompareR0<bool>(&builder, false, {});
188
189 const float inf = std::numeric_limits<float>::infinity();
190 IsFinite(ConstantR0<float>(&builder, inf));
191 ComputeAndCompareR0<bool>(&builder, false, {});
192
193 IsFinite(ConstantR0<float>(&builder, -inf));
194 ComputeAndCompareR0<bool>(&builder, false, {});
195
196 IsFinite(ConstantR0<float>(&builder, 0.0f));
197 ComputeAndCompareR0<bool>(&builder, true, {});
198 }
199
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteR1F32s)200 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
201 XlaBuilder builder(TestName());
202 const float inf = std::numeric_limits<float>::infinity();
203 EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
204 auto a = ConstantR1<float>(&builder,
205 {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
206 IsFinite(a);
207
208 ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false},
209 {});
210 }
211
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantF32s)212 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
213 XlaBuilder builder(TestName());
214 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
215 auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
216 Add(a, b);
217
218 ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
219 error_spec_);
220 }
221
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementF32s)222 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
223 XlaBuilder builder(TestName());
224 auto a = ConstantR1<float>(&builder, {});
225 auto b = ConstantR1<float>(&builder, {});
226 Add(a, b);
227
228 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
229 }
230
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantC64s)231 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
232 XlaBuilder builder(TestName());
233 auto a = ConstantR1<complex64>(
234 &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
235 auto b = ConstantR1<complex64>(
236 &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
237 Add(a, b);
238
239 ComputeAndCompareR1<complex64>(
240 &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {},
241 error_spec_);
242 }
243
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementC64s)244 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
245 XlaBuilder builder(TestName());
246 auto a = ConstantR1<complex64>(&builder, {});
247 auto b = ConstantR1<complex64>(&builder, {});
248 Add(a, b);
249
250 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
251 }
252
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantU64s)253 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
254 XlaBuilder b(TestName());
255
256 std::vector<uint64_t> lhs{0xFFFFFFFF,
257 static_cast<uint64_t>(-1),
258 0,
259 0,
260 0x7FFFFFFFFFFFFFFFLL,
261 0x7FFFFFFFFFFFFFFLL,
262 0x8000000000000000ULL,
263 0x8000000000000000ULL,
264 1};
265 Literal lhs_literal = LiteralUtil::CreateR1<uint64_t>({lhs});
266 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
267 std::unique_ptr<GlobalData> lhs_data =
268 client_->TransferToServer(lhs_literal).value();
269
270 std::vector<uint64_t> rhs{1,
271 0x7FFFFFFFFFFFFFFLL,
272 0x7FFFFFFFFFFFFFFFLL,
273 0x8000000000000000ULL,
274 0,
275 static_cast<uint64_t>(-1),
276 0,
277 1,
278 0x8000000000000000ULL};
279 Literal rhs_literal = LiteralUtil::CreateR1<uint64_t>({rhs});
280 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
281 std::unique_ptr<GlobalData> rhs_data =
282 client_->TransferToServer(rhs_literal).value();
283
284 Add(lhs_param, rhs_param);
285
286 std::vector<uint64_t> expected(lhs.size());
287 for (int64_t i = 0; i < lhs.size(); ++i) {
288 expected[i] = lhs[i] + rhs[i];
289 }
290
291 ComputeAndCompareR1<uint64_t>(&b, expected, {lhs_data.get(), rhs_data.get()});
292 }
293
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS64s)294 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
295 XlaBuilder b(TestName());
296
297 std::vector<int64_t> lhs{static_cast<int64_t>(0x8000000000000000LL),
298 static_cast<int64_t>(0x8000000000000000LL),
299 -1,
300 0x7FFFFFFFFFFFFFFLL,
301 0x7FFFFFFFFFFFFFFFLL,
302 1,
303 0,
304 -1};
305 Literal lhs_literal = LiteralUtil::CreateR1<int64_t>({lhs});
306 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
307 std::unique_ptr<GlobalData> lhs_data =
308 client_->TransferToServer(lhs_literal).value();
309
310 std::vector<int64_t> rhs{-1,
311 0,
312 static_cast<int64_t>(0x8000000000000000LL),
313 1,
314 0,
315 0x7FFFFFFFFFFFFFFLL,
316 0x7FFFFFFFFFFFFFFFLL,
317 0x7FFFFFFFFFFFFFFFLL};
318 Literal rhs_literal = LiteralUtil::CreateR1<int64_t>({rhs});
319 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
320 std::unique_ptr<GlobalData> rhs_data =
321 client_->TransferToServer(rhs_literal).value();
322
323 Sub(lhs_param, rhs_param);
324
325 std::vector<int64_t> expected(lhs.size());
326 for (int64_t i = 0; i < lhs.size(); ++i) {
327 expected[i] = lhs[i] - rhs[i];
328 }
329
330 ComputeAndCompareR1<int64_t>(&b, expected, {lhs_data.get(), rhs_data.get()});
331 }
332
XLA_TEST_F(ArrayElementwiseOpTest,CmpTwoConstantU64s)333 XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
334 XlaBuilder b(TestName());
335
336 std::vector<uint64_t> lhs{static_cast<uint64_t>(0x8000000000000000ULL)};
337 Literal lhs_literal = LiteralUtil::CreateR1<uint64_t>({lhs});
338 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
339
340 std::vector<uint64_t> rhs{static_cast<uint64_t>(0x7FFFFFFFFFFFFFFFULL)};
341 Literal rhs_literal = LiteralUtil::CreateR1<uint64_t>({rhs});
342 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
343
344 Lt(lhs_param, rhs_param);
345
346 ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
347 }
348
TEST_P(ArrayElementwiseOpTestParamCount,AddManyValues)349 TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
350 const int count = GetParam();
351 XlaBuilder builder(TestName());
352 std::vector<float> a_values;
353 std::vector<float> b_values;
354 a_values.reserve(count);
355 b_values.reserve(count);
356 for (int i = 0; i < count; ++i) {
357 a_values.push_back(i / static_cast<float>(count));
358 b_values.push_back(2 * i / static_cast<float>(count + 2));
359 }
360
361 Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
362 std::unique_ptr<GlobalData> a_data =
363 client_->TransferToServer(a_literal).value();
364 auto a_constant = ConstantR1<float>(&builder, a_values);
365 auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
366
367 Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
368 std::unique_ptr<GlobalData> b_data =
369 client_->TransferToServer(b_literal).value();
370 auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param");
371 auto b_constant = ConstantR1<float>(&builder, b_values);
372
373 auto sum1 = Add(a_constant, b_param);
374 auto sum2 = Add(a_constant, b_constant);
375 auto sum3 = Add(a_param, b_param);
376 auto sum4 = Add(a_param, b_constant);
377
378 auto sum = Add(sum1, sum2);
379 sum = Add(sum, sum3);
380 sum = Add(sum, sum4);
381
382 std::vector<float> expected;
383 expected.reserve(count);
384 for (int64_t i = 0; i < count; ++i) {
385 expected.push_back(4 * (a_values[i] + b_values[i]));
386 }
387
388 ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()},
389 error_spec_);
390 }
391
XLA_TEST_F(ArrayElementwiseOpTest,DeeplyNestedAddWithSlices)392 XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) {
393 XlaBuilder builder(TestName());
394 std::vector<float> values(30, 0.0);
395 auto a_literal = LiteralUtil::CreateR1<float>(values);
396 auto a = Parameter(&builder, 0, a_literal.shape(), "x");
397 auto b_literal = LiteralUtil::CreateR1<float>(values);
398 auto b = Parameter(&builder, 1, b_literal.shape(), "x");
399
400 // Construct a sequence of diamond-shaped gadgets like this:
401 //
402 // add
403 // / \
404 // slice slice
405 // \ /
406 // add
407 //
408 // Each 'left' slice removes the last element, each 'right' slice removes the
409 // first element. In this way, we index into the add with different
410 // multi-dimensional index arrays, which defeats the caching we use to avoid
411 // exponential compile time.
412 std::function<XlaOp(int64_t)> generate_recursive =
413 [&](int64_t slice_size) -> XlaOp {
414 if (slice_size == values.size()) {
415 return Add(a, b);
416 }
417 XlaOp param = generate_recursive(slice_size + 1);
418 auto slice1 = Slice(param, {0}, {slice_size}, {1});
419 auto slice2 = Slice(param, {1}, {slice_size + 1}, {1});
420 return Add(slice1, slice2);
421 };
422 generate_recursive(1);
423 auto a_data = client_->TransferToServer(a_literal).value();
424 auto b_data = client_->TransferToServer(b_literal).value();
425 ComputeAndCompareR1<float>(&builder, {0.0}, {a_data.get(), b_data.get()});
426 }
427
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantF32s)428 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
429 XlaBuilder builder(TestName());
430 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
431 auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
432 Sub(a, b);
433
434 ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
435 {}, error_spec_);
436 }
437
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementF32s)438 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
439 XlaBuilder builder(TestName());
440 auto a = ConstantR1<float>(&builder, {});
441 auto b = ConstantR1<float>(&builder, {});
442 Sub(a, b);
443
444 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
445 }
446
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS32s)447 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
448 XlaBuilder builder(TestName());
449 auto a = ConstantR1<int32_t>(&builder, {-1, 0, 2, 1000000000});
450 auto b = ConstantR1<int32_t>(&builder, {-1, 2, 1, -1});
451 Sub(a, b);
452
453 ComputeAndCompareR1<int32_t>(&builder, {0, -2, 1, 1000000001}, {});
454 }
455
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementS32s)456 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
457 XlaBuilder builder(TestName());
458 auto a = ConstantR1<int32_t>(&builder, {});
459 auto b = ConstantR1<int32_t>(&builder, {});
460 Sub(a, b);
461
462 ComputeAndCompareR1<int32_t>(&builder, {}, {});
463 }
464
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantC64s)465 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
466 XlaBuilder builder(TestName());
467 auto a = ConstantR1<complex64>(&builder,
468 {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
469 auto b = ConstantR1<complex64>(
470 &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
471 Sub(a, b);
472
473 ComputeAndCompareR1<complex64>(
474 &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {},
475 error_spec_);
476 }
477
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementC64s)478 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) {
479 XlaBuilder builder(TestName());
480 auto a = ConstantR1<complex64>(&builder, {});
481 auto b = ConstantR1<complex64>(&builder, {});
482 Sub(a, b);
483
484 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
485 }
486
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantF64s)487 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF64s) {
488 XlaBuilder builder(TestName());
489 auto a = ConstantR1<double>(&builder, {-2.5, 3.14, 2.25, -10.0, 6.0});
490 auto b = ConstantR1<double>(&builder, {100.0, 3.13, 2.75, 10.5, -999.0});
491 Sub(a, b);
492
493 ComputeAndCompare(&builder, {}, strict_error_spec_);
494 }
495
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantF32s)496 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
497 XlaBuilder builder(TestName());
498 auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
499 auto b = ConstantR1<float>(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
500 Div(a, b);
501
502 ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
503 error_spec_);
504 }
505
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementF32s)506 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
507 XlaBuilder builder(TestName());
508 auto a = ConstantR1<float>(&builder, {});
509 auto b = ConstantR1<float>(&builder, {});
510 Div(a, b);
511
512 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
513 }
514
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantF64s)515 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF64s) {
516 XlaBuilder builder(TestName());
517 auto a = ConstantR1<double>(
518 &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 1.0, 2.0, 3.2, -4.0, 0.45, 5.7,
519 0.1, 1.0, 2.0, 0.5, -1.0, -0.5, 1.0});
520 auto b = ConstantR1<double>(
521 &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 0.1, 1.0, 2.0, 0.5, -1.0, -0.5,
522 2.1, 3.1, 9.9, -4.5, -11.0, -21.5, M_PI});
523 Div(a, b);
524
525 ComputeAndCompare(&builder, {}, strict_error_spec_);
526 }
527
528 class IntegerDivideOpTest : public ArrayElementwiseOpTest {
529 protected:
530 template <typename T>
TestDivRem(absl::Span<const T> dividends,absl::Span<const T> divisors,absl::Span<const T> quotients,absl::Span<const T> remainders)531 void TestDivRem(absl::Span<const T> dividends, absl::Span<const T> divisors,
532 absl::Span<const T> quotients,
533 absl::Span<const T> remainders) {
534 {
535 XlaBuilder builder(TestName());
536 XlaOp dividend;
537 XlaOp divisor;
538 auto dividend_data =
539 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
540 auto divisor_data =
541 CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
542 Div(dividend, divisor);
543
544 ComputeAndCompareR1<T>(&builder, quotients,
545 {dividend_data.get(), divisor_data.get()});
546 }
547
548 // Test with a compile-time constant divisor.
549 {
550 XlaBuilder builder(TestName());
551 XlaOp dividend;
552 auto dividend_data =
553 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
554 Div(dividend, ConstantR1<T>(&builder, divisors));
555
556 ComputeAndCompareR1<T>(&builder, quotients, {dividend_data.get()});
557 }
558
559 {
560 XlaBuilder builder(TestName());
561 XlaOp dividend;
562 XlaOp divisor;
563 auto dividend_data =
564 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
565 auto divisor_data =
566 CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
567 Rem(dividend, divisor);
568
569 ComputeAndCompareR1<T>(&builder, remainders,
570 {dividend_data.get(), divisor_data.get()});
571 }
572
573 // Test with a compile-time constant divisor.
574 {
575 XlaBuilder builder(TestName());
576 XlaOp dividend;
577 auto dividend_data =
578 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
579 Rem(dividend, ConstantR1<T>(&builder, divisors));
580
581 ComputeAndCompareR1<T>(&builder, remainders, {dividend_data.get()});
582 }
583 }
584 };
585
XLA_TEST_F(IntegerDivideOpTest,DivS32s)586 XLA_TEST_F(IntegerDivideOpTest, DivS32s) {
587 // clang-format off
588 // Some interesting values to test.
589 std::vector<int32_t> vals = {
590 INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff,
591 -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101,
592 7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX};
593 // clang-format on
594
595 std::vector<int32_t> dividends, divisors, quotients, remainders;
596 for (int32_t divisor : vals) {
597 if (divisor != 0) {
598 for (int32_t dividend : vals) {
599 // Avoid integer overflow.
600 if (dividend != INT32_MIN || divisor != -1) {
601 dividends.push_back(dividend);
602 divisors.push_back(divisor);
603 quotients.push_back(dividend / divisor);
604 remainders.push_back(dividend % divisor);
605 }
606 }
607 }
608 }
609
610 TestDivRem<int32_t>(dividends, divisors, quotients, remainders);
611 }
612
XLA_TEST_F(IntegerDivideOpTest,SignedOverflow)613 XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) {
614 std::vector<int32_t> dividends = {5, INT32_MIN}, divisors = {0, -1},
615 quotients = {-1, INT32_MIN}, remainders = {5, 0};
616
617 TestDivRem<int32_t>(dividends, divisors, quotients, remainders);
618 }
619
XLA_TEST_F(IntegerDivideOpTest,DivU32s)620 XLA_TEST_F(IntegerDivideOpTest, DivU32s) {
621 // clang-format off
622 // Some interesting values to test.
623 std::vector<uint32_t> vals = {
624 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000,
625 0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX};
626 // clang-format on
627
628 std::vector<uint32_t> dividends, divisors, quotients, remainders;
629 for (uint32_t divisor : vals) {
630 if (divisor != 0) {
631 for (uint32_t dividend : vals) {
632 dividends.push_back(dividend);
633 divisors.push_back(divisor);
634 quotients.push_back(dividend / divisor);
635 remainders.push_back(dividend % divisor);
636 }
637 }
638 }
639
640 TestDivRem<uint32_t>(dividends, divisors, quotients, remainders);
641 }
642
XLA_TEST_F(IntegerDivideOpTest,UnsignedOverflow)643 XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) {
644 std::vector<int32_t> dividends = {5}, divisors = {0}, quotients = {-1},
645 remainders = {5};
646
647 TestDivRem<int32_t>(dividends, divisors, quotients, remainders);
648 }
649
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantC64s)650 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
651 XlaBuilder builder(TestName());
652 auto a = ConstantR1<complex64>(
653 &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
654 auto b = ConstantR1<complex64>(&builder,
655 {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
656 Div(a, b);
657
658 ComputeAndCompareR1<complex64>(
659 &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_);
660 }
661
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementC64s)662 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) {
663 XlaBuilder builder(TestName());
664 auto a = ConstantR1<complex64>(&builder, {});
665 auto b = ConstantR1<complex64>(&builder, {});
666 Div(a, b);
667
668 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
669 }
670
XLA_TEST_F(ArrayElementwiseOpTest,RemF32s)671 XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
672 XlaBuilder builder(TestName());
673 auto a = ConstantR1<float>(
674 &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
675 auto b = ConstantR1<float>(
676 &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
677 Rem(a, b);
678
679 ComputeAndCompareR1<float>(
680 &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
681 error_spec_);
682 }
683
XLA_TEST_F(ArrayElementwiseOpTest,RemZeroElementF32s)684 XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
685 XlaBuilder builder(TestName());
686 auto a = ConstantR1<float>(&builder, {});
687 auto b = ConstantR1<float>(&builder, {});
688 Rem(a, b);
689
690 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
691 }
692
XLA_TEST_F(ArrayElementwiseOpTest,RemF64s)693 XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
694 XlaBuilder builder(TestName());
695 auto a = ConstantR1<double>(
696 &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
697 auto b = ConstantR1<double>(
698 &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
699 Rem(a, b);
700
701 ComputeAndCompareR1<double>(
702 &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
703 strict_error_spec_);
704 }
705
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantF32s)706 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
707 XlaBuilder builder(TestName());
708 auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
709 auto b = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
710 Mul(a, b);
711
712 ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
713 {}, error_spec_);
714 }
715
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementF32s)716 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
717 XlaBuilder builder(TestName());
718 auto a = ConstantR1<float>(&builder, {});
719 auto b = ConstantR1<float>(&builder, {});
720 Mul(a, b);
721
722 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
723 }
724
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantS32s)725 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
726 std::vector<int32_t> data = {0,
727 1,
728 -1,
729 1234,
730 0x1a243514,
731 std::numeric_limits<int32_t>::max(),
732 std::numeric_limits<int32_t>::min()};
733 // Form the test data set using all products of 'data' with itself.
734 std::vector<int32_t> a_data, b_data, expected;
735 for (int32_t a : data) {
736 for (int32_t b : data) {
737 a_data.push_back(a);
738 b_data.push_back(b);
739 expected.push_back(static_cast<uint32_t>(a) * static_cast<uint32_t>(b));
740 }
741 }
742
743 XlaBuilder builder(TestName());
744 auto a = ConstantR1<int32_t>(&builder, a_data);
745 auto b = ConstantR1<int32_t>(&builder, b_data);
746 Mul(a, b);
747
748 ComputeAndCompareR1<int32_t>(&builder, expected, {});
749 }
750
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementS32s)751 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
752 XlaBuilder builder(TestName());
753 auto a = ConstantR1<int32_t>(&builder, {});
754 auto b = ConstantR1<int32_t>(&builder, {});
755 Mul(a, b);
756
757 ComputeAndCompareR1<int32_t>(&builder, {}, {});
758 }
759
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantU32s)760 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
761 std::vector<uint32_t> data = {0, 1, 0xDEADBEEF, 1234,
762 0x1a243514, 0xFFFFFFFF, 0x80808080};
763
764 // Form the test data set using all products of 'data' with itself.
765 std::vector<uint32_t> a_data, b_data, expected;
766 for (uint32_t a : data) {
767 for (uint32_t b : data) {
768 a_data.push_back(a);
769 b_data.push_back(b);
770 expected.push_back(a * b);
771 }
772 }
773
774 XlaBuilder builder(TestName());
775 auto a = ConstantR1<uint32_t>(&builder, a_data);
776 auto b = ConstantR1<uint32_t>(&builder, b_data);
777 Mul(a, b);
778
779 ComputeAndCompareR1<uint32_t>(&builder, expected, {});
780 }
781
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantC64s)782 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
783 XlaBuilder builder(TestName());
784 auto a = ConstantR1<complex64>(
785 &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
786 auto b = ConstantR1<complex64>(&builder,
787 {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
788 Mul(a, b);
789
790 ComputeAndCompareR1<complex64>(
791 &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {},
792 error_spec_);
793 }
794
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementC64s)795 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) {
796 XlaBuilder builder(TestName());
797 auto a = ConstantR1<complex64>(&builder, {});
798 auto b = ConstantR1<complex64>(&builder, {});
799 Mul(a, b);
800
801 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
802 }
803
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR1)804 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
805 XlaBuilder builder(TestName());
806 auto a = ConstantR1<bool>(&builder, {false, false, true, true});
807 auto b = ConstantR1<bool>(&builder, {false, true, false, true});
808 And(a, b);
809
810 ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
811 }
812
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR2)813 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
814 XlaBuilder builder(TestName());
815 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
816 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
817 And(a, b);
818
819 Array2D<bool> expected_array({{false, false}, {false, true}});
820 ComputeAndCompareR2<bool>(&builder, expected_array, {});
821 }
822
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementPredR1)823 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) {
824 XlaBuilder builder(TestName());
825 auto a = ConstantR1<bool>(&builder, {});
826 auto b = ConstantR1<bool>(&builder, {});
827 And(a, b);
828
829 ComputeAndCompareR1<bool>(&builder, {}, {});
830 }
831
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R1)832 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) {
833 XlaBuilder builder(TestName());
834 auto a = ConstantR1<int32_t>(&builder, {0, -1, -8});
835 auto b = ConstantR1<int32_t>(&builder, {5, -7, 12});
836 And(a, b);
837
838 ComputeAndCompareR1<int32_t>(&builder, {0, -7, 8}, {});
839 }
840
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R2)841 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
842 XlaBuilder builder(TestName());
843 auto a = ConstantR2<int32_t>(&builder, {{0, -5}, {-1, 5}});
844 auto b = ConstantR2<int32_t>(&builder, {{1, -6}, {4, 5}});
845 And(a, b);
846
847 Array2D<int32_t> expected_array({{0, -6}, {4, 5}});
848 ComputeAndCompareR2<int32_t>(&builder, expected_array, {});
849 }
850
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementS32R1)851 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) {
852 XlaBuilder builder(TestName());
853 auto a = ConstantR1<int32_t>(&builder, {});
854 auto b = ConstantR1<int32_t>(&builder, {});
855 And(a, b);
856
857 ComputeAndCompareR1<int32_t>(&builder, {}, {});
858 }
859
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R1)860 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) {
861 XlaBuilder builder(TestName());
862 auto a = ConstantR1<int32_t>(&builder, {0, 1, 8});
863 auto b = ConstantR1<int32_t>(&builder, {5, 7, 12});
864 And(a, b);
865
866 ComputeAndCompareR1<int32_t>(&builder, {0, 1, 8}, {});
867 }
868
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R2)869 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
870 XlaBuilder builder(TestName());
871 auto a = ConstantR2<uint32_t>(&builder, {{0, 1}, {3, 8}});
872 auto b = ConstantR2<uint32_t>(&builder, {{1, 0}, {7, 6}});
873 And(a, b);
874
875 Array2D<uint32_t> expected_array({{0, 0}, {3, 0}});
876 ComputeAndCompareR2<uint32_t>(&builder, expected_array, {});
877 }
878
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementU32R1)879 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) {
880 XlaBuilder builder(TestName());
881 auto a = ConstantR1<uint32_t>(&builder, {});
882 auto b = ConstantR1<uint32_t>(&builder, {});
883 And(a, b);
884
885 ComputeAndCompareR1<uint32_t>(&builder, {}, {});
886 }
887
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR1)888 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) {
889 XlaBuilder builder(TestName());
890 auto a = ConstantR1<bool>(&builder, {false, false, true, true});
891 auto b = ConstantR1<bool>(&builder, {false, true, false, true});
892 Or(a, b);
893
894 ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
895 }
896
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR2)897 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
898 XlaBuilder builder(TestName());
899 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
900 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
901 Or(a, b);
902
903 Array2D<bool> expected_array({{false, true}, {true, true}});
904 ComputeAndCompareR2<bool>(&builder, expected_array, {});
905 }
906
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementPredR1)907 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) {
908 XlaBuilder builder(TestName());
909 auto a = ConstantR1<bool>(&builder, {});
910 auto b = ConstantR1<bool>(&builder, {});
911 Or(a, b);
912
913 ComputeAndCompareR1<bool>(&builder, {}, {});
914 }
915
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R1)916 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) {
917 XlaBuilder builder(TestName());
918 auto a = ConstantR1<int32_t>(&builder, {0, -1, 8});
919 auto b = ConstantR1<int32_t>(&builder, {5, -7, 4});
920 Or(a, b);
921
922 ComputeAndCompareR1<int32_t>(&builder, {5, -1, 12}, {});
923 }
924
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R2)925 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
926 XlaBuilder builder(TestName());
927 auto a = ConstantR2<int32_t>(&builder, {{0, -1}, {8, 8}});
928 auto b = ConstantR2<int32_t>(&builder, {{5, -7}, {4, 1}});
929 Or(a, b);
930
931 Array2D<int32_t> expected_array({{5, -1}, {12, 9}});
932 ComputeAndCompareR2<int32_t>(&builder, expected_array, {});
933 }
934
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementS32R1)935 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) {
936 XlaBuilder builder(TestName());
937 auto a = ConstantR1<int32_t>(&builder, {});
938 auto b = ConstantR1<int32_t>(&builder, {});
939 Or(a, b);
940
941 ComputeAndCompareR1<int32_t>(&builder, {}, {});
942 }
943
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R1)944 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) {
945 XlaBuilder builder(TestName());
946 auto a = ConstantR1<uint32_t>(&builder, {0, 1, 8});
947 auto b = ConstantR1<uint32_t>(&builder, {5, 7, 4});
948 Or(a, b);
949
950 ComputeAndCompareR1<uint32_t>(&builder, {5, 7, 12}, {});
951 }
952
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R2)953 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
954 XlaBuilder builder(TestName());
955 auto a = ConstantR2<uint32_t>(&builder, {{0, 1}, {8, 8}});
956 auto b = ConstantR2<uint32_t>(&builder, {{5, 7}, {4, 1}});
957 Or(a, b);
958
959 Array2D<uint32_t> expected_array({{5, 7}, {12, 9}});
960 ComputeAndCompareR2<uint32_t>(&builder, expected_array, {});
961 }
962
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementU32R1)963 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) {
964 XlaBuilder builder(TestName());
965 auto a = ConstantR1<uint32_t>(&builder, {});
966 auto b = ConstantR1<uint32_t>(&builder, {});
967 Or(a, b);
968
969 ComputeAndCompareR1<uint32_t>(&builder, {}, {});
970 }
971
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR1)972 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) {
973 XlaBuilder builder(TestName());
974 auto a = ConstantR1<bool>(&builder, {false, false, true, true});
975 auto b = ConstantR1<bool>(&builder, {false, true, false, true});
976 Xor(a, b);
977
978 ComputeAndCompareR1<bool>(&builder, {false, true, true, false}, {});
979 }
980
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR2)981 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) {
982 XlaBuilder builder(TestName());
983 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
984 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
985 Xor(a, b);
986
987 Array2D<bool> expected_array({{false, true}, {true, false}});
988 ComputeAndCompareR2<bool>(&builder, expected_array, {});
989 }
990
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementPredR1)991 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) {
992 XlaBuilder builder(TestName());
993 auto a = ConstantR1<bool>(&builder, {});
994 auto b = ConstantR1<bool>(&builder, {});
995 Xor(a, b);
996
997 ComputeAndCompareR1<bool>(&builder, {}, {});
998 }
999
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R1)1000 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) {
1001 XlaBuilder builder(TestName());
1002 auto a = ConstantR1<int32_t>(&builder, {0, -1, 8});
1003 auto b = ConstantR1<int32_t>(&builder, {5, -7, 4});
1004 Xor(a, b);
1005
1006 ComputeAndCompareR1<int32_t>(&builder, {5, 6, 12}, {});
1007 }
1008
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R2)1009 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) {
1010 XlaBuilder builder(TestName());
1011 auto a = ConstantR2<int32_t>(&builder, {{0, -1}, {8, 8}});
1012 auto b = ConstantR2<int32_t>(&builder, {{5, -7}, {4, 1}});
1013 Xor(a, b);
1014
1015 Array2D<int32_t> expected_array({{5, 6}, {12, 9}});
1016 ComputeAndCompareR2<int32_t>(&builder, expected_array, {});
1017 }
1018
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementS32R1)1019 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) {
1020 XlaBuilder builder(TestName());
1021 auto a = ConstantR1<int32_t>(&builder, {});
1022 auto b = ConstantR1<int32_t>(&builder, {});
1023 Xor(a, b);
1024
1025 ComputeAndCompareR1<int32_t>(&builder, {}, {});
1026 }
1027
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R1)1028 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) {
1029 XlaBuilder builder(TestName());
1030 auto a = ConstantR1<uint32_t>(&builder, {0, 1, 8});
1031 auto b = ConstantR1<uint32_t>(&builder, {5, 7, 4});
1032 Xor(a, b);
1033
1034 ComputeAndCompareR1<uint32_t>(&builder, {5, 6, 12}, {});
1035 }
1036
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R2)1037 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) {
1038 XlaBuilder builder(TestName());
1039 auto a = ConstantR2<uint32_t>(&builder, {{0, 1}, {8, 8}});
1040 auto b = ConstantR2<uint32_t>(&builder, {{5, 7}, {4, 1}});
1041 Xor(a, b);
1042
1043 Array2D<uint32_t> expected_array({{5, 6}, {12, 9}});
1044 ComputeAndCompareR2<uint32_t>(&builder, expected_array, {});
1045 }
1046
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementU32R1)1047 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) {
1048 XlaBuilder builder(TestName());
1049 auto a = ConstantR1<uint32_t>(&builder, {});
1050 auto b = ConstantR1<uint32_t>(&builder, {});
1051 Xor(a, b);
1052
1053 ComputeAndCompareR1<uint32_t>(&builder, {}, {});
1054 }
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR1)1055 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) {
1056 XlaBuilder builder(TestName());
1057 auto a = ConstantR1<bool>(&builder, {false, true, true, false});
1058 Not(a);
1059
1060 ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
1061 }
1062
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR2)1063 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
1064 XlaBuilder builder(TestName());
1065 auto a = ConstantR2<bool>(&builder, {{false, true}, {true, false}});
1066 Not(a);
1067
1068 Array2D<bool> expected_array({{true, false}, {false, true}});
1069 ComputeAndCompareR2<bool>(&builder, expected_array, {});
1070 }
1071
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementPredR1)1072 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) {
1073 XlaBuilder builder(TestName());
1074 auto a = ConstantR1<bool>(&builder, {});
1075 Not(a);
1076
1077 ComputeAndCompareR1<bool>(&builder, {}, {});
1078 }
1079
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R1)1080 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) {
1081 XlaBuilder builder(TestName());
1082 auto a = ConstantR1<int32_t>(&builder, {-1, 0, 1});
1083 Not(a);
1084
1085 ComputeAndCompareR1<int32_t>(&builder, {0, -1, -2}, {});
1086 }
1087
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R2)1088 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
1089 XlaBuilder builder(TestName());
1090 auto a = ConstantR2<int32_t>(&builder, {{-1, 0}, {1, 8}});
1091 Not(a);
1092
1093 Array2D<int32_t> expected_array({{0, -1}, {-2, -9}});
1094 ComputeAndCompareR2<int32_t>(&builder, expected_array, {});
1095 }
1096
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementS32R1)1097 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) {
1098 XlaBuilder builder(TestName());
1099 auto a = ConstantR1<int32_t>(&builder, {});
1100 Not(a);
1101
1102 ComputeAndCompareR1<int32_t>(&builder, {}, {});
1103 }
1104
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R1)1105 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) {
1106 XlaBuilder builder(TestName());
1107 auto a = ConstantR1<uint32_t>(&builder, {0, 4294967295});
1108 Not(a);
1109
1110 ComputeAndCompareR1<uint32_t>(&builder, {4294967295, 0}, {});
1111 }
1112
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R2)1113 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
1114 XlaBuilder builder(TestName());
1115 auto a = ConstantR2<uint32_t>(&builder, {{0, 4294967295}, {1, 4294967294}});
1116 Not(a);
1117
1118 Array2D<uint32_t> expected_array({{4294967295, 0}, {4294967294, 1}});
1119 ComputeAndCompareR2<uint32_t>(&builder, expected_array, {});
1120 }
1121
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementU32R1)1122 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
1123 XlaBuilder builder(TestName());
1124 auto a = ConstantR1<uint32_t>(&builder, {});
1125 Not(a);
1126
1127 ComputeAndCompareR1<uint32_t>(&builder, {}, {});
1128 }
1129
XLA_TEST_F(ArrayElementwiseOpTest,PopcntR1)1130 XLA_TEST_F(ArrayElementwiseOpTest, PopcntR1) {
1131 XlaBuilder builder(TestName());
1132 auto a = ConstantR1<int32_t>(&builder, {0, 1, -15, 341});
1133 PopulationCount(a);
1134 ComputeAndCompareR1<int32_t>(&builder, {0, 1, 29, 5}, {});
1135 }
1136
XLA_TEST_F(ArrayElementwiseOpTest,PopcntR2)1137 XLA_TEST_F(ArrayElementwiseOpTest, PopcntR2) {
1138 XlaBuilder builder(TestName());
1139 auto a = ConstantR2<int32_t>(&builder, {{0, 1}, {-15, 341}});
1140 PopulationCount(a);
1141 Array2D<int32_t> expected_array({{0, 1}, {29, 5}});
1142 ComputeAndCompareR2<int32_t>(&builder, expected_array, {});
1143 }
1144
XLA_TEST_F(ArrayElementwiseOpTest,PopcntS64)1145 XLA_TEST_F(ArrayElementwiseOpTest, PopcntS64) {
1146 XlaBuilder builder(TestName());
1147 auto a = ConstantR2<int64_t>(&builder, {{0, -1}, {INT64_MAX, INT64_MAX - 1}});
1148 PopulationCount(a);
1149 Array2D<int64_t> expected_array({{0, 64}, {63, 62}});
1150 ComputeAndCompareR2<int64_t>(&builder, expected_array, {});
1151 }
1152
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftS32)1153 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
1154 XlaBuilder builder(TestName());
1155 auto a = ConstantR1<int32_t>(
1156 &builder, {static_cast<int32_t>(0x12345678),
1157 static_cast<int32_t>(0xF0001000), 1, 3, 77, 1, -3, 77});
1158 auto b = ConstantR1<int32_t>(&builder, {4, 8, 2, 7, 15, 32, 100, -1});
1159 ShiftLeft(a, b);
1160
1161 ComputeAndCompareR1<int32_t>(&builder,
1162 {static_cast<int32_t>(0x23456780), 0x00100000,
1163 0x4, 0x180, 2523136, 0, 0, 0},
1164 {});
1165 }
1166
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticS32)1167 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
1168 XlaBuilder builder(TestName());
1169 auto a = ConstantR1<int32_t>(
1170 &builder, {static_cast<int32_t>(0x92345678),
1171 static_cast<int32_t>(0x10001000), 1, 3, 77, 1, -3, 77});
1172 auto b = ConstantR1<int32_t>(&builder, {4, 8, 2, 7, 2, 32, 100, -1});
1173 ShiftRightArithmetic(a, b);
1174
1175 ComputeAndCompareR1<int32_t>(
1176 &builder,
1177 {static_cast<int32_t>(0xF9234567), static_cast<int32_t>(0x00100010), 0, 0,
1178 19, 0, -1, 0},
1179 {});
1180 }
1181
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalS32)1182 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
1183 XlaBuilder builder(TestName());
1184 auto a = ConstantR1<int32_t>(
1185 &builder, {static_cast<int32_t>(0x92345678),
1186 static_cast<int32_t>(0x10001000), 1, 3, 77, 1, -3, 77});
1187 auto b = ConstantR1<int32_t>(&builder, {4, 8, 2, 7, 5, 32, 100, -1});
1188 ShiftRightLogical(a, b);
1189
1190 ComputeAndCompareR1<int32_t>(&builder,
1191 {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1192 }
1193
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftU32)1194 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
1195 XlaBuilder builder(TestName());
1196 auto a = ConstantR1<uint32_t>(&builder,
1197 {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77});
1198 auto b = ConstantR1<uint32_t>(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u});
1199 ShiftLeft(a, b);
1200
1201 ComputeAndCompareR1<uint32_t>(
1202 &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {});
1203 }
1204
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticU32)1205 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
1206 XlaBuilder builder(TestName());
1207 auto a = ConstantR1<uint32_t>(&builder,
1208 {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1209 auto b = ConstantR1<uint32_t>(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u});
1210 ShiftRightArithmetic(a, b);
1211
1212 ComputeAndCompareR1<uint32_t>(
1213 &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {});
1214 }
1215
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalU32)1216 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
1217 XlaBuilder builder(TestName());
1218 auto a = ConstantR1<uint32_t>(&builder,
1219 {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1220 auto b = ConstantR1<uint32_t>(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u});
1221 ShiftRightLogical(a, b);
1222
1223 ComputeAndCompareR1<uint32_t>(&builder,
1224 {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1225 }
1226
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqF32s)1227 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
1228 SetFastMathDisabled(true);
1229 XlaBuilder builder(TestName());
1230 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1231 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN});
1232 Eq(lhs, rhs);
1233
1234 ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1235 }
1236
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqF32sTO)1237 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) {
1238 SetFastMathDisabled(true);
1239 XlaBuilder builder(TestName());
1240 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1241 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN});
1242 EqTotalOrder(lhs, rhs);
1243
1244 ComputeAndCompareR1<bool>(&builder, {false, false, true, true, false}, {});
1245 }
1246
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementF32s)1247 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
1248 XlaBuilder builder(TestName());
1249 auto lhs = ConstantR1<float>(&builder, {});
1250 auto rhs = ConstantR1<float>(&builder, {});
1251 Eq(lhs, rhs);
1252
1253 ComputeAndCompareR1<bool>(&builder, {}, {});
1254 }
1255
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeF32s)1256 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
1257 SetFastMathDisabled(true);
1258 XlaBuilder builder(TestName());
1259 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1260 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1261 Ge(lhs, rhs);
1262
1263 ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1264 }
1265
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeF32sTO)1266 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) {
1267 SetFastMathDisabled(true);
1268 XlaBuilder builder(TestName());
1269 // For portability, need to represent NAN using the following call.
1270 // The C++ standard does not specify if quiet_NaN() sets the sign bit of
1271 // its result. The call to std::fabs will ensure that it is not set.
1272 auto nan = std::fabs(std::numeric_limits<float>::quiet_NaN());
1273 auto lhs =
1274 ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, nan, 6.0f, 6.0f});
1275 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, nan, -nan});
1276 GeTotalOrder(lhs, rhs);
1277
1278 ComputeAndCompareR1<bool>(&builder, {false, true, true, true, false, true},
1279 {});
1280 }
1281
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtF32s)1282 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
1283 SetFastMathDisabled(true);
1284 XlaBuilder builder(TestName());
1285 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1286 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1287 Gt(lhs, rhs);
1288
1289 ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1290 }
1291
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeF32s)1292 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
1293 SetFastMathDisabled(true);
1294 XlaBuilder builder(TestName());
1295 auto lhs = ConstantR1<float>(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f});
1296 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1297 Le(lhs, rhs);
1298
1299 ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
1300 }
1301
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtF32s)1302 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
1303 SetFastMathDisabled(true);
1304 XlaBuilder builder(TestName());
1305 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1306 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1307 Lt(lhs, rhs);
1308
1309 ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
1310 }
1311
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqS32s)1312 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
1313 const int32_t min = std::numeric_limits<int32_t>::min();
1314 const int32_t max = std::numeric_limits<int32_t>::max();
1315 XlaBuilder builder(TestName());
1316 auto lhs =
1317 ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1318 auto rhs =
1319 ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1320 Eq(lhs, rhs);
1321
1322 ComputeAndCompareR1<bool>(
1323 &builder, {true, false, false, false, true, false, false, false, true},
1324 {});
1325 }
1326
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementS32s)1327 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
1328 XlaBuilder builder(TestName());
1329 auto lhs = ConstantR1<int32_t>(&builder, {});
1330 auto rhs = ConstantR1<int32_t>(&builder, {});
1331 Eq(lhs, rhs);
1332
1333 ComputeAndCompareR1<bool>(&builder, {}, {});
1334 }
1335
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqC64s)1336 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) {
1337 SetFastMathDisabled(true);
1338 XlaBuilder builder(TestName());
1339 auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1340 {1.0f, 25.5f},
1341 {2.25f, -3.0f},
1342 {NAN, 0.0f},
1343 {1.0f, 6.0f}});
1344 auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1345 {1.0f, 5.0f},
1346 {2.25f, -3.0f},
1347 {10.0f, 0.0f},
1348 {1.0f, NAN}});
1349 Eq(lhs, rhs);
1350
1351 ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1352 }
1353
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementC64s)1354 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) {
1355 XlaBuilder builder(TestName());
1356 auto lhs = ConstantR1<complex64>(&builder, {});
1357 auto rhs = ConstantR1<complex64>(&builder, {});
1358 Eq(lhs, rhs);
1359
1360 ComputeAndCompareR1<bool>(&builder, {}, {});
1361 }
1362
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeC64s)1363 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) {
1364 // Disable fast-math because we're operating on NaNs.
1365 SetFastMathDisabled(true);
1366
1367 XlaBuilder builder(TestName());
1368 auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1369 {1.0f, 25.5f},
1370 {2.25f, -3.0f},
1371 {NAN, 0.0f},
1372 {1.0f, 6.0f}});
1373 auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1374 {1.0f, 5.0f},
1375 {2.25f, -3.0f},
1376 {10.0f, 0.0f},
1377 {1.0f, NAN}});
1378 Ne(lhs, rhs);
1379
1380 ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {});
1381 }
1382
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeF32s)1383 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
1384 // Disable fast-math because we're operating on NaNs.
1385 SetFastMathDisabled(true);
1386
1387 XlaBuilder builder(TestName());
1388 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1389 auto rhs = ConstantR1<float>(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN});
1390 Ne(lhs, rhs);
1391
1392 ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
1393 }
1394
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeS32s)1395 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
1396 const int32_t min = std::numeric_limits<int32_t>::min();
1397 const int32_t max = std::numeric_limits<int32_t>::max();
1398 XlaBuilder builder(TestName());
1399 auto lhs =
1400 ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1401 auto rhs =
1402 ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1403 Ne(lhs, rhs);
1404
1405 ComputeAndCompareR1<bool>(
1406 &builder, {false, true, true, true, false, true, true, true, false}, {});
1407 }
1408
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeS32s)1409 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
1410 const int32_t min = std::numeric_limits<int32_t>::min();
1411 const int32_t max = std::numeric_limits<int32_t>::max();
1412 XlaBuilder builder(TestName());
1413 auto lhs =
1414 ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1415 auto rhs =
1416 ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1417 Ge(lhs, rhs);
1418
1419 ComputeAndCompareR1<bool>(
1420 &builder, {true, false, false, true, true, false, true, true, true}, {});
1421 }
1422
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtS32s)1423 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
1424 const int32_t min = std::numeric_limits<int32_t>::min();
1425 const int32_t max = std::numeric_limits<int32_t>::max();
1426 XlaBuilder builder(TestName());
1427 auto lhs =
1428 ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1429 auto rhs =
1430 ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1431 Gt(lhs, rhs);
1432
1433 ComputeAndCompareR1<bool>(
1434 &builder, {false, false, false, true, false, false, true, true, false},
1435 {});
1436 }
1437
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeS32s)1438 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
1439 const int32_t min = std::numeric_limits<int32_t>::min();
1440 const int32_t max = std::numeric_limits<int32_t>::max();
1441 XlaBuilder builder(TestName());
1442 auto lhs =
1443 ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1444 auto rhs =
1445 ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1446 Le(lhs, rhs);
1447
1448 ComputeAndCompareR1<bool>(
1449 &builder, {true, true, true, false, true, true, false, false, true}, {});
1450 }
1451
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtS32s)1452 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
1453 const int32_t min = std::numeric_limits<int32_t>::min();
1454 const int32_t max = std::numeric_limits<int32_t>::max();
1455 XlaBuilder builder(TestName());
1456 auto lhs =
1457 ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1458 auto rhs =
1459 ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1460 Lt(lhs, rhs);
1461
1462 ComputeAndCompareR1<bool>(
1463 &builder, {false, true, true, false, false, true, false, false, false},
1464 {});
1465 }
1466
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqU32s)1467 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
1468 const uint32_t max = std::numeric_limits<uint32_t>::max();
1469 XlaBuilder builder(TestName());
1470 auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1471 auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1472 Eq(lhs, rhs);
1473
1474 ComputeAndCompareR1<bool>(
1475 &builder, {true, false, false, false, true, false, false, false, true},
1476 {});
1477 }
1478
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeU32s)1479 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
1480 const uint32_t max = std::numeric_limits<uint32_t>::max();
1481 XlaBuilder builder(TestName());
1482 auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1483 auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1484 Ne(lhs, rhs);
1485
1486 ComputeAndCompareR1<bool>(
1487 &builder, {false, true, true, true, false, true, true, true, false}, {});
1488 }
1489
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeU32s)1490 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
1491 const uint32_t max = std::numeric_limits<uint32_t>::max();
1492 XlaBuilder builder(TestName());
1493 auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1494 auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1495 Ge(lhs, rhs);
1496
1497 ComputeAndCompareR1<bool>(
1498 &builder, {true, false, false, true, true, false, true, true, true}, {});
1499 }
1500
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtU32s)1501 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
1502 const uint32_t max = std::numeric_limits<uint32_t>::max();
1503 XlaBuilder builder(TestName());
1504 auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1505 auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1506 Gt(lhs, rhs);
1507
1508 ComputeAndCompareR1<bool>(
1509 &builder, {false, false, false, true, false, false, true, true, false},
1510 {});
1511 }
1512
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeU32s)1513 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
1514 const uint32_t max = std::numeric_limits<uint32_t>::max();
1515 XlaBuilder builder(TestName());
1516 auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1517 auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1518 Le(lhs, rhs);
1519
1520 ComputeAndCompareR1<bool>(
1521 &builder, {true, true, true, false, true, true, false, false, true}, {});
1522 }
1523
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtU32s)1524 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
1525 const uint32_t max = std::numeric_limits<uint32_t>::max();
1526 XlaBuilder builder(TestName());
1527 auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1528 auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1529 Lt(lhs, rhs);
1530
1531 ComputeAndCompareR1<bool>(
1532 &builder, {false, true, true, false, false, true, false, false, false},
1533 {});
1534 }
1535
XLA_TEST_F(ArrayElementwiseOpTest,PowF32s)1536 XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
1537 SetFastMathDisabled(true);
1538 XlaBuilder builder(TestName());
1539 auto lhs = ConstantR1<float>(
1540 &builder, {0.0f, 4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f});
1541 auto rhs = ConstantR1<float>(
1542 &builder, {0.0f, 2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f});
1543 Pow(lhs, rhs);
1544
1545 ComputeAndCompareR1<float>(&builder,
1546 {1.0f, 16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f},
1547 {}, error_spec_);
1548 }
1549
XLA_TEST_F(ArrayElementwiseOpTest,PowNonIntegerF32s)1550 XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
1551 SetFastMathDisabled(true);
1552 XlaBuilder builder(TestName());
1553 auto lhs = ConstantR1<float>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f});
1554 auto rhs = ConstantR1<float>(&builder, {0.5f, 0.6f, -0.6f, -0.6f});
1555 Pow(lhs, rhs);
1556
1557 ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
1558 error_spec_);
1559 }
1560
XLA_TEST_F(ArrayElementwiseOpTest,PowC64s)1561 XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) {
1562 SetFastMathDisabled(true);
1563 XlaBuilder builder(TestName());
1564 auto lhs =
1565 ConstantR1<complex64>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f});
1566 auto rhs =
1567 ConstantR1<complex64>(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f});
1568 Pow(lhs, rhs);
1569
1570 ComputeAndCompareR1<complex64>(&builder,
1571 {
1572 {0, 1.41421356},
1573 {-2.27443288e-01, 0.69999846},
1574 {-4.19847531e-01, -1.29215783},
1575 {0, 0},
1576 {0, 0},
1577 {1, 0},
1578 },
1579 {}, error_spec_);
1580 }
1581
XLA_TEST_F(ArrayElementwiseOpTest,PowZeroElementF32s)1582 XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
1583 XlaBuilder builder(TestName());
1584 auto lhs = ConstantR1<float>(&builder, {});
1585 auto rhs = ConstantR1<float>(&builder, {});
1586 Pow(lhs, rhs);
1587
1588 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1589 }
1590
1591 // Some Pow cases that can be implemented more efficiently.
XLA_TEST_F(ArrayElementwiseOpTest,PowSpecialF32)1592 XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
1593 XlaBuilder b(TestName());
1594
1595 std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
1596 std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1597
1598 Literal param_literal = LiteralUtil::CreateR1<float>(values);
1599 std::unique_ptr<GlobalData> param_data =
1600 client_->TransferToServer(param_literal).value();
1601
1602 auto sum = ConstantR0<float>(&b, 0.0f);
1603 auto param = Parameter(&b, 0, param_literal.shape(), "param");
1604 for (float exponent : exponents) {
1605 sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
1606 }
1607
1608 std::vector<float> expected;
1609 expected.reserve(values.size());
1610 for (auto value : values) {
1611 float sum = 0.0f;
1612 for (float exponent : exponents) {
1613 sum += std::pow(value, exponent);
1614 }
1615 expected.push_back(sum);
1616 }
1617
1618 ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
1619 }
1620
XLA_TEST_F(ArrayElementwiseOpTest,PowOfExpF32)1621 XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
1622 XlaBuilder b(TestName());
1623
1624 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1625 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1626
1627 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1628 std::unique_ptr<GlobalData> data0 =
1629 client_->TransferToServer(literal0).value();
1630 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1631 std::unique_ptr<GlobalData> data1 =
1632 client_->TransferToServer(literal1).value();
1633 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1634 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1635 Pow(Exp(param0), param1);
1636
1637 std::vector<float> expected(values0.size());
1638 for (int64_t i = 0; i < values0.size(); ++i) {
1639 expected[i] = std::pow(std::exp(values0[i]), values1[i]);
1640 }
1641
1642 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1643 error_spec_);
1644 }
1645
XLA_TEST_F(ArrayElementwiseOpTest,LogOfPowerF32)1646 XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
1647 XlaBuilder b(TestName());
1648
1649 std::vector<float> values0 = {1.0f, -10.0f, -2.0f, 2.0f, 3.2f,
1650 4.0f, 0.5f, 5.7f, 0.0f};
1651 std::vector<float> values1 = {0.0f, 10.0f, -4.0f, 1.0f, 2.0f,
1652 0.5f, -1.0f, -0.5f, 0.0f};
1653
1654 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1655 std::unique_ptr<GlobalData> data0 =
1656 client_->TransferToServer(literal0).value();
1657 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1658 std::unique_ptr<GlobalData> data1 =
1659 client_->TransferToServer(literal1).value();
1660 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1661 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1662 Log(Pow(param0, param1));
1663
1664 std::vector<float> expected(values0.size());
1665 for (int64_t i = 0; i < values0.size(); ++i) {
1666 expected[i] = std::log(std::pow(values0[i], values1[i]));
1667 }
1668
1669 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1670 error_spec_);
1671 }
1672
XLA_TEST_F(ArrayElementwiseOpTest,MulOfExpF32)1673 XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
1674 XlaBuilder b(TestName());
1675
1676 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1677 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1678
1679 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1680 std::unique_ptr<GlobalData> data0 =
1681 client_->TransferToServer(literal0).value();
1682 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1683 std::unique_ptr<GlobalData> data1 =
1684 client_->TransferToServer(literal1).value();
1685 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1686 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1687 Mul(Exp(param0), Exp(param1));
1688
1689 std::vector<float> expected(values0.size());
1690 for (int64_t i = 0; i < values0.size(); ++i) {
1691 expected[i] = std::exp(values0[i]) * std::exp(values1[i]);
1692 }
1693
1694 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1695 error_spec_);
1696 }
1697
XLA_TEST_F(ArrayElementwiseOpTest,DivOfExpF32)1698 XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
1699 XlaBuilder b(TestName());
1700
1701 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1702 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1703
1704 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1705 std::unique_ptr<GlobalData> data0 =
1706 client_->TransferToServer(literal0).value();
1707 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1708 std::unique_ptr<GlobalData> data1 =
1709 client_->TransferToServer(literal1).value();
1710 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1711 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1712 Div(param0, Exp(param1));
1713
1714 std::vector<float> expected(values0.size());
1715 for (int64_t i = 0; i < values0.size(); ++i) {
1716 expected[i] = values0[i] / std::exp(values1[i]);
1717 }
1718
1719 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1720 error_spec_);
1721 }
1722
XLA_TEST_F(ArrayElementwiseOpTest,Div3_lhs_F32)1723 XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
1724 XlaBuilder b(TestName());
1725
1726 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1727 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1728 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1729
1730 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1731 std::unique_ptr<GlobalData> data0 =
1732 client_->TransferToServer(literal0).value();
1733
1734 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1735 std::unique_ptr<GlobalData> data1 =
1736 client_->TransferToServer(literal1).value();
1737
1738 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1739 std::unique_ptr<GlobalData> data2 =
1740 client_->TransferToServer(literal2).value();
1741 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1742 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1743 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1744 Div(Div(param0, param1), param2);
1745
1746 std::vector<float> expected(values0.size());
1747 for (int64_t i = 0; i < values0.size(); ++i) {
1748 expected[i] = (values0[i] / values1[i]) / values2[i];
1749 }
1750
1751 ComputeAndCompareR1<float>(
1752 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1753 }
1754
XLA_TEST_F(ArrayElementwiseOpTest,Div3_rhs_F32)1755 XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
1756 XlaBuilder b(TestName());
1757
1758 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1759 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1760 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1761
1762 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1763 std::unique_ptr<GlobalData> data0 =
1764 client_->TransferToServer(literal0).value();
1765
1766 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1767 std::unique_ptr<GlobalData> data1 =
1768 client_->TransferToServer(literal1).value();
1769
1770 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1771 std::unique_ptr<GlobalData> data2 =
1772 client_->TransferToServer(literal2).value();
1773
1774 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1775 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1776 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1777 Div(param0, Div(param1, param2));
1778
1779 std::vector<float> expected(values0.size());
1780 for (int64_t i = 0; i < values0.size(); ++i) {
1781 expected[i] = values0[i] / (values1[i] / values2[i]);
1782 }
1783
1784 ComputeAndCompareR1<float>(
1785 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1786 }
1787
XLA_TEST_F(ArrayElementwiseOpTest,DivOfPowerF32)1788 XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
1789 XlaBuilder b(TestName());
1790
1791 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1792 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
1793 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
1794
1795 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1796 std::unique_ptr<GlobalData> data0 =
1797 client_->TransferToServer(literal0).value();
1798
1799 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1800 std::unique_ptr<GlobalData> data1 =
1801 client_->TransferToServer(literal1).value();
1802
1803 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1804 std::unique_ptr<GlobalData> data2 =
1805 client_->TransferToServer(literal2).value();
1806
1807 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1808 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1809 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1810 Div(param0, Pow(param1, param2));
1811
1812 std::vector<float> expected(values0.size());
1813 for (int64_t i = 0; i < values0.size(); ++i) {
1814 expected[i] = values0[i] / std::pow(values1[i], values2[i]);
1815 }
1816
1817 ComputeAndCompareR1<float>(
1818 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1819 }
1820
XLA_TEST_F(ArrayElementwiseOpTest,Div4F32)1821 XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
1822 XlaBuilder b(TestName());
1823
1824 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1825 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1826 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1827 std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
1828
1829 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1830 std::unique_ptr<GlobalData> data0 =
1831 client_->TransferToServer(literal0).value();
1832
1833 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1834 std::unique_ptr<GlobalData> data1 =
1835 client_->TransferToServer(literal1).value();
1836
1837 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1838 std::unique_ptr<GlobalData> data2 =
1839 client_->TransferToServer(literal2).value();
1840
1841 Literal literal3 = LiteralUtil::CreateR1<float>(values3);
1842 std::unique_ptr<GlobalData> data3 =
1843 client_->TransferToServer(literal3).value();
1844
1845 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1846 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1847 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1848 auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
1849 Div(Div(param0, param1), Div(param2, param3));
1850
1851 std::vector<float> expected(values0.size());
1852 for (int64_t i = 0; i < values0.size(); ++i) {
1853 expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]);
1854 }
1855
1856 ComputeAndCompareR1<float>(
1857 &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()},
1858 error_spec_);
1859 }
1860
TEST_P(ArrayElementwiseOpTestParamCount,SquareManyValues)1861 TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
1862 const int count = GetParam();
1863 XlaBuilder builder(TestName());
1864 std::vector<float> values;
1865 values.reserve(count);
1866 for (int i = 0; i < count; ++i) {
1867 values.push_back(i / static_cast<float>(count));
1868 }
1869 auto x = ConstantR1<float>(&builder, values);
1870 Pow(x, ConstantR0<float>(&builder, 2.0f));
1871
1872 std::vector<float> expected;
1873 expected.reserve(values.size());
1874 for (float value : values) {
1875 expected.push_back(value * value);
1876 }
1877
1878 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1879 }
1880
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4D)1881 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
1882 XlaBuilder builder(TestName());
1883 Array4D<float> values(2, 2, 2, 2);
1884
1885 std::vector<float> values_vector;
1886 std::vector<float> expected_vector;
1887 const auto num_elements = values.num_elements();
1888 values_vector.reserve(num_elements);
1889 expected_vector.reserve(num_elements);
1890 for (int i = 0; i < num_elements; ++i) {
1891 values_vector.push_back(static_cast<float>(i) / values.num_elements());
1892 expected_vector.push_back(values_vector.back() * values_vector.back());
1893 }
1894 values.SetValues(values_vector);
1895
1896 Array4D<float> expected(2, 2, 2, 2, expected_vector);
1897
1898 auto x = ConstantR4FromArray4D<float>(&builder, values);
1899 Pow(x, ConstantR0<float>(&builder, 2.0f));
1900
1901 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1902 }
1903
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4DZeroElements)1904 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
1905 XlaBuilder builder(TestName());
1906 Array4D<float> values(2, 2, 0, 2);
1907 Array4D<float> expected(2, 2, 0, 2);
1908
1909 auto x = ConstantR4FromArray4D<float>(&builder, values);
1910 Pow(x, ConstantR0<float>(&builder, 2.0f));
1911
1912 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1913 }
1914
XLA_TEST_F(ArrayElementwiseOpTest,MinF32s)1915 XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
1916 XlaBuilder builder(TestName());
1917 SetFastMathDisabled(true);
1918 auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1919 auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1920 Min(lhs, rhs);
1921
1922 ComputeAndCompareR1<float>(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {},
1923 error_spec_);
1924 }
1925
XLA_TEST_F(ArrayElementwiseOpTest,MinZeroElementF32s)1926 XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
1927 XlaBuilder builder(TestName());
1928 auto lhs = ConstantR1<float>(&builder, {});
1929 auto rhs = ConstantR1<float>(&builder, {});
1930 Min(lhs, rhs);
1931 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1932 }
1933
XLA_TEST_F(ArrayElementwiseOpTest,MinF64s)1934 XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
1935 XlaBuilder builder(TestName());
1936 SetFastMathDisabled(true);
1937 auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1938 auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1939 Min(lhs, rhs);
1940
1941 ComputeAndCompareR1<double>(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {},
1942 strict_error_spec_);
1943 }
1944
XLA_TEST_F(ArrayElementwiseOpTest,MaxF32s)1945 XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
1946 XlaBuilder builder(TestName());
1947 SetFastMathDisabled(true);
1948 auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1949 auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1950 Max(lhs, rhs);
1951
1952 ComputeAndCompareR1<float>(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {},
1953 error_spec_);
1954 }
1955
XLA_TEST_F(ArrayElementwiseOpTest,DISABLED_ON_CPU (DefaultMaxF32sNaNPropagation))1956 XLA_TEST_F(ArrayElementwiseOpTest,
1957 DISABLED_ON_CPU(DefaultMaxF32sNaNPropagation)) {
1958 XlaBuilder builder(TestName());
1959 auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1960 auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1961 Max(lhs, rhs);
1962
1963 ComputeAndCompareR1<float>(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {},
1964 error_spec_);
1965 }
1966
XLA_TEST_F(ArrayElementwiseOpTest,MaxZeroElementF32s)1967 XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
1968 XlaBuilder builder(TestName());
1969 auto lhs = ConstantR1<float>(&builder, {});
1970 auto rhs = ConstantR1<float>(&builder, {});
1971 Max(lhs, rhs);
1972 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1973 }
1974
XLA_TEST_F(ArrayElementwiseOpTest,MaxF64s)1975 XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
1976 XlaBuilder builder(TestName());
1977 SetFastMathDisabled(true);
1978 auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1979 auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1980 Max(lhs, rhs);
1981
1982 ComputeAndCompareR1<double>(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {},
1983 strict_error_spec_);
1984 }
1985
XLA_TEST_F(ArrayElementwiseOpTest,MaxS32s)1986 XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
1987 const int32_t min = std::numeric_limits<int32_t>::min();
1988 const int32_t max = std::numeric_limits<int32_t>::max();
1989 XlaBuilder builder(TestName());
1990 auto x = ConstantR1<int32_t>(
1991 &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1992 auto y = ConstantR1<int32_t>(
1993 &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1994 Max(x, y);
1995
1996 std::vector<int32_t> expected = {min, max, 0, -1, 0, 0, 0,
1997 1, 1, 10, max, max, max};
1998 ComputeAndCompareR1<int32_t>(&builder, expected, {});
1999 }
2000
XLA_TEST_F(ArrayElementwiseOpTest,MinS32s)2001 XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
2002 const int32_t min = std::numeric_limits<int32_t>::min();
2003 const int32_t max = std::numeric_limits<int32_t>::max();
2004 XlaBuilder builder(TestName());
2005 auto x = ConstantR1<int32_t>(
2006 &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
2007 auto y = ConstantR1<int32_t>(
2008 &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
2009 Min(x, y);
2010
2011 std::vector<int32_t> expected = {min, min, min, -10, -1, -1, 0,
2012 0, 0, 1, 0, max, min};
2013 ComputeAndCompareR1<int32_t>(&builder, expected, {});
2014 }
2015
XLA_TEST_F(ArrayElementwiseOpTest,MaxU32s)2016 XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
2017 const uint32_t max = std::numeric_limits<uint32_t>::max();
2018 XlaBuilder builder(TestName());
2019 auto x = ConstantR1<uint32_t>(&builder, {0, 0, 1, 1, 1, max, max, max});
2020 auto y = ConstantR1<uint32_t>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
2021 Max(x, y);
2022
2023 std::vector<uint32_t> expected = {0, 1, 1, 1, 10, max, max, max};
2024 ComputeAndCompareR1<uint32_t>(&builder, expected, {});
2025 }
2026
XLA_TEST_F(ArrayElementwiseOpTest,MinU32s)2027 XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
2028 const uint32_t max = std::numeric_limits<uint32_t>::max();
2029 XlaBuilder builder(TestName());
2030 auto x = ConstantR1<uint32_t>(&builder, {0, 0, 1, 1, 1, max, max, max});
2031 auto y = ConstantR1<uint32_t>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
2032 Min(x, y);
2033
2034 std::vector<uint32_t> expected = {0, 0, 0, 1, 1, 0, 234234, max};
2035 ComputeAndCompareR1<uint32_t>(&builder, expected, {});
2036 }
2037
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenF32s)2038 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
2039 XlaBuilder builder(TestName());
2040 auto x = ConstantR1<float>(
2041 &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
2042 auto y = ConstantR1<float>(
2043 &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
2044 Max(x, y);
2045
2046 std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
2047 5.0, 6.0, 7.0, 8.0, 9.0};
2048 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
2049 }
2050
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S1AndR1S0F32s)2051 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
2052 XlaBuilder builder(TestName());
2053 auto u = ConstantR1<float>(&builder, {3.5});
2054 auto v = ConstantR1<float>(&builder, {});
2055 Max(u, v);
2056
2057 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
2058 }
2059
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S0AndR2S0x2F32s)2060 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
2061 for (int broadcast_dim : {0, 1}) {
2062 XlaBuilder builder(TestName());
2063 auto u = ConstantR1<float>(&builder, {3.5});
2064 auto v = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
2065 Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
2066
2067 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
2068 }
2069 }
2070
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DF32s)2071 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
2072 XlaBuilder builder(TestName());
2073 auto v = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
2074 auto m = ConstantR2<float>(&builder,
2075 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2076 Max(v, m, /*broadcast_dimensions=*/{1});
2077
2078 Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
2079 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2080 }
2081
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DZeroElementF32s)2082 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
2083 XlaBuilder builder(TestName());
2084 auto v = ConstantR1<float>(&builder, {});
2085 auto m = ConstantR2<float>(&builder, {{}, {}});
2086 Max(v, m, /*broadcast_dimensions=*/{1});
2087
2088 Array2D<float> expected({{}, {}});
2089 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2090 }
2091
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarS32s)2092 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
2093 XlaBuilder builder(TestName());
2094 auto scalar = ConstantR0<int32_t>(&builder, 2);
2095 Array3D<int32_t> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
2096 auto array = ConstantR3FromArray3D<int32_t>(&builder, a_3d);
2097 Max(array, scalar, /*broadcast_dimensions=*/{});
2098
2099 Array3D<int32_t> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
2100 ComputeAndCompareR3<int32_t>(&builder, expected, {});
2101 }
2102
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarZeroElementS32s)2103 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
2104 XlaBuilder builder(TestName());
2105 auto scalar = ConstantR0<int32_t>(&builder, 2);
2106 Array3D<int32_t> a_3d(2, 0, 3);
2107 auto array = ConstantR3FromArray3D<int32_t>(&builder, a_3d);
2108 Max(array, scalar, /*broadcast_dimensions=*/{});
2109
2110 Array3D<int32_t> expected(2, 0, 3);
2111 ComputeAndCompareR3<int32_t>(&builder, expected, {});
2112 }
2113
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DF32s)2114 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
2115 XlaBuilder builder(TestName());
2116 auto m = ConstantR2<float>(&builder,
2117 {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
2118 auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
2119 Min(m, v, /*broadcast_dimensions=*/{0});
2120
2121 Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
2122 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2123 }
2124
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DZeroElementF32s)2125 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
2126 XlaBuilder builder(TestName());
2127 auto m = ConstantR2<float>(&builder, {{}, {}});
2128 auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
2129 Min(m, v, /*broadcast_dimensions=*/{0});
2130
2131 Array2D<float> expected({{}, {}});
2132 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2133 }
2134
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DF32s)2135 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
2136 XlaBuilder builder(TestName());
2137 auto array2d =
2138 ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2139 auto array4d = ConstantR4FromArray4D<float>(
2140 &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
2141 {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
2142 Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2143
2144 Array4D<float> expected(
2145 {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
2146 {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}});
2147 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2148 }
2149
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DZeroElementF32s)2150 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
2151 XlaBuilder builder(TestName());
2152 auto array2d =
2153 ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2154 Array4D<float> arg(2, 2, 0, 3);
2155 auto array4d = ConstantR4FromArray4D<float>(&builder, arg);
2156 Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2157
2158 Array4D<float> expected(2, 2, 0, 3);
2159 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2160 }
2161
XLA_TEST_F(ArrayElementwiseOpTest,MinTenS32s)2162 XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
2163 XlaBuilder builder(TestName());
2164 auto x = ConstantR1<int32_t>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2165 auto y = ConstantR1<int32_t>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2166 Min(x, y);
2167
2168 std::vector<int32_t> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
2169 ComputeAndCompareR1<int32_t>(&builder, expected, {});
2170 }
2171
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenS32s)2172 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
2173 XlaBuilder builder(TestName());
2174 auto x = ConstantR1<int32_t>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2175 auto y = ConstantR1<int32_t>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2176 Max(x, y);
2177
2178 std::vector<int32_t> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
2179 ComputeAndCompareR1<int32_t>(&builder, expected, {});
2180 }
2181
XLA_TEST_F(ArrayElementwiseOpTest,RemTwoConstantS32s)2182 XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
2183 XlaBuilder builder(TestName());
2184 auto a = ConstantR1<int32_t>(&builder, {-3, 26, 2, -1, 1});
2185 auto b = ConstantR1<int32_t>(&builder, {10, 5, 1, 10, -10});
2186 Rem(a, b);
2187
2188 ComputeAndCompareR1<int32_t>(&builder, {-3, 1, 0, -1, 1}, {});
2189 }
2190
XLA_TEST_F(ArrayElementwiseOpTest,NonNanClampF32)2191 XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
2192 XlaBuilder builder(TestName());
2193 auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2194 auto argument =
2195 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2196 auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2197 Clamp(minimum, argument, maximum);
2198
2199 ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
2200 error_spec_);
2201 }
2202
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32)2203 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) {
2204 SetFastMathDisabled(true);
2205 XlaBuilder builder(TestName());
2206 auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN});
2207 auto argument =
2208 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2209 auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f});
2210 Clamp(minimum, argument, maximum);
2211
2212 ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {},
2213 error_spec_);
2214 }
2215
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32Scalar)2216 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
2217 XlaBuilder builder(TestName());
2218 auto minimum = ConstantR0<float>(&builder, 0.0f);
2219 auto argument = ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2220 auto maximum = ConstantR0<float>(&builder, 5.0f);
2221 Clamp(minimum, argument, maximum);
2222
2223 ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
2224 error_spec_);
2225 }
2226
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32ScalarVector)2227 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
2228 XlaBuilder builder(TestName());
2229 auto min_scalar = ConstantR0<float>(&builder, 0.0f);
2230 auto min_vector =
2231 ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2232 auto arg_vector =
2233 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2234 auto max_scalar = ConstantR0<float>(&builder, 3.0f);
2235 auto max_vector =
2236 ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2237 // Perform clamp with broadcasted scalar and vector.
2238 Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2239 Clamp(min_scalar, arg_vector, max_vector)),
2240 Add(Clamp(min_vector, arg_vector, max_vector),
2241 Clamp(min_scalar, arg_vector, max_scalar)));
2242
2243 ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {},
2244 error_spec_);
2245 }
2246
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32Vector)2247 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
2248 XlaBuilder builder(TestName());
2249 auto min_vector = ConstantR1<int32_t>(&builder, {1, -6, 1, 2, 0, -5});
2250 auto arg_vector = ConstantR1<int32_t>(&builder, {2, 10, -5, 1, 4, 10});
2251 auto max_vector = ConstantR1<int32_t>(&builder, {3, 0, 25, 5, 123, -1});
2252 Clamp(min_vector, arg_vector, max_vector);
2253
2254 ComputeAndCompareR1<int32_t>(&builder, {2, 0, 1, 2, 4, -1}, {});
2255 }
2256
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32ScalarVector)2257 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) {
2258 XlaBuilder builder(TestName());
2259 auto min_scalar = ConstantR0<int32_t>(&builder, 0);
2260 auto min_vector = ConstantR1<int32_t>(&builder, {1, -6, 1, 2, 0});
2261 auto arg_vector = ConstantR1<int32_t>(&builder, {2, 10, -5, 1, 4});
2262 auto max_scalar = ConstantR0<int32_t>(&builder, 3);
2263 auto max_vector = ConstantR1<int32_t>(&builder, {3, 1, 25, 5, 123});
2264 // Perform clamp with broadcasted scalar and vector.
2265 Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2266 Clamp(min_scalar, arg_vector, max_vector)),
2267 Add(Clamp(min_vector, arg_vector, max_vector),
2268 Clamp(min_scalar, arg_vector, max_scalar)));
2269
2270 ComputeAndCompareR1<int32_t>(&builder, {8, 8, 2, 6, 14}, {});
2271 }
2272
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32Vector)2273 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
2274 XlaBuilder builder(TestName());
2275 auto min_vector = ConstantR1<uint32_t>(&builder, {1, 2, 1, 2, 0, ~0u - 4});
2276 auto arg_vector = ConstantR1<uint32_t>(&builder, {2, 10, 5, 1, 4, 10});
2277 auto max_vector = ConstantR1<uint32_t>(&builder, {3, 5, 25, 5, 123, ~0u});
2278 Clamp(min_vector, arg_vector, max_vector);
2279
2280 ComputeAndCompareR1<uint32_t>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
2281 }
2282
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32ScalarVector)2283 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
2284 XlaBuilder builder(TestName());
2285 auto min_scalar = ConstantR0<uint32_t>(&builder, 0);
2286 auto min_vector = ConstantR1<uint32_t>(&builder, {1, 0, 1, 2, 0});
2287 auto arg_vector = ConstantR1<uint32_t>(&builder, {2, 10, 0, 1, 4});
2288 auto max_scalar = ConstantR0<uint32_t>(&builder, 3);
2289 auto max_vector = ConstantR1<uint32_t>(&builder, {3, 1, 25, 5, 123});
2290 // Perform clamp with broadcasted scalar and vector.
2291 Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2292 Clamp(min_scalar, arg_vector, max_vector)),
2293 Add(Clamp(min_vector, arg_vector, max_vector),
2294 Clamp(min_scalar, arg_vector, max_scalar)));
2295
2296 ComputeAndCompareR1<uint32_t>(&builder, {8, 8, 2, 6, 14}, {});
2297 }
2298
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersF32s)2299 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
2300 XlaBuilder builder(TestName());
2301
2302 Literal param0_literal =
2303 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2304 std::unique_ptr<GlobalData> param0_data =
2305 client_->TransferToServer(param0_literal).value();
2306
2307 Literal param1_literal =
2308 LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
2309 std::unique_ptr<GlobalData> param1_data =
2310 client_->TransferToServer(param1_literal).value();
2311
2312 auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2313 auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2314 Add(p0, p1);
2315
2316 ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
2317 {param0_data.get(), param1_data.get()},
2318 error_spec_);
2319 }
2320
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersZeroElementF32s)2321 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
2322 XlaBuilder builder(TestName());
2323
2324 Literal param0_literal =
2325 LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2326 std::unique_ptr<GlobalData> param0_data =
2327 client_->TransferToServer(param0_literal).value();
2328
2329 Literal param1_literal =
2330 LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2331 std::unique_ptr<GlobalData> param1_data =
2332 client_->TransferToServer(param1_literal).value();
2333
2334 auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2335 auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2336 Add(p0, p1);
2337
2338 Array3D<float> expected(0, 7, 0);
2339 ComputeAndCompareR3<float>(
2340 &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
2341 }
2342
XLA_TEST_F(ArrayElementwiseOpTest,AddParameterToConstantF32s)2343 XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
2344 XlaBuilder builder(TestName());
2345
2346 Literal param0_literal =
2347 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2348 std::unique_ptr<GlobalData> param0_data =
2349 client_->TransferToServer(param0_literal).value();
2350
2351 auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2352 auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
2353 Add(a, p);
2354
2355 ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
2356 {param0_data.get()}, error_spec_);
2357 }
2358
XLA_TEST_F(ArrayElementwiseOpTest,CosF32s)2359 XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) {
2360 XlaBuilder builder(TestName());
2361 auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2362 Cos(a);
2363
2364 ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {},
2365 error_spec_);
2366 }
2367
XLA_TEST_F(ArrayElementwiseOpTest,SinF32s)2368 XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
2369 XlaBuilder builder(TestName());
2370 auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2371 Sin(a);
2372
2373 ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {},
2374 error_spec_);
2375 }
2376
XLA_TEST_F(ArrayElementwiseOpTest,RealF64s)2377 XLA_TEST_F(ArrayElementwiseOpTest, RealF64s) {
2378 XlaBuilder builder(TestName());
2379 std::vector<double> xs = {3.14159f, 0.0f, 1.570796f, -0.78539f};
2380 auto a = ConstantR1<double>(&builder, xs);
2381 Real(a);
2382 ComputeAndCompareR1<double>(&builder, xs, {});
2383 }
2384
XLA_TEST_F(ArrayElementwiseOpTest,ImagF64s)2385 XLA_TEST_F(ArrayElementwiseOpTest, ImagF64s) {
2386 XlaBuilder builder(TestName());
2387 std::vector<double> xs = {3.14159, 0.0, 1.570796, -0.78539};
2388 auto a = ConstantR1<double>(&builder, xs);
2389 Imag(a);
2390 ComputeAndCompareR1<double>(&builder, {0., 0., 0., 0.}, {});
2391 }
2392
XLA_TEST_F(ArrayElementwiseOpTest,Atan2F32s)2393 XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
2394 XlaBuilder builder(TestName());
2395 auto inf = std::numeric_limits<float>::infinity();
2396 std::vector<float> ys;
2397 std::vector<float> xs;
2398 const auto _ys = {+0.0f, -0.0f, inf, -inf, 5.0f, -3.0f, 2.0f, -8.0f, 1.0f};
2399 const auto _xs = {+0.0f, -0.0f, inf, -inf, 6.0f, -4.0f, 2.0f, 8.0f};
2400 const auto n = _ys.size() * _xs.size();
2401 ys.reserve(n);
2402 xs.reserve(n);
2403 for (auto y : _ys) {
2404 for (auto x : _xs) {
2405 ys.push_back(y);
2406 xs.push_back(x);
2407 }
2408 }
2409 auto y = ConstantR1<float>(&builder, ys);
2410 auto x = ConstantR1<float>(&builder, xs);
2411 Atan2(y, x);
2412
2413 ComputeAndCompare(&builder, {}, error_spec_);
2414 }
2415
XLA_TEST_F(ArrayElementwiseOpTest,Atan2C64s)2416 XLA_TEST_F(ArrayElementwiseOpTest, Atan2C64s) {
2417 XlaBuilder builder(TestName());
2418 auto inf = std::numeric_limits<float>::infinity();
2419 std::vector<std::complex<float>> ys;
2420 std::vector<std::complex<float>> xs;
2421 const auto _ys = {+0.0f, -0.0f, inf, -inf, 5.0f, -3.0f, 2.0f, -8.0f, 1.0f};
2422 const auto _xs = {+0.0f, -0.0f, inf, -inf, 6.0f, -4.0f, 2.0f, 8.0f};
2423 const auto n = _ys.size() * _xs.size();
2424 ys.reserve(n);
2425 xs.reserve(n);
2426 for (auto y : _ys) {
2427 for (auto x : _xs) {
2428 ys.push_back(y);
2429 xs.push_back(x);
2430 }
2431 }
2432 auto y = ConstantR1<std::complex<float>>(&builder, ys);
2433 auto x = ConstantR1<std::complex<float>>(&builder, xs);
2434 Atan2(y, x);
2435
2436 ComputeAndCompare(&builder, {}, error_spec_);
2437 }
2438
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32s)2439 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
2440 XlaBuilder builder(TestName());
2441 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f});
2442 Tanh(a);
2443
2444 ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
2445 error_spec_);
2446 }
2447
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32sVector)2448 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
2449 // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that
2450 // the input tensor is large enough to exercise the vectorized tanh
2451 // implementation on XLA CPU.
2452 XlaBuilder builder(TestName());
2453 auto input_literal = LiteralUtil::CreateR1<float>(
2454 {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16,
2455 -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32,
2456 -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85,
2457 0.04, 1.02, 0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81,
2458 0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26, -1.29, 1.35,
2459 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31,
2460 -0.79, 1.41, 1.21, 1.05});
2461 TF_ASSERT_OK_AND_ASSIGN(auto input_data,
2462 client_->TransferToServer(input_literal));
2463
2464 auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2465 Tanh(input);
2466
2467 ComputeAndCompareR1<float>(
2468 &builder,
2469 {0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684,
2470 -0.71985596, -0.45764771, 0.66664988, -0.58278900, 0.16050975,
2471 -0.06770509, 0.36843640, -0.38476998, 0.04018109, 0.87562293,
2472 0.84788644, 0.38603750, 0.57294142, -0.79140943, 0.31032649,
2473 -0.89590985, -0.64770776, -0.79625875, 0.72234446, -0.77389336,
2474 -0.28871772, -0.80428445, -0.82541436, 0.90456349, -0.68856895,
2475 0.03877772, 0.76877952, 0.32561871, -0.54546672, 0.39072621,
2476 0.07273290, -0.01924866, 0.88924897, -0.55283129, 0.67183107,
2477 0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743,
2478 -0.41581789, 0.72739530, 0.85025692, -0.85931867, 0.87357593,
2479 0.07782833, -0.84597743, -0.72748238, 0.45396307, 0.82449573,
2480 -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558,
2481 -0.65565848, 0.88789743, 0.83566397, 0.78287679},
2482 {input_data.get()},
2483 // The error spec is unusually high here to account for the fact that we
2484 // use a rational interpolant to approximate tanh.
2485 ErrorSpec(0.004, 0.004));
2486 }
2487
XLA_TEST_F(ArrayElementwiseOpTest,ExpF32sVector)2488 XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
2489 // The input tensor is large enough to exercise the vectorized exp
2490 // implementation on XLA CPU.
2491 XlaBuilder builder(TestName());
2492
2493 // Just to help make sense of the scales here -- exp(89) saturates float32 and
2494 // exp(-10) is smaller than our error spec.
2495 Literal input_literal = LiteralUtil::CreateR1<float>(
2496 {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
2497 -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
2498 -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
2499 -16.3, -15.2, -14.1, -13.0, -11.9, -10.8, -9.7, -8.6, -7.5,
2500 -6.4, -5.3, -4.2, -3.1, -2.0, -0.9, 0.2, 1.3, 2.4,
2501 3.5, 4.6, 5.7, 6.8, 7.9, 9.0, 10.1, 11.2, 12.3,
2502 13.4, 14.5, 15.6, 16.7, 17.8, 18.9, 20.0, 21.1, 22.2,
2503 23.3, 24.4, 25.5, 26.6, 27.7, 28.8, 29.9, 31.0, 32.1,
2504 68.4, 69.5, 70.6, 71.7, 72.8, 73.9, 75.0, 76.1, 77.2,
2505 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3,
2506 86.4, 86.5, 87.6, 87.7, 87.8, 87.9});
2507 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2508 client_->TransferToServer(input_literal));
2509
2510 auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2511 Exp(input);
2512
2513 std::vector<float> expected_result;
2514 int64_t input_size = input_literal.shape().dimensions(0);
2515 expected_result.reserve(input_size);
2516 for (int64_t i = 0; i < input_size; i++) {
2517 expected_result.push_back(std::exp(input_literal.Get<float>({i})));
2518 }
2519
2520 ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2521 error_spec_);
2522 }
2523
XLA_TEST_F(ArrayElementwiseOpTest,LogF32sVector)2524 XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
2525 // The input tensor is large enough to exercise the vectorized exp
2526 // implementation on XLA CPU.
2527 XlaBuilder builder(TestName());
2528
2529 Literal input_literal = LiteralUtil::CreateR1<float>(
2530 {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
2531 -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
2532 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
2533 1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
2534 1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
2535 1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
2536 1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
2537 1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
2538 2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
2539 1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22,
2540 1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
2541 1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
2542 1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30,
2543 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
2544 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
2545 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2546 client_->TransferToServer(input_literal));
2547
2548 auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2549 Log(input);
2550
2551 std::vector<float> expected_result;
2552 int64_t input_size = input_literal.shape().dimensions(0);
2553 expected_result.reserve(input_size);
2554 for (int64_t i = 0; i < input_size; i++) {
2555 expected_result.push_back(std::log(input_literal.Get<float>({i})));
2556 }
2557
2558 ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2559 error_spec_);
2560 }
2561
XLA_TEST_F(ArrayElementwiseOpTest,ClzU32s)2562 XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) {
2563 XlaBuilder builder(TestName());
2564 auto a = ConstantR1<uint32_t>(
2565 &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678});
2566 Clz(a);
2567
2568 ComputeAndCompareR1<uint32_t>(&builder, {32, 31, 27, 15, 9, 3, 0}, {});
2569 }
2570
XLA_TEST_F(ArrayElementwiseOpTest,ClzS64s)2571 XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) {
2572 XlaBuilder builder(TestName());
2573 auto a = ConstantR1<int64_t>(&builder,
2574 {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1});
2575 Clz(a);
2576
2577 ComputeAndCompareR1<int64_t>(&builder, {64, 63, 32, 1, 0}, {});
2578 }
2579
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldLeft)2580 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
2581 // a ------ (add) --------- (add)
2582 // / /
2583 // b -----/ /
2584 // c---------------------/
2585 XlaBuilder builder(TestName());
2586
2587 auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2588 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2589 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2590
2591 auto add = Add(a, b);
2592 Add(add, c);
2593
2594 ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
2595 error_spec_);
2596 }
2597
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldRight)2598 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
2599 // b ------ (add) --------- (add)
2600 // / /
2601 // c -----/ /
2602 // a---------------------/
2603 XlaBuilder builder(TestName());
2604
2605 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2606 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2607 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2608
2609 auto add = Add(b, c);
2610 Add(a, add);
2611
2612 ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
2613 error_spec_);
2614 }
2615
XLA_TEST_F(ArrayElementwiseOpTest,AddWithNeg)2616 XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
2617 // a ----- (neg) ----- (add)
2618 // /
2619 // b ----- (neg) ----/
2620 XlaBuilder builder(TestName());
2621
2622 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2623 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2624
2625 auto neg_a = Neg(a);
2626 auto neg_b = Neg(b);
2627 Add(neg_a, neg_b);
2628
2629 ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
2630 error_spec_);
2631 }
2632
XLA_TEST_F(ArrayElementwiseOpTest,AddChainTwoSide)2633 XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
2634 // a ------ (add) ------------\
2635 // / \
2636 // b -----/ (add)
2637 // /
2638 // c ------ (add) ------------/
2639 // /
2640 // d -----/
2641 XlaBuilder builder(TestName());
2642
2643 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2644 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2645 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2646 auto d = ConstantR1<float>(&builder, {-19.0f, 10.0f, -40.0f, 20.2f});
2647
2648 auto add_ab = Add(a, b);
2649 auto add_cd = Add(c, d);
2650 Add(add_ab, add_cd);
2651
2652 ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
2653 error_spec_);
2654 }
2655
2656 XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
2657 XlaBuilder builder(TestName());
2658 auto a = ConstantR2<float>(&builder,
2659 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2660 auto b = ConstantR2<float>(&builder,
2661 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2662 Add(a, b);
2663
2664 Array2D<float> expected_array(
2665 {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
2666 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2667 }
2668
XLA_TEST_F(ArrayElementwiseOpTest,ScalarPlus2DF32)2669 XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
2670 // Add a scalar + matrix.
2671 XlaBuilder builder(TestName());
2672 auto a = ConstantR2<float>(&builder,
2673 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2674 auto scalar = ConstantR0<float>(&builder, 3.0f);
2675 Add(scalar, a);
2676
2677 Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2678 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2679 }
2680
2681 XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
2682 // Add a matrix + scalar.
2683 XlaBuilder builder(TestName());
2684 auto a = ConstantR2<float>(&builder,
2685 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2686 auto scalar = ConstantR0<float>(&builder, 3.0f);
2687 Add(a, scalar);
2688
2689 Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2690 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2691 }
2692
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32)2693 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
2694 // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
2695 // only dim 0 of the matrix.
2696 XlaBuilder builder(TestName());
2697 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f, 60.0f});
2698 // clang-format off
2699 auto m = ConstantR2<float>(&builder, {
2700 {-2.5f, 3.14f, 1.0f},
2701 {2.25f, -10.0f, 3.33f}});
2702 // clang-format on
2703 Add(v, m, /*broadcast_dimensions=*/{1});
2704 Array2D<float> expected_array(
2705 {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
2706 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2707 }
2708
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Eq)2709 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
2710 // Test broadcasting in Eq comparison.
2711 XlaBuilder builder(TestName());
2712 auto v = ConstantR1<int32_t>(&builder, {42, 73});
2713 auto m = ConstantR2<int32_t>(&builder, {{42, 73}, {42, 52}});
2714
2715 // This test exercises both possible broadcast dimensions for a vector/matrix
2716 // comparison.
2717 auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1});
2718 auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
2719 Tuple(&builder, {cmp_dim_0, cmp_dim_1});
2720
2721 auto expected = LiteralUtil::MakeTupleFromSlices(
2722 {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
2723 LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
2724 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
2725 }
2726
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ne)2727 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
2728 // Test broadcasting in Ne comparison.
2729 XlaBuilder builder(TestName());
2730 auto v = ConstantR1<int32_t>(&builder, {42, 73});
2731 auto m = ConstantR2<int32_t>(&builder, {{42, 73}, {42, 52}});
2732 Ne(v, m, /*broadcast_dimensions=*/{1});
2733
2734 const std::string expected = R"(pred[2,2] {
2735 { 0, 0 },
2736 { 0, 1 }
2737 })";
2738 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2739 }
2740
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ge)2741 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
2742 // Test broadcasting in Ge comparison.
2743 XlaBuilder builder(TestName());
2744 auto v = ConstantR1<int32_t>(&builder, {1, 2, 3, 4});
2745 auto m = ConstantR2<int32_t>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2746 Ge(v, m, /*broadcast_dimensions=*/{1});
2747
2748 const std::string expected = R"(pred[2,4] {
2749 { 1, 1, 0, 0 },
2750 { 0, 0, 0, 1 }
2751 })";
2752 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2753 }
2754
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Gt)2755 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
2756 // Test broadcasting in Gt comparison.
2757 XlaBuilder builder(TestName());
2758 auto v = ConstantR1<int32_t>(&builder, {1, 2, 3, 4});
2759 auto m = ConstantR2<int32_t>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2760 Gt(v, m, /*broadcast_dimensions=*/{1});
2761
2762 const std::string expected = R"(pred[2,4] {
2763 { 0, 1, 0, 0 },
2764 { 0, 0, 0, 0 }
2765 })";
2766 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2767 }
2768
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Le)2769 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
2770 // Test broadcasting in Le comparison.
2771 XlaBuilder builder(TestName());
2772 auto v = ConstantR1<int32_t>(&builder, {1, 2, 3, 4});
2773 auto m = ConstantR2<int32_t>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2774 Le(v, m, /*broadcast_dimensions=*/{1});
2775
2776 const std::string expected = R"(pred[2,4] {
2777 { 1, 0, 1, 1 },
2778 { 1, 1, 1, 1 }
2779 })";
2780 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2781 }
2782
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Lt)2783 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
2784 // Test broadcasting in Lt comparison.
2785 XlaBuilder builder(TestName());
2786 auto v = ConstantR1<int32_t>(&builder, {1, 2, 3, 4});
2787 auto m = ConstantR2<int32_t>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2788 Lt(v, m, /*broadcast_dimensions=*/{1});
2789
2790 const std::string expected = R"(pred[2,4] {
2791 { 0, 0, 1, 1 },
2792 { 1, 1, 1, 0 }
2793 })";
2794 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2795 }
2796
XLA_TEST_F(ArrayElementwiseOpTest,Mul2Dby1DF32)2797 XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
2798 // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
2799 // arguments is reversed.
2800 XlaBuilder builder(TestName());
2801 auto m =
2802 ConstantR2<float>(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
2803 auto v = ConstantR1<float>(&builder, {2.0f, 4.0f, 6.0f});
2804 Mul(m, v, /*broadcast_dimensions=*/{1});
2805 Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
2806 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2807 }
2808
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim1)2809 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
2810 // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2811 XlaBuilder builder(TestName());
2812 // m's shape in XLA notation is {3, 2}
2813 // md's shape in XLA notation is {3, 1}
2814 // The result has shape {3, 2}, where md is broadcast over m
2815 auto m = ConstantR2<float>(&builder,
2816 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2817 auto md = ConstantR2<float>(&builder, {{10.0f, 20.0f, 30.0f}});
2818 Add(m, md);
2819 Array2D<float> expected_array(
2820 {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
2821 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2822 }
2823
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim0)2824 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
2825 // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2826 XlaBuilder builder(TestName());
2827 // m's shape in XLA notation is {3, 2}
2828 // md's shape in XLA notation is {1, 2}
2829 // The result has shape {3, 2}, where md is broadcast over m
2830 auto m = ConstantR2<float>(&builder,
2831 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2832 auto md = ConstantR2<float>(&builder, {{10.0f}, {20.0f}});
2833 Add(m, md);
2834 Array2D<float> expected_array(
2835 {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
2836 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2837 }
2838
XLA_TEST_F(ArrayElementwiseOpTest,Add2DsWithDegenerateDimsOuterProduct)2839 XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
2840 // Tests broadcasting for two degenerate arrays. This kind of broadcasting
2841 // effectively creates an "outer product" operation.
2842 // This is taken from the Numpy docs example at:
2843 // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html
2844 XlaBuilder builder(TestName());
2845 // a's shape in XLA notation is {1, 4}
2846 // b's shape in XLA notation is {3, 1}
2847 // The result has shape {3, 4}.
2848 auto a = ConstantR2<float>(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}});
2849 auto b = ConstantR2<float>(&builder, {{1.0f, 2.0f, 3.0f}});
2850 Add(a, b);
2851 Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
2852 {11.0f, 12.0f, 13.0f},
2853 {21.0f, 22.0f, 23.0f},
2854 {31.0f, 32.0f, 33.0f}});
2855 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2856 }
2857
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver1)2858 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
2859 // Add together a (2,2) array and a (2) array, using dimension 0 for
2860 // broadcasting (though there are two ways to broadcast these shapes).
2861 XlaBuilder builder(TestName());
2862 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2863 auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2864 Add(v, m, /*broadcast_dimensions=*/{1});
2865 Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
2866 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2867 }
2868
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver0)2869 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
2870 // Add together a (2,2) array and a (2) array, using dimension 1 for
2871 // broadcasting (though there are two ways to broadcast these shapes).
2872 XlaBuilder builder(TestName());
2873 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2874 auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2875 Add(v, m, /*broadcast_dimensions=*/{0});
2876 Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
2877 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2878 }
2879
2880 XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
2881 // Binary add of two R3s together
2882 XlaBuilder builder(TestName());
2883 Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2884 {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2885 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2886
2887 Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
2888 {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
2889 auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2890 Add(a, b);
2891
2892 Array3D<float> expected_3d(
2893 {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
2894 {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}});
2895 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2896 }
2897
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver2)2898 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
2899 // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for
2900 // broadcasting (though there are two ways to broadcast these shapes).
2901 XlaBuilder builder(TestName());
2902 // clang-format off
2903 Array3D<float> a_3d({
2904 {{1.0f, 2.0f},
2905 {3.0f, 4.0f},
2906 {5.0f, 6.0f}},
2907 {{7.0f, 8.0f},
2908 {9.0f, 10.0f},
2909 {11.0f, 12.0f}},
2910 });
2911 // clang-format on
2912 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2913 auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2914 Add(a, v, /*broadcast_dimensions=*/{2});
2915
2916 Array3D<float> expected_3d(
2917 {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
2918 {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}});
2919 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2920 }
2921
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver0)2922 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
2923 // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for
2924 // broadcasting (though there are two ways to broadcast these shapes).
2925 XlaBuilder builder(TestName());
2926 // clang-format off
2927 Array3D<float> a_3d({
2928 {{1.0f, 2.0f},
2929 {3.0f, 4.0f},
2930 {5.0f, 6.0f}},
2931 {{7.0f, 8.0f},
2932 {9.0f, 10.0f},
2933 {11.0f, 12.0f}},
2934 });
2935 // clang-format on
2936 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2937 auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2938 Add(a, v, /*broadcast_dimensions=*/{0});
2939
2940 // clang-format off
2941 Array3D<float> expected_3d({
2942 {{11.0f, 12.0f},
2943 {13.0f, 14.0f},
2944 {15.0f, 16.0f}},
2945 {{27.0f, 28.0f},
2946 {29.0f, 30.0f},
2947 {31.0f, 32.0f}},
2948 });
2949 // clang-format on
2950 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2951 }
2952
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo3D)2953 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
2954 // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2}
2955 // for broadcasting.
2956 XlaBuilder builder(TestName());
2957 // clang-format off
2958 Array3D<float> a_3d({
2959 {{1.0f, 2.0f},
2960 {3.0f, 4.0f},
2961 {5.0f, 6.0f}},
2962 {{7.0f, 8.0f},
2963 {9.0f, 10.0f},
2964 {11.0f, 12.0f}},
2965 });
2966 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2967 auto m = ConstantR2<float>(&builder, {
2968 {10.0f, 20.0f, 30.0f},
2969 {40.0f, 50.0f, 60.0f},
2970 });
2971 Add(a, m, /*broadcast_dimensions=*/{0, 1});
2972
2973 Array3D<float> expected_3d({
2974 {{11.0f, 12.0f},
2975 {23.0f, 24.0f},
2976 {35.0f, 36.0f}},
2977 {{47.0f, 48.0f},
2978 {59.0f, 60.0f},
2979 {71.0f, 72.0f}},
2980 });
2981 // clang-format on
2982 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2983 }
2984
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtR3F32sWithDegenerateDim2)2985 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
2986 // Comparison between two 3D arrays of compatible shapes:
2987 // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs.
2988 XlaBuilder builder(TestName());
2989 Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2990 {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2991 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2992
2993 Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
2994 auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2995
2996 Gt(a, b);
2997
2998 Array3D<int> expected_3d(
2999 {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
3000 const std::string expected = R"(pred[2,3,2] {
3001 {
3002 { 0, 1 },
3003 { 0, 0 },
3004 { 0, 0 }
3005 },
3006 {
3007 { 0, 1 },
3008 { 1, 0 },
3009 { 0, 1 }
3010 }
3011 })";
3012 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
3013 }
3014
3015 XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
3016 XlaBuilder builder(TestName());
3017
3018 std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
3019 std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5));
3020 std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
3021 float value = 0.0;
3022 for (int64_t p = 0; p < 2; ++p) {
3023 for (int64_t z = 0; z < 3; ++z) {
3024 for (int64_t y = 0; y < 4; ++y) {
3025 for (int64_t x = 0; x < 5; ++x) {
3026 (*operand_a_4d)(p, z, y, x) = value;
3027 (*operand_b_4d)(p, z, y, x) = 2.0 * value;
3028 (*expected_4d)(p, z, y, x) = 3.0 * value;
3029 value += 0.1;
3030 }
3031 }
3032 }
3033 }
3034
3035 auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
3036 auto b = ConstantR4FromArray4D<float>(&builder, *operand_b_4d);
3037 Add(a, b);
3038
3039 ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
3040 }
3041
XLA_TEST_F(ArrayElementwiseOpTest,R4PlusR1InDim1)3042 XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
3043 XlaBuilder builder(TestName());
3044
3045 std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
3046 std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
3047 std::vector<float> operand_b_1d(3);
3048 std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0);
3049
3050 float value = 0.0;
3051 for (int64_t p = 0; p < 2; ++p) {
3052 for (int64_t z = 0; z < 3; ++z) {
3053 for (int64_t y = 0; y < 4; ++y) {
3054 for (int64_t x = 0; x < 5; ++x) {
3055 (*operand_a_4d)(p, z, y, x) = value;
3056 (*expected_4d)(p, z, y, x) = value + operand_b_1d[z];
3057 value += 0.1;
3058 }
3059 }
3060 }
3061 }
3062
3063 auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
3064 auto b = ConstantR1<float>(&builder, operand_b_1d);
3065 Add(a, b, {1});
3066
3067 ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
3068 }
3069
XLA_TEST_F(ArrayElementwiseOpTest,R4_16x16x2x2_Plus_R1_16)3070 XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
3071 constexpr int d0 = 16;
3072 constexpr int d1 = 16;
3073 constexpr int d2 = 2;
3074 constexpr int d3 = 2;
3075 Array4D<float> r4(d0, d1, d2, d3);
3076 r4.Fill(1.0);
3077 std::vector<float> r1(d1);
3078 std::iota(r1.begin(), r1.end(), 1.0);
3079
3080 XlaBuilder builder(TestName());
3081 Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
3082 r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
3083 auto a = ConstantLiteral(&builder, a_literal);
3084 auto b = ConstantR1<float>(&builder, r1);
3085 Add(a, b, {1});
3086
3087 for (int i0 = 0; i0 < d0; ++i0) {
3088 for (int i1 = 0; i1 < d1; ++i1) {
3089 for (int i2 = 0; i2 < d2; ++i2) {
3090 for (int i3 = 0; i3 < d3; ++i3) {
3091 r4(i0, i1, i2, i3) += r1[i1];
3092 }
3093 }
3094 }
3095 }
3096 ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_);
3097 }
3098
3099 // Show that we can't add two opaques.
XLA_TEST_F(ArrayElementwiseOpTest,CannotAddOpaques)3100 XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
3101 XlaBuilder builder(TestName());
3102 auto shape = ShapeUtil::MakeOpaqueShape();
3103 auto x = Parameter(&builder, 0, shape, "x");
3104 Add(x, x);
3105 auto computation_status = builder.Build();
3106 ASSERT_FALSE(computation_status.ok());
3107 EXPECT_THAT(computation_status.status().ToString(),
3108 ::testing::ContainsRegex(
3109 "Expected array argument for lhs of binary operation"));
3110 }
3111
XLA_TEST_F(ArrayElementwiseOpTest,IdentityBroadcastOfSameRankIsAllowed)3112 XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
3113 XlaBuilder builder(TestName());
3114 auto a = ConstantR2<float>(&builder,
3115 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
3116 auto b = ConstantR2<float>(&builder,
3117 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
3118 Add(a, b, /*broadcast_dimensions=*/{0, 1});
3119
3120 Array2D<float> expected_array(
3121 {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
3122 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
3123 }
3124
XLA_TEST_F(ArrayElementwiseOpTest,NonIdentityBroadcastOfSameRankIsDisallowed)3125 XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
3126 XlaBuilder builder(TestName());
3127 auto a = ConstantR2<float>(&builder,
3128 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
3129 auto b = ConstantR2<float>(&builder,
3130 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
3131 Add(a, b, /*broadcast_dimensions=*/{1, 0});
3132
3133 auto computation_status = builder.Build();
3134 ASSERT_FALSE(computation_status.ok());
3135 EXPECT_THAT(computation_status.status().error_message(),
3136 ::testing::ContainsRegex("must.*be the identity"));
3137 }
3138
3139 // Regression test for b/31927799. "slice - y" is fused and requires implicit
3140 // broadcast.
XLA_TEST_F(ArrayElementwiseOpTest,ImplicitBroadcastInFusedExpressions)3141 XLA_TEST_F(ArrayElementwiseOpTest, ImplicitBroadcastInFusedExpressions) {
3142 XlaBuilder builder(TestName());
3143 auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
3144 auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
3145 auto x_data = client_->TransferToServer(x_literal).value();
3146 auto y_data = client_->TransferToServer(y_literal).value();
3147
3148 auto x = Parameter(&builder, 0, x_literal.shape(), "x");
3149 auto y = Parameter(&builder, 1, y_literal.shape(), "y");
3150 auto slice = Slice(x, {1}, {2}, {1});
3151 Sub(slice, y);
3152
3153 ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
3154 error_spec_);
3155 }
3156
3157 INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
3158 ArrayElementwiseOpTestParamCount,
3159 ::testing::Values(127, 128, 129, 17 * 4096));
3160
3161 } // namespace
3162 } // namespace xla
3163