xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <functional>
18 #include <limits>
19 #include <memory>
20 #include <numeric>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/base/casts.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/client/global_data.h"
31 #include "tensorflow/compiler/xla/client/local_client.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/test.h"
37 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
38 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
39 #include "tensorflow/compiler/xla/tests/test_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 
42 namespace xla {
43 namespace {
44 
45 class ArrayElementwiseOpTest : public ClientLibraryTestBase {
46  public:
47   ErrorSpec error_spec_{0.0001, 0.0001};
48   ErrorSpec strict_error_spec_{3.6e-15, 3.6e-15};
49 };
50 
51 class ArrayElementwiseOpTestParamCount
52     : public ArrayElementwiseOpTest,
53       public ::testing::WithParamInterface<int> {};
54 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementF32)55 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
56   XlaBuilder builder(TestName());
57   auto a = ConstantR1<float>(&builder, {});
58   Neg(a);
59 
60   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
61 }
62 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantF32)63 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
64   XlaBuilder builder(TestName());
65   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
66   Neg(a);
67 
68   ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
69                              error_spec_);
70 }
71 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantF64)72 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF64) {
73   XlaBuilder builder(TestName());
74   auto a = ConstantR1<double>(&builder, {-2.5, 3.14, 2.25, -10.0, 6.0});
75   Neg(a);
76 
77   ComputeAndCompare(&builder, {}, strict_error_spec_);
78 }
79 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS32)80 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
81   XlaBuilder builder(TestName());
82   auto a = ConstantR1<int32_t>(
83       &builder, {-1, 0, 1, 324, std::numeric_limits<int32_t>::min(),
84                  std::numeric_limits<int32_t>::max()});
85   Neg(a);
86 
87   // -min == min for int32_t due to an overflow. In C++ it is undefined behavior
88   // to do this calculation. For XLA we have not specified that, so it
89   // ought to work.
90   ComputeAndCompareR1<int32_t>(
91       &builder,
92       {1, 0, -1, -324, std::numeric_limits<int32_t>::min(),
93        -std::numeric_limits<int32_t>::max()},
94       {});
95 }
96 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementC64)97 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) {
98   XlaBuilder builder(TestName());
99   auto a = ConstantR1<complex64>(&builder, {});
100   Neg(a);
101 
102   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
103 }
104 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantC64)105 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
106   XlaBuilder builder(TestName());
107   auto a = ConstantR1<complex64>(
108       &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
109   Neg(a);
110 
111   ComputeAndCompareR1<complex64>(
112       &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}},
113       {}, error_spec_);
114 }
115 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS64)116 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) {
117   XlaBuilder builder(TestName());
118   auto a = ConstantR1<int64_t>(&builder,
119                                {
120                                    -1,
121                                    1,
122                                    0,
123                                    0x12345678,
124                                    static_cast<int64_t>(0xffffffff12345678l),
125                                    static_cast<int64_t>(0x8000000000000000LL),
126                                    static_cast<int64_t>(0x8000000000000001LL),
127                                });
128   Neg(a);
129   LOG(INFO) << -static_cast<int64_t>(0x7FFFFFFFFFFFFFFFLL);
130 
131   ComputeAndCompareR1<int64_t>(&builder,
132                                {
133                                    1,
134                                    -1,
135                                    0,
136                                    -0x12345678,
137                                    0xedcba988,
138                                    static_cast<int64_t>(0x8000000000000000LL),
139                                    -static_cast<int64_t>(0x8000000000000001LL),
140                                },
141                                {});
142 }
143 
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteZeroElementF32s)144 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
145   XlaBuilder builder(TestName());
146   auto a = ConstantR1<float>(&builder, {});
147   IsFinite(a);
148 
149   ComputeAndCompareR1<bool>(&builder, {}, {});
150 }
151 
XLA_TEST_F(ArrayElementwiseOpTest,IntPow)152 XLA_TEST_F(ArrayElementwiseOpTest, IntPow) {
153   XlaBuilder builder(TestName());
154   XlaOp lhs =
155       ConstantR1<int32_t>(&builder, {0, 1, 2, 3, 4, 5, -1, -2, 3, 5, 3, 1});
156   XlaOp rhs =
157       ConstantR1<int32_t>(&builder, {0, 3, 3, 3, 3, 3, 2, 3, 2, 10, -100, -2});
158   Pow(lhs, rhs);
159 
160   std::vector<int32_t> expected = {1, 1,  8, 27,      64, 125,
161                                    1, -8, 9, 9765625, 0,  1};
162 
163   ComputeAndCompareR1<int32_t>(&builder, expected, {});
164 }
165 
XLA_TEST_F(ArrayElementwiseOpTest,IntPowLarge)166 XLA_TEST_F(ArrayElementwiseOpTest, IntPowLarge) {
167   XlaBuilder builder(TestName());
168   XlaOp lhs = ConstantR1<int64_t>(&builder, {2});
169   XlaOp rhs = ConstantR1<int64_t>(&builder, {62});
170   Pow(lhs, rhs);
171 
172   std::vector<int64_t> expected = {4611686018427387904};
173 
174   ComputeAndCompareR1<int64_t>(&builder, expected, {});
175 }
176 
177 // A non-canonical quiet NaN value.
178 static const float kNonCanonicalNaN = absl::bit_cast<float>(0x7FD01234);
179 
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteScalarF32)180 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
181   XlaBuilder builder(TestName());
182   IsFinite(ConstantR0<float>(&builder, NAN));
183   ComputeAndCompareR0<bool>(&builder, false, {});
184 
185   EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
186   IsFinite(ConstantR0<float>(&builder, kNonCanonicalNaN));
187   ComputeAndCompareR0<bool>(&builder, false, {});
188 
189   const float inf = std::numeric_limits<float>::infinity();
190   IsFinite(ConstantR0<float>(&builder, inf));
191   ComputeAndCompareR0<bool>(&builder, false, {});
192 
193   IsFinite(ConstantR0<float>(&builder, -inf));
194   ComputeAndCompareR0<bool>(&builder, false, {});
195 
196   IsFinite(ConstantR0<float>(&builder, 0.0f));
197   ComputeAndCompareR0<bool>(&builder, true, {});
198 }
199 
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteR1F32s)200 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
201   XlaBuilder builder(TestName());
202   const float inf = std::numeric_limits<float>::infinity();
203   EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
204   auto a = ConstantR1<float>(&builder,
205                              {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
206   IsFinite(a);
207 
208   ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false},
209                             {});
210 }
211 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantF32s)212 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
213   XlaBuilder builder(TestName());
214   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
215   auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
216   Add(a, b);
217 
218   ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
219                              error_spec_);
220 }
221 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementF32s)222 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
223   XlaBuilder builder(TestName());
224   auto a = ConstantR1<float>(&builder, {});
225   auto b = ConstantR1<float>(&builder, {});
226   Add(a, b);
227 
228   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
229 }
230 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantC64s)231 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
232   XlaBuilder builder(TestName());
233   auto a = ConstantR1<complex64>(
234       &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
235   auto b = ConstantR1<complex64>(
236       &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
237   Add(a, b);
238 
239   ComputeAndCompareR1<complex64>(
240       &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {},
241       error_spec_);
242 }
243 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementC64s)244 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
245   XlaBuilder builder(TestName());
246   auto a = ConstantR1<complex64>(&builder, {});
247   auto b = ConstantR1<complex64>(&builder, {});
248   Add(a, b);
249 
250   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
251 }
252 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantU64s)253 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
254   XlaBuilder b(TestName());
255 
256   std::vector<uint64_t> lhs{0xFFFFFFFF,
257                             static_cast<uint64_t>(-1),
258                             0,
259                             0,
260                             0x7FFFFFFFFFFFFFFFLL,
261                             0x7FFFFFFFFFFFFFFLL,
262                             0x8000000000000000ULL,
263                             0x8000000000000000ULL,
264                             1};
265   Literal lhs_literal = LiteralUtil::CreateR1<uint64_t>({lhs});
266   auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
267   std::unique_ptr<GlobalData> lhs_data =
268       client_->TransferToServer(lhs_literal).value();
269 
270   std::vector<uint64_t> rhs{1,
271                             0x7FFFFFFFFFFFFFFLL,
272                             0x7FFFFFFFFFFFFFFFLL,
273                             0x8000000000000000ULL,
274                             0,
275                             static_cast<uint64_t>(-1),
276                             0,
277                             1,
278                             0x8000000000000000ULL};
279   Literal rhs_literal = LiteralUtil::CreateR1<uint64_t>({rhs});
280   auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
281   std::unique_ptr<GlobalData> rhs_data =
282       client_->TransferToServer(rhs_literal).value();
283 
284   Add(lhs_param, rhs_param);
285 
286   std::vector<uint64_t> expected(lhs.size());
287   for (int64_t i = 0; i < lhs.size(); ++i) {
288     expected[i] = lhs[i] + rhs[i];
289   }
290 
291   ComputeAndCompareR1<uint64_t>(&b, expected, {lhs_data.get(), rhs_data.get()});
292 }
293 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS64s)294 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
295   XlaBuilder b(TestName());
296 
297   std::vector<int64_t> lhs{static_cast<int64_t>(0x8000000000000000LL),
298                            static_cast<int64_t>(0x8000000000000000LL),
299                            -1,
300                            0x7FFFFFFFFFFFFFFLL,
301                            0x7FFFFFFFFFFFFFFFLL,
302                            1,
303                            0,
304                            -1};
305   Literal lhs_literal = LiteralUtil::CreateR1<int64_t>({lhs});
306   auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
307   std::unique_ptr<GlobalData> lhs_data =
308       client_->TransferToServer(lhs_literal).value();
309 
310   std::vector<int64_t> rhs{-1,
311                            0,
312                            static_cast<int64_t>(0x8000000000000000LL),
313                            1,
314                            0,
315                            0x7FFFFFFFFFFFFFFLL,
316                            0x7FFFFFFFFFFFFFFFLL,
317                            0x7FFFFFFFFFFFFFFFLL};
318   Literal rhs_literal = LiteralUtil::CreateR1<int64_t>({rhs});
319   auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
320   std::unique_ptr<GlobalData> rhs_data =
321       client_->TransferToServer(rhs_literal).value();
322 
323   Sub(lhs_param, rhs_param);
324 
325   std::vector<int64_t> expected(lhs.size());
326   for (int64_t i = 0; i < lhs.size(); ++i) {
327     expected[i] = lhs[i] - rhs[i];
328   }
329 
330   ComputeAndCompareR1<int64_t>(&b, expected, {lhs_data.get(), rhs_data.get()});
331 }
332 
XLA_TEST_F(ArrayElementwiseOpTest,CmpTwoConstantU64s)333 XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
334   XlaBuilder b(TestName());
335 
336   std::vector<uint64_t> lhs{static_cast<uint64_t>(0x8000000000000000ULL)};
337   Literal lhs_literal = LiteralUtil::CreateR1<uint64_t>({lhs});
338   auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
339 
340   std::vector<uint64_t> rhs{static_cast<uint64_t>(0x7FFFFFFFFFFFFFFFULL)};
341   Literal rhs_literal = LiteralUtil::CreateR1<uint64_t>({rhs});
342   auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
343 
344   Lt(lhs_param, rhs_param);
345 
346   ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
347 }
348 
TEST_P(ArrayElementwiseOpTestParamCount,AddManyValues)349 TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
350   const int count = GetParam();
351   XlaBuilder builder(TestName());
352   std::vector<float> a_values;
353   std::vector<float> b_values;
354   a_values.reserve(count);
355   b_values.reserve(count);
356   for (int i = 0; i < count; ++i) {
357     a_values.push_back(i / static_cast<float>(count));
358     b_values.push_back(2 * i / static_cast<float>(count + 2));
359   }
360 
361   Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
362   std::unique_ptr<GlobalData> a_data =
363       client_->TransferToServer(a_literal).value();
364   auto a_constant = ConstantR1<float>(&builder, a_values);
365   auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
366 
367   Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
368   std::unique_ptr<GlobalData> b_data =
369       client_->TransferToServer(b_literal).value();
370   auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param");
371   auto b_constant = ConstantR1<float>(&builder, b_values);
372 
373   auto sum1 = Add(a_constant, b_param);
374   auto sum2 = Add(a_constant, b_constant);
375   auto sum3 = Add(a_param, b_param);
376   auto sum4 = Add(a_param, b_constant);
377 
378   auto sum = Add(sum1, sum2);
379   sum = Add(sum, sum3);
380   sum = Add(sum, sum4);
381 
382   std::vector<float> expected;
383   expected.reserve(count);
384   for (int64_t i = 0; i < count; ++i) {
385     expected.push_back(4 * (a_values[i] + b_values[i]));
386   }
387 
388   ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()},
389                              error_spec_);
390 }
391 
XLA_TEST_F(ArrayElementwiseOpTest,DeeplyNestedAddWithSlices)392 XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) {
393   XlaBuilder builder(TestName());
394   std::vector<float> values(30, 0.0);
395   auto a_literal = LiteralUtil::CreateR1<float>(values);
396   auto a = Parameter(&builder, 0, a_literal.shape(), "x");
397   auto b_literal = LiteralUtil::CreateR1<float>(values);
398   auto b = Parameter(&builder, 1, b_literal.shape(), "x");
399 
400   // Construct a sequence of diamond-shaped gadgets like this:
401   //
402   //      add
403   //    /    \
404   //  slice  slice
405   //     \   /
406   //      add
407   //
408   // Each 'left' slice removes the last element, each 'right' slice removes the
409   // first element. In this way, we index into the add with different
410   // multi-dimensional index arrays, which defeats the caching we use to avoid
411   // exponential compile time.
412   std::function<XlaOp(int64_t)> generate_recursive =
413       [&](int64_t slice_size) -> XlaOp {
414     if (slice_size == values.size()) {
415       return Add(a, b);
416     }
417     XlaOp param = generate_recursive(slice_size + 1);
418     auto slice1 = Slice(param, {0}, {slice_size}, {1});
419     auto slice2 = Slice(param, {1}, {slice_size + 1}, {1});
420     return Add(slice1, slice2);
421   };
422   generate_recursive(1);
423   auto a_data = client_->TransferToServer(a_literal).value();
424   auto b_data = client_->TransferToServer(b_literal).value();
425   ComputeAndCompareR1<float>(&builder, {0.0}, {a_data.get(), b_data.get()});
426 }
427 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantF32s)428 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
429   XlaBuilder builder(TestName());
430   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
431   auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
432   Sub(a, b);
433 
434   ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
435                              {}, error_spec_);
436 }
437 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementF32s)438 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
439   XlaBuilder builder(TestName());
440   auto a = ConstantR1<float>(&builder, {});
441   auto b = ConstantR1<float>(&builder, {});
442   Sub(a, b);
443 
444   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
445 }
446 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS32s)447 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
448   XlaBuilder builder(TestName());
449   auto a = ConstantR1<int32_t>(&builder, {-1, 0, 2, 1000000000});
450   auto b = ConstantR1<int32_t>(&builder, {-1, 2, 1, -1});
451   Sub(a, b);
452 
453   ComputeAndCompareR1<int32_t>(&builder, {0, -2, 1, 1000000001}, {});
454 }
455 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementS32s)456 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
457   XlaBuilder builder(TestName());
458   auto a = ConstantR1<int32_t>(&builder, {});
459   auto b = ConstantR1<int32_t>(&builder, {});
460   Sub(a, b);
461 
462   ComputeAndCompareR1<int32_t>(&builder, {}, {});
463 }
464 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantC64s)465 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
466   XlaBuilder builder(TestName());
467   auto a = ConstantR1<complex64>(&builder,
468                                  {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
469   auto b = ConstantR1<complex64>(
470       &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
471   Sub(a, b);
472 
473   ComputeAndCompareR1<complex64>(
474       &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {},
475       error_spec_);
476 }
477 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementC64s)478 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) {
479   XlaBuilder builder(TestName());
480   auto a = ConstantR1<complex64>(&builder, {});
481   auto b = ConstantR1<complex64>(&builder, {});
482   Sub(a, b);
483 
484   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
485 }
486 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantF64s)487 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF64s) {
488   XlaBuilder builder(TestName());
489   auto a = ConstantR1<double>(&builder, {-2.5, 3.14, 2.25, -10.0, 6.0});
490   auto b = ConstantR1<double>(&builder, {100.0, 3.13, 2.75, 10.5, -999.0});
491   Sub(a, b);
492 
493   ComputeAndCompare(&builder, {}, strict_error_spec_);
494 }
495 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantF32s)496 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
497   XlaBuilder builder(TestName());
498   auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
499   auto b = ConstantR1<float>(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
500   Div(a, b);
501 
502   ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
503                              error_spec_);
504 }
505 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementF32s)506 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
507   XlaBuilder builder(TestName());
508   auto a = ConstantR1<float>(&builder, {});
509   auto b = ConstantR1<float>(&builder, {});
510   Div(a, b);
511 
512   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
513 }
514 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantF64s)515 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF64s) {
516   XlaBuilder builder(TestName());
517   auto a = ConstantR1<double>(
518       &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 1.0, 2.0, 3.2, -4.0, 0.45, 5.7,
519                  0.1, 1.0, 2.0, 0.5, -1.0, -0.5, 1.0});
520   auto b = ConstantR1<double>(
521       &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 0.1, 1.0, 2.0, 0.5, -1.0, -0.5,
522                  2.1, 3.1, 9.9, -4.5, -11.0, -21.5, M_PI});
523   Div(a, b);
524 
525   ComputeAndCompare(&builder, {}, strict_error_spec_);
526 }
527 
528 class IntegerDivideOpTest : public ArrayElementwiseOpTest {
529  protected:
530   template <typename T>
TestDivRem(absl::Span<const T> dividends,absl::Span<const T> divisors,absl::Span<const T> quotients,absl::Span<const T> remainders)531   void TestDivRem(absl::Span<const T> dividends, absl::Span<const T> divisors,
532                   absl::Span<const T> quotients,
533                   absl::Span<const T> remainders) {
534     {
535       XlaBuilder builder(TestName());
536       XlaOp dividend;
537       XlaOp divisor;
538       auto dividend_data =
539           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
540       auto divisor_data =
541           CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
542       Div(dividend, divisor);
543 
544       ComputeAndCompareR1<T>(&builder, quotients,
545                              {dividend_data.get(), divisor_data.get()});
546     }
547 
548     // Test with a compile-time constant divisor.
549     {
550       XlaBuilder builder(TestName());
551       XlaOp dividend;
552       auto dividend_data =
553           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
554       Div(dividend, ConstantR1<T>(&builder, divisors));
555 
556       ComputeAndCompareR1<T>(&builder, quotients, {dividend_data.get()});
557     }
558 
559     {
560       XlaBuilder builder(TestName());
561       XlaOp dividend;
562       XlaOp divisor;
563       auto dividend_data =
564           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
565       auto divisor_data =
566           CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
567       Rem(dividend, divisor);
568 
569       ComputeAndCompareR1<T>(&builder, remainders,
570                              {dividend_data.get(), divisor_data.get()});
571     }
572 
573     // Test with a compile-time constant divisor.
574     {
575       XlaBuilder builder(TestName());
576       XlaOp dividend;
577       auto dividend_data =
578           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
579       Rem(dividend, ConstantR1<T>(&builder, divisors));
580 
581       ComputeAndCompareR1<T>(&builder, remainders, {dividend_data.get()});
582     }
583   }
584 };
585 
XLA_TEST_F(IntegerDivideOpTest,DivS32s)586 XLA_TEST_F(IntegerDivideOpTest, DivS32s) {
587   // clang-format off
588   // Some interesting values to test.
589   std::vector<int32_t> vals = {
590     INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff,
591     -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101,
592     7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX};
593   // clang-format on
594 
595   std::vector<int32_t> dividends, divisors, quotients, remainders;
596   for (int32_t divisor : vals) {
597     if (divisor != 0) {
598       for (int32_t dividend : vals) {
599         // Avoid integer overflow.
600         if (dividend != INT32_MIN || divisor != -1) {
601           dividends.push_back(dividend);
602           divisors.push_back(divisor);
603           quotients.push_back(dividend / divisor);
604           remainders.push_back(dividend % divisor);
605         }
606       }
607     }
608   }
609 
610   TestDivRem<int32_t>(dividends, divisors, quotients, remainders);
611 }
612 
XLA_TEST_F(IntegerDivideOpTest,SignedOverflow)613 XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) {
614   std::vector<int32_t> dividends = {5, INT32_MIN}, divisors = {0, -1},
615                        quotients = {-1, INT32_MIN}, remainders = {5, 0};
616 
617   TestDivRem<int32_t>(dividends, divisors, quotients, remainders);
618 }
619 
XLA_TEST_F(IntegerDivideOpTest,DivU32s)620 XLA_TEST_F(IntegerDivideOpTest, DivU32s) {
621   // clang-format off
622   // Some interesting values to test.
623   std::vector<uint32_t> vals = {
624     0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000,
625     0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX};
626   // clang-format on
627 
628   std::vector<uint32_t> dividends, divisors, quotients, remainders;
629   for (uint32_t divisor : vals) {
630     if (divisor != 0) {
631       for (uint32_t dividend : vals) {
632         dividends.push_back(dividend);
633         divisors.push_back(divisor);
634         quotients.push_back(dividend / divisor);
635         remainders.push_back(dividend % divisor);
636       }
637     }
638   }
639 
640   TestDivRem<uint32_t>(dividends, divisors, quotients, remainders);
641 }
642 
XLA_TEST_F(IntegerDivideOpTest,UnsignedOverflow)643 XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) {
644   std::vector<int32_t> dividends = {5}, divisors = {0}, quotients = {-1},
645                        remainders = {5};
646 
647   TestDivRem<int32_t>(dividends, divisors, quotients, remainders);
648 }
649 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantC64s)650 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
651   XlaBuilder builder(TestName());
652   auto a = ConstantR1<complex64>(
653       &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
654   auto b = ConstantR1<complex64>(&builder,
655                                  {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
656   Div(a, b);
657 
658   ComputeAndCompareR1<complex64>(
659       &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_);
660 }
661 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementC64s)662 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) {
663   XlaBuilder builder(TestName());
664   auto a = ConstantR1<complex64>(&builder, {});
665   auto b = ConstantR1<complex64>(&builder, {});
666   Div(a, b);
667 
668   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
669 }
670 
XLA_TEST_F(ArrayElementwiseOpTest,RemF32s)671 XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
672   XlaBuilder builder(TestName());
673   auto a = ConstantR1<float>(
674       &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
675   auto b = ConstantR1<float>(
676       &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
677   Rem(a, b);
678 
679   ComputeAndCompareR1<float>(
680       &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
681       error_spec_);
682 }
683 
XLA_TEST_F(ArrayElementwiseOpTest,RemZeroElementF32s)684 XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
685   XlaBuilder builder(TestName());
686   auto a = ConstantR1<float>(&builder, {});
687   auto b = ConstantR1<float>(&builder, {});
688   Rem(a, b);
689 
690   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
691 }
692 
XLA_TEST_F(ArrayElementwiseOpTest,RemF64s)693 XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
694   XlaBuilder builder(TestName());
695   auto a = ConstantR1<double>(
696       &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
697   auto b = ConstantR1<double>(
698       &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
699   Rem(a, b);
700 
701   ComputeAndCompareR1<double>(
702       &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
703       strict_error_spec_);
704 }
705 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantF32s)706 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
707   XlaBuilder builder(TestName());
708   auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
709   auto b = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
710   Mul(a, b);
711 
712   ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
713                              {}, error_spec_);
714 }
715 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementF32s)716 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
717   XlaBuilder builder(TestName());
718   auto a = ConstantR1<float>(&builder, {});
719   auto b = ConstantR1<float>(&builder, {});
720   Mul(a, b);
721 
722   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
723 }
724 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantS32s)725 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
726   std::vector<int32_t> data = {0,
727                                1,
728                                -1,
729                                1234,
730                                0x1a243514,
731                                std::numeric_limits<int32_t>::max(),
732                                std::numeric_limits<int32_t>::min()};
733   // Form the test data set using all products of 'data' with itself.
734   std::vector<int32_t> a_data, b_data, expected;
735   for (int32_t a : data) {
736     for (int32_t b : data) {
737       a_data.push_back(a);
738       b_data.push_back(b);
739       expected.push_back(static_cast<uint32_t>(a) * static_cast<uint32_t>(b));
740     }
741   }
742 
743   XlaBuilder builder(TestName());
744   auto a = ConstantR1<int32_t>(&builder, a_data);
745   auto b = ConstantR1<int32_t>(&builder, b_data);
746   Mul(a, b);
747 
748   ComputeAndCompareR1<int32_t>(&builder, expected, {});
749 }
750 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementS32s)751 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
752   XlaBuilder builder(TestName());
753   auto a = ConstantR1<int32_t>(&builder, {});
754   auto b = ConstantR1<int32_t>(&builder, {});
755   Mul(a, b);
756 
757   ComputeAndCompareR1<int32_t>(&builder, {}, {});
758 }
759 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantU32s)760 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
761   std::vector<uint32_t> data = {0,          1,          0xDEADBEEF, 1234,
762                                 0x1a243514, 0xFFFFFFFF, 0x80808080};
763 
764   // Form the test data set using all products of 'data' with itself.
765   std::vector<uint32_t> a_data, b_data, expected;
766   for (uint32_t a : data) {
767     for (uint32_t b : data) {
768       a_data.push_back(a);
769       b_data.push_back(b);
770       expected.push_back(a * b);
771     }
772   }
773 
774   XlaBuilder builder(TestName());
775   auto a = ConstantR1<uint32_t>(&builder, a_data);
776   auto b = ConstantR1<uint32_t>(&builder, b_data);
777   Mul(a, b);
778 
779   ComputeAndCompareR1<uint32_t>(&builder, expected, {});
780 }
781 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantC64s)782 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
783   XlaBuilder builder(TestName());
784   auto a = ConstantR1<complex64>(
785       &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
786   auto b = ConstantR1<complex64>(&builder,
787                                  {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
788   Mul(a, b);
789 
790   ComputeAndCompareR1<complex64>(
791       &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {},
792       error_spec_);
793 }
794 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementC64s)795 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) {
796   XlaBuilder builder(TestName());
797   auto a = ConstantR1<complex64>(&builder, {});
798   auto b = ConstantR1<complex64>(&builder, {});
799   Mul(a, b);
800 
801   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
802 }
803 
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR1)804 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
805   XlaBuilder builder(TestName());
806   auto a = ConstantR1<bool>(&builder, {false, false, true, true});
807   auto b = ConstantR1<bool>(&builder, {false, true, false, true});
808   And(a, b);
809 
810   ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
811 }
812 
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR2)813 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
814   XlaBuilder builder(TestName());
815   auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
816   auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
817   And(a, b);
818 
819   Array2D<bool> expected_array({{false, false}, {false, true}});
820   ComputeAndCompareR2<bool>(&builder, expected_array, {});
821 }
822 
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementPredR1)823 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) {
824   XlaBuilder builder(TestName());
825   auto a = ConstantR1<bool>(&builder, {});
826   auto b = ConstantR1<bool>(&builder, {});
827   And(a, b);
828 
829   ComputeAndCompareR1<bool>(&builder, {}, {});
830 }
831 
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R1)832 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) {
833   XlaBuilder builder(TestName());
834   auto a = ConstantR1<int32_t>(&builder, {0, -1, -8});
835   auto b = ConstantR1<int32_t>(&builder, {5, -7, 12});
836   And(a, b);
837 
838   ComputeAndCompareR1<int32_t>(&builder, {0, -7, 8}, {});
839 }
840 
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R2)841 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
842   XlaBuilder builder(TestName());
843   auto a = ConstantR2<int32_t>(&builder, {{0, -5}, {-1, 5}});
844   auto b = ConstantR2<int32_t>(&builder, {{1, -6}, {4, 5}});
845   And(a, b);
846 
847   Array2D<int32_t> expected_array({{0, -6}, {4, 5}});
848   ComputeAndCompareR2<int32_t>(&builder, expected_array, {});
849 }
850 
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementS32R1)851 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) {
852   XlaBuilder builder(TestName());
853   auto a = ConstantR1<int32_t>(&builder, {});
854   auto b = ConstantR1<int32_t>(&builder, {});
855   And(a, b);
856 
857   ComputeAndCompareR1<int32_t>(&builder, {}, {});
858 }
859 
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R1)860 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) {
861   XlaBuilder builder(TestName());
862   auto a = ConstantR1<int32_t>(&builder, {0, 1, 8});
863   auto b = ConstantR1<int32_t>(&builder, {5, 7, 12});
864   And(a, b);
865 
866   ComputeAndCompareR1<int32_t>(&builder, {0, 1, 8}, {});
867 }
868 
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R2)869 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
870   XlaBuilder builder(TestName());
871   auto a = ConstantR2<uint32_t>(&builder, {{0, 1}, {3, 8}});
872   auto b = ConstantR2<uint32_t>(&builder, {{1, 0}, {7, 6}});
873   And(a, b);
874 
875   Array2D<uint32_t> expected_array({{0, 0}, {3, 0}});
876   ComputeAndCompareR2<uint32_t>(&builder, expected_array, {});
877 }
878 
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementU32R1)879 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) {
880   XlaBuilder builder(TestName());
881   auto a = ConstantR1<uint32_t>(&builder, {});
882   auto b = ConstantR1<uint32_t>(&builder, {});
883   And(a, b);
884 
885   ComputeAndCompareR1<uint32_t>(&builder, {}, {});
886 }
887 
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR1)888 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) {
889   XlaBuilder builder(TestName());
890   auto a = ConstantR1<bool>(&builder, {false, false, true, true});
891   auto b = ConstantR1<bool>(&builder, {false, true, false, true});
892   Or(a, b);
893 
894   ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
895 }
896 
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR2)897 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
898   XlaBuilder builder(TestName());
899   auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
900   auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
901   Or(a, b);
902 
903   Array2D<bool> expected_array({{false, true}, {true, true}});
904   ComputeAndCompareR2<bool>(&builder, expected_array, {});
905 }
906 
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementPredR1)907 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) {
908   XlaBuilder builder(TestName());
909   auto a = ConstantR1<bool>(&builder, {});
910   auto b = ConstantR1<bool>(&builder, {});
911   Or(a, b);
912 
913   ComputeAndCompareR1<bool>(&builder, {}, {});
914 }
915 
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R1)916 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) {
917   XlaBuilder builder(TestName());
918   auto a = ConstantR1<int32_t>(&builder, {0, -1, 8});
919   auto b = ConstantR1<int32_t>(&builder, {5, -7, 4});
920   Or(a, b);
921 
922   ComputeAndCompareR1<int32_t>(&builder, {5, -1, 12}, {});
923 }
924 
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R2)925 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
926   XlaBuilder builder(TestName());
927   auto a = ConstantR2<int32_t>(&builder, {{0, -1}, {8, 8}});
928   auto b = ConstantR2<int32_t>(&builder, {{5, -7}, {4, 1}});
929   Or(a, b);
930 
931   Array2D<int32_t> expected_array({{5, -1}, {12, 9}});
932   ComputeAndCompareR2<int32_t>(&builder, expected_array, {});
933 }
934 
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementS32R1)935 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) {
936   XlaBuilder builder(TestName());
937   auto a = ConstantR1<int32_t>(&builder, {});
938   auto b = ConstantR1<int32_t>(&builder, {});
939   Or(a, b);
940 
941   ComputeAndCompareR1<int32_t>(&builder, {}, {});
942 }
943 
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R1)944 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) {
945   XlaBuilder builder(TestName());
946   auto a = ConstantR1<uint32_t>(&builder, {0, 1, 8});
947   auto b = ConstantR1<uint32_t>(&builder, {5, 7, 4});
948   Or(a, b);
949 
950   ComputeAndCompareR1<uint32_t>(&builder, {5, 7, 12}, {});
951 }
952 
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R2)953 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
954   XlaBuilder builder(TestName());
955   auto a = ConstantR2<uint32_t>(&builder, {{0, 1}, {8, 8}});
956   auto b = ConstantR2<uint32_t>(&builder, {{5, 7}, {4, 1}});
957   Or(a, b);
958 
959   Array2D<uint32_t> expected_array({{5, 7}, {12, 9}});
960   ComputeAndCompareR2<uint32_t>(&builder, expected_array, {});
961 }
962 
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementU32R1)963 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) {
964   XlaBuilder builder(TestName());
965   auto a = ConstantR1<uint32_t>(&builder, {});
966   auto b = ConstantR1<uint32_t>(&builder, {});
967   Or(a, b);
968 
969   ComputeAndCompareR1<uint32_t>(&builder, {}, {});
970 }
971 
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR1)972 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) {
973   XlaBuilder builder(TestName());
974   auto a = ConstantR1<bool>(&builder, {false, false, true, true});
975   auto b = ConstantR1<bool>(&builder, {false, true, false, true});
976   Xor(a, b);
977 
978   ComputeAndCompareR1<bool>(&builder, {false, true, true, false}, {});
979 }
980 
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR2)981 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) {
982   XlaBuilder builder(TestName());
983   auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
984   auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
985   Xor(a, b);
986 
987   Array2D<bool> expected_array({{false, true}, {true, false}});
988   ComputeAndCompareR2<bool>(&builder, expected_array, {});
989 }
990 
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementPredR1)991 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) {
992   XlaBuilder builder(TestName());
993   auto a = ConstantR1<bool>(&builder, {});
994   auto b = ConstantR1<bool>(&builder, {});
995   Xor(a, b);
996 
997   ComputeAndCompareR1<bool>(&builder, {}, {});
998 }
999 
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R1)1000 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) {
1001   XlaBuilder builder(TestName());
1002   auto a = ConstantR1<int32_t>(&builder, {0, -1, 8});
1003   auto b = ConstantR1<int32_t>(&builder, {5, -7, 4});
1004   Xor(a, b);
1005 
1006   ComputeAndCompareR1<int32_t>(&builder, {5, 6, 12}, {});
1007 }
1008 
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R2)1009 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) {
1010   XlaBuilder builder(TestName());
1011   auto a = ConstantR2<int32_t>(&builder, {{0, -1}, {8, 8}});
1012   auto b = ConstantR2<int32_t>(&builder, {{5, -7}, {4, 1}});
1013   Xor(a, b);
1014 
1015   Array2D<int32_t> expected_array({{5, 6}, {12, 9}});
1016   ComputeAndCompareR2<int32_t>(&builder, expected_array, {});
1017 }
1018 
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementS32R1)1019 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) {
1020   XlaBuilder builder(TestName());
1021   auto a = ConstantR1<int32_t>(&builder, {});
1022   auto b = ConstantR1<int32_t>(&builder, {});
1023   Xor(a, b);
1024 
1025   ComputeAndCompareR1<int32_t>(&builder, {}, {});
1026 }
1027 
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R1)1028 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) {
1029   XlaBuilder builder(TestName());
1030   auto a = ConstantR1<uint32_t>(&builder, {0, 1, 8});
1031   auto b = ConstantR1<uint32_t>(&builder, {5, 7, 4});
1032   Xor(a, b);
1033 
1034   ComputeAndCompareR1<uint32_t>(&builder, {5, 6, 12}, {});
1035 }
1036 
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R2)1037 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) {
1038   XlaBuilder builder(TestName());
1039   auto a = ConstantR2<uint32_t>(&builder, {{0, 1}, {8, 8}});
1040   auto b = ConstantR2<uint32_t>(&builder, {{5, 7}, {4, 1}});
1041   Xor(a, b);
1042 
1043   Array2D<uint32_t> expected_array({{5, 6}, {12, 9}});
1044   ComputeAndCompareR2<uint32_t>(&builder, expected_array, {});
1045 }
1046 
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementU32R1)1047 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) {
1048   XlaBuilder builder(TestName());
1049   auto a = ConstantR1<uint32_t>(&builder, {});
1050   auto b = ConstantR1<uint32_t>(&builder, {});
1051   Xor(a, b);
1052 
1053   ComputeAndCompareR1<uint32_t>(&builder, {}, {});
1054 }
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR1)1055 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) {
1056   XlaBuilder builder(TestName());
1057   auto a = ConstantR1<bool>(&builder, {false, true, true, false});
1058   Not(a);
1059 
1060   ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
1061 }
1062 
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR2)1063 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
1064   XlaBuilder builder(TestName());
1065   auto a = ConstantR2<bool>(&builder, {{false, true}, {true, false}});
1066   Not(a);
1067 
1068   Array2D<bool> expected_array({{true, false}, {false, true}});
1069   ComputeAndCompareR2<bool>(&builder, expected_array, {});
1070 }
1071 
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementPredR1)1072 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) {
1073   XlaBuilder builder(TestName());
1074   auto a = ConstantR1<bool>(&builder, {});
1075   Not(a);
1076 
1077   ComputeAndCompareR1<bool>(&builder, {}, {});
1078 }
1079 
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R1)1080 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) {
1081   XlaBuilder builder(TestName());
1082   auto a = ConstantR1<int32_t>(&builder, {-1, 0, 1});
1083   Not(a);
1084 
1085   ComputeAndCompareR1<int32_t>(&builder, {0, -1, -2}, {});
1086 }
1087 
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R2)1088 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
1089   XlaBuilder builder(TestName());
1090   auto a = ConstantR2<int32_t>(&builder, {{-1, 0}, {1, 8}});
1091   Not(a);
1092 
1093   Array2D<int32_t> expected_array({{0, -1}, {-2, -9}});
1094   ComputeAndCompareR2<int32_t>(&builder, expected_array, {});
1095 }
1096 
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementS32R1)1097 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) {
1098   XlaBuilder builder(TestName());
1099   auto a = ConstantR1<int32_t>(&builder, {});
1100   Not(a);
1101 
1102   ComputeAndCompareR1<int32_t>(&builder, {}, {});
1103 }
1104 
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R1)1105 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) {
1106   XlaBuilder builder(TestName());
1107   auto a = ConstantR1<uint32_t>(&builder, {0, 4294967295});
1108   Not(a);
1109 
1110   ComputeAndCompareR1<uint32_t>(&builder, {4294967295, 0}, {});
1111 }
1112 
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R2)1113 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
1114   XlaBuilder builder(TestName());
1115   auto a = ConstantR2<uint32_t>(&builder, {{0, 4294967295}, {1, 4294967294}});
1116   Not(a);
1117 
1118   Array2D<uint32_t> expected_array({{4294967295, 0}, {4294967294, 1}});
1119   ComputeAndCompareR2<uint32_t>(&builder, expected_array, {});
1120 }
1121 
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementU32R1)1122 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
1123   XlaBuilder builder(TestName());
1124   auto a = ConstantR1<uint32_t>(&builder, {});
1125   Not(a);
1126 
1127   ComputeAndCompareR1<uint32_t>(&builder, {}, {});
1128 }
1129 
XLA_TEST_F(ArrayElementwiseOpTest,PopcntR1)1130 XLA_TEST_F(ArrayElementwiseOpTest, PopcntR1) {
1131   XlaBuilder builder(TestName());
1132   auto a = ConstantR1<int32_t>(&builder, {0, 1, -15, 341});
1133   PopulationCount(a);
1134   ComputeAndCompareR1<int32_t>(&builder, {0, 1, 29, 5}, {});
1135 }
1136 
XLA_TEST_F(ArrayElementwiseOpTest,PopcntR2)1137 XLA_TEST_F(ArrayElementwiseOpTest, PopcntR2) {
1138   XlaBuilder builder(TestName());
1139   auto a = ConstantR2<int32_t>(&builder, {{0, 1}, {-15, 341}});
1140   PopulationCount(a);
1141   Array2D<int32_t> expected_array({{0, 1}, {29, 5}});
1142   ComputeAndCompareR2<int32_t>(&builder, expected_array, {});
1143 }
1144 
XLA_TEST_F(ArrayElementwiseOpTest,PopcntS64)1145 XLA_TEST_F(ArrayElementwiseOpTest, PopcntS64) {
1146   XlaBuilder builder(TestName());
1147   auto a = ConstantR2<int64_t>(&builder, {{0, -1}, {INT64_MAX, INT64_MAX - 1}});
1148   PopulationCount(a);
1149   Array2D<int64_t> expected_array({{0, 64}, {63, 62}});
1150   ComputeAndCompareR2<int64_t>(&builder, expected_array, {});
1151 }
1152 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftS32)1153 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
1154   XlaBuilder builder(TestName());
1155   auto a = ConstantR1<int32_t>(
1156       &builder, {static_cast<int32_t>(0x12345678),
1157                  static_cast<int32_t>(0xF0001000), 1, 3, 77, 1, -3, 77});
1158   auto b = ConstantR1<int32_t>(&builder, {4, 8, 2, 7, 15, 32, 100, -1});
1159   ShiftLeft(a, b);
1160 
1161   ComputeAndCompareR1<int32_t>(&builder,
1162                                {static_cast<int32_t>(0x23456780), 0x00100000,
1163                                 0x4, 0x180, 2523136, 0, 0, 0},
1164                                {});
1165 }
1166 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticS32)1167 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
1168   XlaBuilder builder(TestName());
1169   auto a = ConstantR1<int32_t>(
1170       &builder, {static_cast<int32_t>(0x92345678),
1171                  static_cast<int32_t>(0x10001000), 1, 3, 77, 1, -3, 77});
1172   auto b = ConstantR1<int32_t>(&builder, {4, 8, 2, 7, 2, 32, 100, -1});
1173   ShiftRightArithmetic(a, b);
1174 
1175   ComputeAndCompareR1<int32_t>(
1176       &builder,
1177       {static_cast<int32_t>(0xF9234567), static_cast<int32_t>(0x00100010), 0, 0,
1178        19, 0, -1, 0},
1179       {});
1180 }
1181 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalS32)1182 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
1183   XlaBuilder builder(TestName());
1184   auto a = ConstantR1<int32_t>(
1185       &builder, {static_cast<int32_t>(0x92345678),
1186                  static_cast<int32_t>(0x10001000), 1, 3, 77, 1, -3, 77});
1187   auto b = ConstantR1<int32_t>(&builder, {4, 8, 2, 7, 5, 32, 100, -1});
1188   ShiftRightLogical(a, b);
1189 
1190   ComputeAndCompareR1<int32_t>(&builder,
1191                                {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1192 }
1193 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftU32)1194 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
1195   XlaBuilder builder(TestName());
1196   auto a = ConstantR1<uint32_t>(&builder,
1197                                 {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77});
1198   auto b = ConstantR1<uint32_t>(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u});
1199   ShiftLeft(a, b);
1200 
1201   ComputeAndCompareR1<uint32_t>(
1202       &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {});
1203 }
1204 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticU32)1205 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
1206   XlaBuilder builder(TestName());
1207   auto a = ConstantR1<uint32_t>(&builder,
1208                                 {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1209   auto b = ConstantR1<uint32_t>(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u});
1210   ShiftRightArithmetic(a, b);
1211 
1212   ComputeAndCompareR1<uint32_t>(
1213       &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {});
1214 }
1215 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalU32)1216 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
1217   XlaBuilder builder(TestName());
1218   auto a = ConstantR1<uint32_t>(&builder,
1219                                 {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1220   auto b = ConstantR1<uint32_t>(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u});
1221   ShiftRightLogical(a, b);
1222 
1223   ComputeAndCompareR1<uint32_t>(&builder,
1224                                 {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1225 }
1226 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqF32s)1227 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
1228   SetFastMathDisabled(true);
1229   XlaBuilder builder(TestName());
1230   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1231   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN});
1232   Eq(lhs, rhs);
1233 
1234   ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1235 }
1236 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqF32sTO)1237 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) {
1238   SetFastMathDisabled(true);
1239   XlaBuilder builder(TestName());
1240   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1241   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN});
1242   EqTotalOrder(lhs, rhs);
1243 
1244   ComputeAndCompareR1<bool>(&builder, {false, false, true, true, false}, {});
1245 }
1246 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementF32s)1247 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
1248   XlaBuilder builder(TestName());
1249   auto lhs = ConstantR1<float>(&builder, {});
1250   auto rhs = ConstantR1<float>(&builder, {});
1251   Eq(lhs, rhs);
1252 
1253   ComputeAndCompareR1<bool>(&builder, {}, {});
1254 }
1255 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeF32s)1256 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
1257   SetFastMathDisabled(true);
1258   XlaBuilder builder(TestName());
1259   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1260   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1261   Ge(lhs, rhs);
1262 
1263   ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1264 }
1265 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeF32sTO)1266 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) {
1267   SetFastMathDisabled(true);
1268   XlaBuilder builder(TestName());
1269   // For portability, need to represent NAN using the following call.
1270   // The C++ standard does not specify if quiet_NaN() sets the sign bit of
1271   // its result. The call to std::fabs will ensure that it is not set.
1272   auto nan = std::fabs(std::numeric_limits<float>::quiet_NaN());
1273   auto lhs =
1274       ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, nan, 6.0f, 6.0f});
1275   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, nan, -nan});
1276   GeTotalOrder(lhs, rhs);
1277 
1278   ComputeAndCompareR1<bool>(&builder, {false, true, true, true, false, true},
1279                             {});
1280 }
1281 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtF32s)1282 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
1283   SetFastMathDisabled(true);
1284   XlaBuilder builder(TestName());
1285   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1286   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1287   Gt(lhs, rhs);
1288 
1289   ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1290 }
1291 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeF32s)1292 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
1293   SetFastMathDisabled(true);
1294   XlaBuilder builder(TestName());
1295   auto lhs = ConstantR1<float>(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f});
1296   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1297   Le(lhs, rhs);
1298 
1299   ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
1300 }
1301 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtF32s)1302 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
1303   SetFastMathDisabled(true);
1304   XlaBuilder builder(TestName());
1305   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1306   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1307   Lt(lhs, rhs);
1308 
1309   ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
1310 }
1311 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqS32s)1312 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
1313   const int32_t min = std::numeric_limits<int32_t>::min();
1314   const int32_t max = std::numeric_limits<int32_t>::max();
1315   XlaBuilder builder(TestName());
1316   auto lhs =
1317       ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1318   auto rhs =
1319       ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1320   Eq(lhs, rhs);
1321 
1322   ComputeAndCompareR1<bool>(
1323       &builder, {true, false, false, false, true, false, false, false, true},
1324       {});
1325 }
1326 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementS32s)1327 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
1328   XlaBuilder builder(TestName());
1329   auto lhs = ConstantR1<int32_t>(&builder, {});
1330   auto rhs = ConstantR1<int32_t>(&builder, {});
1331   Eq(lhs, rhs);
1332 
1333   ComputeAndCompareR1<bool>(&builder, {}, {});
1334 }
1335 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqC64s)1336 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) {
1337   SetFastMathDisabled(true);
1338   XlaBuilder builder(TestName());
1339   auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1340                                               {1.0f, 25.5f},
1341                                               {2.25f, -3.0f},
1342                                               {NAN, 0.0f},
1343                                               {1.0f, 6.0f}});
1344   auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1345                                               {1.0f, 5.0f},
1346                                               {2.25f, -3.0f},
1347                                               {10.0f, 0.0f},
1348                                               {1.0f, NAN}});
1349   Eq(lhs, rhs);
1350 
1351   ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1352 }
1353 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementC64s)1354 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) {
1355   XlaBuilder builder(TestName());
1356   auto lhs = ConstantR1<complex64>(&builder, {});
1357   auto rhs = ConstantR1<complex64>(&builder, {});
1358   Eq(lhs, rhs);
1359 
1360   ComputeAndCompareR1<bool>(&builder, {}, {});
1361 }
1362 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeC64s)1363 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) {
1364   // Disable fast-math because we're operating on NaNs.
1365   SetFastMathDisabled(true);
1366 
1367   XlaBuilder builder(TestName());
1368   auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1369                                               {1.0f, 25.5f},
1370                                               {2.25f, -3.0f},
1371                                               {NAN, 0.0f},
1372                                               {1.0f, 6.0f}});
1373   auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1374                                               {1.0f, 5.0f},
1375                                               {2.25f, -3.0f},
1376                                               {10.0f, 0.0f},
1377                                               {1.0f, NAN}});
1378   Ne(lhs, rhs);
1379 
1380   ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {});
1381 }
1382 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeF32s)1383 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
1384   // Disable fast-math because we're operating on NaNs.
1385   SetFastMathDisabled(true);
1386 
1387   XlaBuilder builder(TestName());
1388   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1389   auto rhs = ConstantR1<float>(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN});
1390   Ne(lhs, rhs);
1391 
1392   ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
1393 }
1394 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeS32s)1395 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
1396   const int32_t min = std::numeric_limits<int32_t>::min();
1397   const int32_t max = std::numeric_limits<int32_t>::max();
1398   XlaBuilder builder(TestName());
1399   auto lhs =
1400       ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1401   auto rhs =
1402       ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1403   Ne(lhs, rhs);
1404 
1405   ComputeAndCompareR1<bool>(
1406       &builder, {false, true, true, true, false, true, true, true, false}, {});
1407 }
1408 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeS32s)1409 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
1410   const int32_t min = std::numeric_limits<int32_t>::min();
1411   const int32_t max = std::numeric_limits<int32_t>::max();
1412   XlaBuilder builder(TestName());
1413   auto lhs =
1414       ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1415   auto rhs =
1416       ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1417   Ge(lhs, rhs);
1418 
1419   ComputeAndCompareR1<bool>(
1420       &builder, {true, false, false, true, true, false, true, true, true}, {});
1421 }
1422 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtS32s)1423 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
1424   const int32_t min = std::numeric_limits<int32_t>::min();
1425   const int32_t max = std::numeric_limits<int32_t>::max();
1426   XlaBuilder builder(TestName());
1427   auto lhs =
1428       ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1429   auto rhs =
1430       ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1431   Gt(lhs, rhs);
1432 
1433   ComputeAndCompareR1<bool>(
1434       &builder, {false, false, false, true, false, false, true, true, false},
1435       {});
1436 }
1437 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeS32s)1438 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
1439   const int32_t min = std::numeric_limits<int32_t>::min();
1440   const int32_t max = std::numeric_limits<int32_t>::max();
1441   XlaBuilder builder(TestName());
1442   auto lhs =
1443       ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1444   auto rhs =
1445       ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1446   Le(lhs, rhs);
1447 
1448   ComputeAndCompareR1<bool>(
1449       &builder, {true, true, true, false, true, true, false, false, true}, {});
1450 }
1451 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtS32s)1452 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
1453   const int32_t min = std::numeric_limits<int32_t>::min();
1454   const int32_t max = std::numeric_limits<int32_t>::max();
1455   XlaBuilder builder(TestName());
1456   auto lhs =
1457       ConstantR1<int32_t>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1458   auto rhs =
1459       ConstantR1<int32_t>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1460   Lt(lhs, rhs);
1461 
1462   ComputeAndCompareR1<bool>(
1463       &builder, {false, true, true, false, false, true, false, false, false},
1464       {});
1465 }
1466 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqU32s)1467 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
1468   const uint32_t max = std::numeric_limits<uint32_t>::max();
1469   XlaBuilder builder(TestName());
1470   auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1471   auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1472   Eq(lhs, rhs);
1473 
1474   ComputeAndCompareR1<bool>(
1475       &builder, {true, false, false, false, true, false, false, false, true},
1476       {});
1477 }
1478 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeU32s)1479 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
1480   const uint32_t max = std::numeric_limits<uint32_t>::max();
1481   XlaBuilder builder(TestName());
1482   auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1483   auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1484   Ne(lhs, rhs);
1485 
1486   ComputeAndCompareR1<bool>(
1487       &builder, {false, true, true, true, false, true, true, true, false}, {});
1488 }
1489 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeU32s)1490 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
1491   const uint32_t max = std::numeric_limits<uint32_t>::max();
1492   XlaBuilder builder(TestName());
1493   auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1494   auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1495   Ge(lhs, rhs);
1496 
1497   ComputeAndCompareR1<bool>(
1498       &builder, {true, false, false, true, true, false, true, true, true}, {});
1499 }
1500 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtU32s)1501 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
1502   const uint32_t max = std::numeric_limits<uint32_t>::max();
1503   XlaBuilder builder(TestName());
1504   auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1505   auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1506   Gt(lhs, rhs);
1507 
1508   ComputeAndCompareR1<bool>(
1509       &builder, {false, false, false, true, false, false, true, true, false},
1510       {});
1511 }
1512 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeU32s)1513 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
1514   const uint32_t max = std::numeric_limits<uint32_t>::max();
1515   XlaBuilder builder(TestName());
1516   auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1517   auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1518   Le(lhs, rhs);
1519 
1520   ComputeAndCompareR1<bool>(
1521       &builder, {true, true, true, false, true, true, false, false, true}, {});
1522 }
1523 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtU32s)1524 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
1525   const uint32_t max = std::numeric_limits<uint32_t>::max();
1526   XlaBuilder builder(TestName());
1527   auto lhs = ConstantR1<uint32_t>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1528   auto rhs = ConstantR1<uint32_t>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1529   Lt(lhs, rhs);
1530 
1531   ComputeAndCompareR1<bool>(
1532       &builder, {false, true, true, false, false, true, false, false, false},
1533       {});
1534 }
1535 
XLA_TEST_F(ArrayElementwiseOpTest,PowF32s)1536 XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
1537   SetFastMathDisabled(true);
1538   XlaBuilder builder(TestName());
1539   auto lhs = ConstantR1<float>(
1540       &builder, {0.0f, 4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f});
1541   auto rhs = ConstantR1<float>(
1542       &builder, {0.0f, 2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f});
1543   Pow(lhs, rhs);
1544 
1545   ComputeAndCompareR1<float>(&builder,
1546                              {1.0f, 16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f},
1547                              {}, error_spec_);
1548 }
1549 
XLA_TEST_F(ArrayElementwiseOpTest,PowNonIntegerF32s)1550 XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
1551   SetFastMathDisabled(true);
1552   XlaBuilder builder(TestName());
1553   auto lhs = ConstantR1<float>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f});
1554   auto rhs = ConstantR1<float>(&builder, {0.5f, 0.6f, -0.6f, -0.6f});
1555   Pow(lhs, rhs);
1556 
1557   ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
1558                              error_spec_);
1559 }
1560 
XLA_TEST_F(ArrayElementwiseOpTest,PowC64s)1561 XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) {
1562   SetFastMathDisabled(true);
1563   XlaBuilder builder(TestName());
1564   auto lhs =
1565       ConstantR1<complex64>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f});
1566   auto rhs =
1567       ConstantR1<complex64>(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f});
1568   Pow(lhs, rhs);
1569 
1570   ComputeAndCompareR1<complex64>(&builder,
1571                                  {
1572                                      {0, 1.41421356},
1573                                      {-2.27443288e-01, 0.69999846},
1574                                      {-4.19847531e-01, -1.29215783},
1575                                      {0, 0},
1576                                      {0, 0},
1577                                      {1, 0},
1578                                  },
1579                                  {}, error_spec_);
1580 }
1581 
XLA_TEST_F(ArrayElementwiseOpTest,PowZeroElementF32s)1582 XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
1583   XlaBuilder builder(TestName());
1584   auto lhs = ConstantR1<float>(&builder, {});
1585   auto rhs = ConstantR1<float>(&builder, {});
1586   Pow(lhs, rhs);
1587 
1588   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1589 }
1590 
1591 // Some Pow cases that can be implemented more efficiently.
XLA_TEST_F(ArrayElementwiseOpTest,PowSpecialF32)1592 XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
1593   XlaBuilder b(TestName());
1594 
1595   std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
1596   std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1597 
1598   Literal param_literal = LiteralUtil::CreateR1<float>(values);
1599   std::unique_ptr<GlobalData> param_data =
1600       client_->TransferToServer(param_literal).value();
1601 
1602   auto sum = ConstantR0<float>(&b, 0.0f);
1603   auto param = Parameter(&b, 0, param_literal.shape(), "param");
1604   for (float exponent : exponents) {
1605     sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
1606   }
1607 
1608   std::vector<float> expected;
1609   expected.reserve(values.size());
1610   for (auto value : values) {
1611     float sum = 0.0f;
1612     for (float exponent : exponents) {
1613       sum += std::pow(value, exponent);
1614     }
1615     expected.push_back(sum);
1616   }
1617 
1618   ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
1619 }
1620 
XLA_TEST_F(ArrayElementwiseOpTest,PowOfExpF32)1621 XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
1622   XlaBuilder b(TestName());
1623 
1624   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1625   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1626 
1627   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1628   std::unique_ptr<GlobalData> data0 =
1629       client_->TransferToServer(literal0).value();
1630   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1631   std::unique_ptr<GlobalData> data1 =
1632       client_->TransferToServer(literal1).value();
1633   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1634   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1635   Pow(Exp(param0), param1);
1636 
1637   std::vector<float> expected(values0.size());
1638   for (int64_t i = 0; i < values0.size(); ++i) {
1639     expected[i] = std::pow(std::exp(values0[i]), values1[i]);
1640   }
1641 
1642   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1643                              error_spec_);
1644 }
1645 
XLA_TEST_F(ArrayElementwiseOpTest,LogOfPowerF32)1646 XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
1647   XlaBuilder b(TestName());
1648 
1649   std::vector<float> values0 = {1.0f, -10.0f, -2.0f, 2.0f, 3.2f,
1650                                 4.0f, 0.5f,   5.7f,  0.0f};
1651   std::vector<float> values1 = {0.0f, 10.0f, -4.0f, 1.0f, 2.0f,
1652                                 0.5f, -1.0f, -0.5f, 0.0f};
1653 
1654   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1655   std::unique_ptr<GlobalData> data0 =
1656       client_->TransferToServer(literal0).value();
1657   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1658   std::unique_ptr<GlobalData> data1 =
1659       client_->TransferToServer(literal1).value();
1660   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1661   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1662   Log(Pow(param0, param1));
1663 
1664   std::vector<float> expected(values0.size());
1665   for (int64_t i = 0; i < values0.size(); ++i) {
1666     expected[i] = std::log(std::pow(values0[i], values1[i]));
1667   }
1668 
1669   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1670                              error_spec_);
1671 }
1672 
XLA_TEST_F(ArrayElementwiseOpTest,MulOfExpF32)1673 XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
1674   XlaBuilder b(TestName());
1675 
1676   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1677   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1678 
1679   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1680   std::unique_ptr<GlobalData> data0 =
1681       client_->TransferToServer(literal0).value();
1682   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1683   std::unique_ptr<GlobalData> data1 =
1684       client_->TransferToServer(literal1).value();
1685   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1686   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1687   Mul(Exp(param0), Exp(param1));
1688 
1689   std::vector<float> expected(values0.size());
1690   for (int64_t i = 0; i < values0.size(); ++i) {
1691     expected[i] = std::exp(values0[i]) * std::exp(values1[i]);
1692   }
1693 
1694   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1695                              error_spec_);
1696 }
1697 
XLA_TEST_F(ArrayElementwiseOpTest,DivOfExpF32)1698 XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
1699   XlaBuilder b(TestName());
1700 
1701   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1702   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1703 
1704   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1705   std::unique_ptr<GlobalData> data0 =
1706       client_->TransferToServer(literal0).value();
1707   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1708   std::unique_ptr<GlobalData> data1 =
1709       client_->TransferToServer(literal1).value();
1710   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1711   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1712   Div(param0, Exp(param1));
1713 
1714   std::vector<float> expected(values0.size());
1715   for (int64_t i = 0; i < values0.size(); ++i) {
1716     expected[i] = values0[i] / std::exp(values1[i]);
1717   }
1718 
1719   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1720                              error_spec_);
1721 }
1722 
XLA_TEST_F(ArrayElementwiseOpTest,Div3_lhs_F32)1723 XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
1724   XlaBuilder b(TestName());
1725 
1726   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1727   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1728   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1729 
1730   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1731   std::unique_ptr<GlobalData> data0 =
1732       client_->TransferToServer(literal0).value();
1733 
1734   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1735   std::unique_ptr<GlobalData> data1 =
1736       client_->TransferToServer(literal1).value();
1737 
1738   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1739   std::unique_ptr<GlobalData> data2 =
1740       client_->TransferToServer(literal2).value();
1741   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1742   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1743   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1744   Div(Div(param0, param1), param2);
1745 
1746   std::vector<float> expected(values0.size());
1747   for (int64_t i = 0; i < values0.size(); ++i) {
1748     expected[i] = (values0[i] / values1[i]) / values2[i];
1749   }
1750 
1751   ComputeAndCompareR1<float>(
1752       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1753 }
1754 
XLA_TEST_F(ArrayElementwiseOpTest,Div3_rhs_F32)1755 XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
1756   XlaBuilder b(TestName());
1757 
1758   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1759   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1760   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1761 
1762   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1763   std::unique_ptr<GlobalData> data0 =
1764       client_->TransferToServer(literal0).value();
1765 
1766   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1767   std::unique_ptr<GlobalData> data1 =
1768       client_->TransferToServer(literal1).value();
1769 
1770   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1771   std::unique_ptr<GlobalData> data2 =
1772       client_->TransferToServer(literal2).value();
1773 
1774   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1775   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1776   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1777   Div(param0, Div(param1, param2));
1778 
1779   std::vector<float> expected(values0.size());
1780   for (int64_t i = 0; i < values0.size(); ++i) {
1781     expected[i] = values0[i] / (values1[i] / values2[i]);
1782   }
1783 
1784   ComputeAndCompareR1<float>(
1785       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1786 }
1787 
XLA_TEST_F(ArrayElementwiseOpTest,DivOfPowerF32)1788 XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
1789   XlaBuilder b(TestName());
1790 
1791   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1792   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
1793   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
1794 
1795   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1796   std::unique_ptr<GlobalData> data0 =
1797       client_->TransferToServer(literal0).value();
1798 
1799   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1800   std::unique_ptr<GlobalData> data1 =
1801       client_->TransferToServer(literal1).value();
1802 
1803   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1804   std::unique_ptr<GlobalData> data2 =
1805       client_->TransferToServer(literal2).value();
1806 
1807   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1808   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1809   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1810   Div(param0, Pow(param1, param2));
1811 
1812   std::vector<float> expected(values0.size());
1813   for (int64_t i = 0; i < values0.size(); ++i) {
1814     expected[i] = values0[i] / std::pow(values1[i], values2[i]);
1815   }
1816 
1817   ComputeAndCompareR1<float>(
1818       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1819 }
1820 
XLA_TEST_F(ArrayElementwiseOpTest,Div4F32)1821 XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
1822   XlaBuilder b(TestName());
1823 
1824   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1825   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1826   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1827   std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
1828 
1829   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1830   std::unique_ptr<GlobalData> data0 =
1831       client_->TransferToServer(literal0).value();
1832 
1833   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1834   std::unique_ptr<GlobalData> data1 =
1835       client_->TransferToServer(literal1).value();
1836 
1837   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1838   std::unique_ptr<GlobalData> data2 =
1839       client_->TransferToServer(literal2).value();
1840 
1841   Literal literal3 = LiteralUtil::CreateR1<float>(values3);
1842   std::unique_ptr<GlobalData> data3 =
1843       client_->TransferToServer(literal3).value();
1844 
1845   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1846   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1847   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1848   auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
1849   Div(Div(param0, param1), Div(param2, param3));
1850 
1851   std::vector<float> expected(values0.size());
1852   for (int64_t i = 0; i < values0.size(); ++i) {
1853     expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]);
1854   }
1855 
1856   ComputeAndCompareR1<float>(
1857       &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()},
1858       error_spec_);
1859 }
1860 
TEST_P(ArrayElementwiseOpTestParamCount,SquareManyValues)1861 TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
1862   const int count = GetParam();
1863   XlaBuilder builder(TestName());
1864   std::vector<float> values;
1865   values.reserve(count);
1866   for (int i = 0; i < count; ++i) {
1867     values.push_back(i / static_cast<float>(count));
1868   }
1869   auto x = ConstantR1<float>(&builder, values);
1870   Pow(x, ConstantR0<float>(&builder, 2.0f));
1871 
1872   std::vector<float> expected;
1873   expected.reserve(values.size());
1874   for (float value : values) {
1875     expected.push_back(value * value);
1876   }
1877 
1878   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1879 }
1880 
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4D)1881 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
1882   XlaBuilder builder(TestName());
1883   Array4D<float> values(2, 2, 2, 2);
1884 
1885   std::vector<float> values_vector;
1886   std::vector<float> expected_vector;
1887   const auto num_elements = values.num_elements();
1888   values_vector.reserve(num_elements);
1889   expected_vector.reserve(num_elements);
1890   for (int i = 0; i < num_elements; ++i) {
1891     values_vector.push_back(static_cast<float>(i) / values.num_elements());
1892     expected_vector.push_back(values_vector.back() * values_vector.back());
1893   }
1894   values.SetValues(values_vector);
1895 
1896   Array4D<float> expected(2, 2, 2, 2, expected_vector);
1897 
1898   auto x = ConstantR4FromArray4D<float>(&builder, values);
1899   Pow(x, ConstantR0<float>(&builder, 2.0f));
1900 
1901   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1902 }
1903 
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4DZeroElements)1904 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
1905   XlaBuilder builder(TestName());
1906   Array4D<float> values(2, 2, 0, 2);
1907   Array4D<float> expected(2, 2, 0, 2);
1908 
1909   auto x = ConstantR4FromArray4D<float>(&builder, values);
1910   Pow(x, ConstantR0<float>(&builder, 2.0f));
1911 
1912   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1913 }
1914 
XLA_TEST_F(ArrayElementwiseOpTest,MinF32s)1915 XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
1916   XlaBuilder builder(TestName());
1917   SetFastMathDisabled(true);
1918   auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1919   auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1920   Min(lhs, rhs);
1921 
1922   ComputeAndCompareR1<float>(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {},
1923                              error_spec_);
1924 }
1925 
XLA_TEST_F(ArrayElementwiseOpTest,MinZeroElementF32s)1926 XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
1927   XlaBuilder builder(TestName());
1928   auto lhs = ConstantR1<float>(&builder, {});
1929   auto rhs = ConstantR1<float>(&builder, {});
1930   Min(lhs, rhs);
1931   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1932 }
1933 
XLA_TEST_F(ArrayElementwiseOpTest,MinF64s)1934 XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
1935   XlaBuilder builder(TestName());
1936   SetFastMathDisabled(true);
1937   auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1938   auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1939   Min(lhs, rhs);
1940 
1941   ComputeAndCompareR1<double>(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {},
1942                               strict_error_spec_);
1943 }
1944 
XLA_TEST_F(ArrayElementwiseOpTest,MaxF32s)1945 XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
1946   XlaBuilder builder(TestName());
1947   SetFastMathDisabled(true);
1948   auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1949   auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1950   Max(lhs, rhs);
1951 
1952   ComputeAndCompareR1<float>(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {},
1953                              error_spec_);
1954 }
1955 
XLA_TEST_F(ArrayElementwiseOpTest,DISABLED_ON_CPU (DefaultMaxF32sNaNPropagation))1956 XLA_TEST_F(ArrayElementwiseOpTest,
1957            DISABLED_ON_CPU(DefaultMaxF32sNaNPropagation)) {
1958   XlaBuilder builder(TestName());
1959   auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1960   auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1961   Max(lhs, rhs);
1962 
1963   ComputeAndCompareR1<float>(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {},
1964                              error_spec_);
1965 }
1966 
XLA_TEST_F(ArrayElementwiseOpTest,MaxZeroElementF32s)1967 XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
1968   XlaBuilder builder(TestName());
1969   auto lhs = ConstantR1<float>(&builder, {});
1970   auto rhs = ConstantR1<float>(&builder, {});
1971   Max(lhs, rhs);
1972   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1973 }
1974 
XLA_TEST_F(ArrayElementwiseOpTest,MaxF64s)1975 XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
1976   XlaBuilder builder(TestName());
1977   SetFastMathDisabled(true);
1978   auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1979   auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1980   Max(lhs, rhs);
1981 
1982   ComputeAndCompareR1<double>(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {},
1983                               strict_error_spec_);
1984 }
1985 
XLA_TEST_F(ArrayElementwiseOpTest,MaxS32s)1986 XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
1987   const int32_t min = std::numeric_limits<int32_t>::min();
1988   const int32_t max = std::numeric_limits<int32_t>::max();
1989   XlaBuilder builder(TestName());
1990   auto x = ConstantR1<int32_t>(
1991       &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1992   auto y = ConstantR1<int32_t>(
1993       &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1994   Max(x, y);
1995 
1996   std::vector<int32_t> expected = {min, max, 0,  -1,  0,   0,  0,
1997                                    1,   1,   10, max, max, max};
1998   ComputeAndCompareR1<int32_t>(&builder, expected, {});
1999 }
2000 
XLA_TEST_F(ArrayElementwiseOpTest,MinS32s)2001 XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
2002   const int32_t min = std::numeric_limits<int32_t>::min();
2003   const int32_t max = std::numeric_limits<int32_t>::max();
2004   XlaBuilder builder(TestName());
2005   auto x = ConstantR1<int32_t>(
2006       &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
2007   auto y = ConstantR1<int32_t>(
2008       &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
2009   Min(x, y);
2010 
2011   std::vector<int32_t> expected = {min, min, min, -10, -1,  -1, 0,
2012                                    0,   0,   1,   0,   max, min};
2013   ComputeAndCompareR1<int32_t>(&builder, expected, {});
2014 }
2015 
XLA_TEST_F(ArrayElementwiseOpTest,MaxU32s)2016 XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
2017   const uint32_t max = std::numeric_limits<uint32_t>::max();
2018   XlaBuilder builder(TestName());
2019   auto x = ConstantR1<uint32_t>(&builder, {0, 0, 1, 1, 1, max, max, max});
2020   auto y = ConstantR1<uint32_t>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
2021   Max(x, y);
2022 
2023   std::vector<uint32_t> expected = {0, 1, 1, 1, 10, max, max, max};
2024   ComputeAndCompareR1<uint32_t>(&builder, expected, {});
2025 }
2026 
XLA_TEST_F(ArrayElementwiseOpTest,MinU32s)2027 XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
2028   const uint32_t max = std::numeric_limits<uint32_t>::max();
2029   XlaBuilder builder(TestName());
2030   auto x = ConstantR1<uint32_t>(&builder, {0, 0, 1, 1, 1, max, max, max});
2031   auto y = ConstantR1<uint32_t>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
2032   Min(x, y);
2033 
2034   std::vector<uint32_t> expected = {0, 0, 0, 1, 1, 0, 234234, max};
2035   ComputeAndCompareR1<uint32_t>(&builder, expected, {});
2036 }
2037 
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenF32s)2038 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
2039   XlaBuilder builder(TestName());
2040   auto x = ConstantR1<float>(
2041       &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
2042   auto y = ConstantR1<float>(
2043       &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
2044   Max(x, y);
2045 
2046   std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
2047                                  5.0,  6.0, 7.0, 8.0, 9.0};
2048   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
2049 }
2050 
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S1AndR1S0F32s)2051 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
2052   XlaBuilder builder(TestName());
2053   auto u = ConstantR1<float>(&builder, {3.5});
2054   auto v = ConstantR1<float>(&builder, {});
2055   Max(u, v);
2056 
2057   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
2058 }
2059 
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S0AndR2S0x2F32s)2060 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
2061   for (int broadcast_dim : {0, 1}) {
2062     XlaBuilder builder(TestName());
2063     auto u = ConstantR1<float>(&builder, {3.5});
2064     auto v = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
2065     Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
2066 
2067     ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
2068   }
2069 }
2070 
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DF32s)2071 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
2072   XlaBuilder builder(TestName());
2073   auto v = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
2074   auto m = ConstantR2<float>(&builder,
2075                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2076   Max(v, m, /*broadcast_dimensions=*/{1});
2077 
2078   Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
2079   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2080 }
2081 
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DZeroElementF32s)2082 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
2083   XlaBuilder builder(TestName());
2084   auto v = ConstantR1<float>(&builder, {});
2085   auto m = ConstantR2<float>(&builder, {{}, {}});
2086   Max(v, m, /*broadcast_dimensions=*/{1});
2087 
2088   Array2D<float> expected({{}, {}});
2089   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2090 }
2091 
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarS32s)2092 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
2093   XlaBuilder builder(TestName());
2094   auto scalar = ConstantR0<int32_t>(&builder, 2);
2095   Array3D<int32_t> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
2096   auto array = ConstantR3FromArray3D<int32_t>(&builder, a_3d);
2097   Max(array, scalar, /*broadcast_dimensions=*/{});
2098 
2099   Array3D<int32_t> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
2100   ComputeAndCompareR3<int32_t>(&builder, expected, {});
2101 }
2102 
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarZeroElementS32s)2103 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
2104   XlaBuilder builder(TestName());
2105   auto scalar = ConstantR0<int32_t>(&builder, 2);
2106   Array3D<int32_t> a_3d(2, 0, 3);
2107   auto array = ConstantR3FromArray3D<int32_t>(&builder, a_3d);
2108   Max(array, scalar, /*broadcast_dimensions=*/{});
2109 
2110   Array3D<int32_t> expected(2, 0, 3);
2111   ComputeAndCompareR3<int32_t>(&builder, expected, {});
2112 }
2113 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DF32s)2114 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
2115   XlaBuilder builder(TestName());
2116   auto m = ConstantR2<float>(&builder,
2117                              {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
2118   auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
2119   Min(m, v, /*broadcast_dimensions=*/{0});
2120 
2121   Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
2122   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2123 }
2124 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DZeroElementF32s)2125 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
2126   XlaBuilder builder(TestName());
2127   auto m = ConstantR2<float>(&builder, {{}, {}});
2128   auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
2129   Min(m, v, /*broadcast_dimensions=*/{0});
2130 
2131   Array2D<float> expected({{}, {}});
2132   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2133 }
2134 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DF32s)2135 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
2136   XlaBuilder builder(TestName());
2137   auto array2d =
2138       ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2139   auto array4d = ConstantR4FromArray4D<float>(
2140       &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
2141                  {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
2142   Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2143 
2144   Array4D<float> expected(
2145       {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
2146        {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}});
2147   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2148 }
2149 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DZeroElementF32s)2150 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
2151   XlaBuilder builder(TestName());
2152   auto array2d =
2153       ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2154   Array4D<float> arg(2, 2, 0, 3);
2155   auto array4d = ConstantR4FromArray4D<float>(&builder, arg);
2156   Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2157 
2158   Array4D<float> expected(2, 2, 0, 3);
2159   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2160 }
2161 
XLA_TEST_F(ArrayElementwiseOpTest,MinTenS32s)2162 XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
2163   XlaBuilder builder(TestName());
2164   auto x = ConstantR1<int32_t>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2165   auto y = ConstantR1<int32_t>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2166   Min(x, y);
2167 
2168   std::vector<int32_t> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
2169   ComputeAndCompareR1<int32_t>(&builder, expected, {});
2170 }
2171 
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenS32s)2172 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
2173   XlaBuilder builder(TestName());
2174   auto x = ConstantR1<int32_t>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2175   auto y = ConstantR1<int32_t>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2176   Max(x, y);
2177 
2178   std::vector<int32_t> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
2179   ComputeAndCompareR1<int32_t>(&builder, expected, {});
2180 }
2181 
XLA_TEST_F(ArrayElementwiseOpTest,RemTwoConstantS32s)2182 XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
2183   XlaBuilder builder(TestName());
2184   auto a = ConstantR1<int32_t>(&builder, {-3, 26, 2, -1, 1});
2185   auto b = ConstantR1<int32_t>(&builder, {10, 5, 1, 10, -10});
2186   Rem(a, b);
2187 
2188   ComputeAndCompareR1<int32_t>(&builder, {-3, 1, 0, -1, 1}, {});
2189 }
2190 
XLA_TEST_F(ArrayElementwiseOpTest,NonNanClampF32)2191 XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
2192   XlaBuilder builder(TestName());
2193   auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2194   auto argument =
2195       ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2196   auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2197   Clamp(minimum, argument, maximum);
2198 
2199   ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
2200                              error_spec_);
2201 }
2202 
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32)2203 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) {
2204   SetFastMathDisabled(true);
2205   XlaBuilder builder(TestName());
2206   auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN});
2207   auto argument =
2208       ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2209   auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f});
2210   Clamp(minimum, argument, maximum);
2211 
2212   ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {},
2213                              error_spec_);
2214 }
2215 
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32Scalar)2216 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
2217   XlaBuilder builder(TestName());
2218   auto minimum = ConstantR0<float>(&builder, 0.0f);
2219   auto argument = ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2220   auto maximum = ConstantR0<float>(&builder, 5.0f);
2221   Clamp(minimum, argument, maximum);
2222 
2223   ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
2224                              error_spec_);
2225 }
2226 
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32ScalarVector)2227 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
2228   XlaBuilder builder(TestName());
2229   auto min_scalar = ConstantR0<float>(&builder, 0.0f);
2230   auto min_vector =
2231       ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2232   auto arg_vector =
2233       ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2234   auto max_scalar = ConstantR0<float>(&builder, 3.0f);
2235   auto max_vector =
2236       ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2237   // Perform clamp with broadcasted scalar and vector.
2238   Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2239           Clamp(min_scalar, arg_vector, max_vector)),
2240       Add(Clamp(min_vector, arg_vector, max_vector),
2241           Clamp(min_scalar, arg_vector, max_scalar)));
2242 
2243   ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {},
2244                              error_spec_);
2245 }
2246 
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32Vector)2247 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
2248   XlaBuilder builder(TestName());
2249   auto min_vector = ConstantR1<int32_t>(&builder, {1, -6, 1, 2, 0, -5});
2250   auto arg_vector = ConstantR1<int32_t>(&builder, {2, 10, -5, 1, 4, 10});
2251   auto max_vector = ConstantR1<int32_t>(&builder, {3, 0, 25, 5, 123, -1});
2252   Clamp(min_vector, arg_vector, max_vector);
2253 
2254   ComputeAndCompareR1<int32_t>(&builder, {2, 0, 1, 2, 4, -1}, {});
2255 }
2256 
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32ScalarVector)2257 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) {
2258   XlaBuilder builder(TestName());
2259   auto min_scalar = ConstantR0<int32_t>(&builder, 0);
2260   auto min_vector = ConstantR1<int32_t>(&builder, {1, -6, 1, 2, 0});
2261   auto arg_vector = ConstantR1<int32_t>(&builder, {2, 10, -5, 1, 4});
2262   auto max_scalar = ConstantR0<int32_t>(&builder, 3);
2263   auto max_vector = ConstantR1<int32_t>(&builder, {3, 1, 25, 5, 123});
2264   // Perform clamp with broadcasted scalar and vector.
2265   Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2266           Clamp(min_scalar, arg_vector, max_vector)),
2267       Add(Clamp(min_vector, arg_vector, max_vector),
2268           Clamp(min_scalar, arg_vector, max_scalar)));
2269 
2270   ComputeAndCompareR1<int32_t>(&builder, {8, 8, 2, 6, 14}, {});
2271 }
2272 
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32Vector)2273 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
2274   XlaBuilder builder(TestName());
2275   auto min_vector = ConstantR1<uint32_t>(&builder, {1, 2, 1, 2, 0, ~0u - 4});
2276   auto arg_vector = ConstantR1<uint32_t>(&builder, {2, 10, 5, 1, 4, 10});
2277   auto max_vector = ConstantR1<uint32_t>(&builder, {3, 5, 25, 5, 123, ~0u});
2278   Clamp(min_vector, arg_vector, max_vector);
2279 
2280   ComputeAndCompareR1<uint32_t>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
2281 }
2282 
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32ScalarVector)2283 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
2284   XlaBuilder builder(TestName());
2285   auto min_scalar = ConstantR0<uint32_t>(&builder, 0);
2286   auto min_vector = ConstantR1<uint32_t>(&builder, {1, 0, 1, 2, 0});
2287   auto arg_vector = ConstantR1<uint32_t>(&builder, {2, 10, 0, 1, 4});
2288   auto max_scalar = ConstantR0<uint32_t>(&builder, 3);
2289   auto max_vector = ConstantR1<uint32_t>(&builder, {3, 1, 25, 5, 123});
2290   // Perform clamp with broadcasted scalar and vector.
2291   Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2292           Clamp(min_scalar, arg_vector, max_vector)),
2293       Add(Clamp(min_vector, arg_vector, max_vector),
2294           Clamp(min_scalar, arg_vector, max_scalar)));
2295 
2296   ComputeAndCompareR1<uint32_t>(&builder, {8, 8, 2, 6, 14}, {});
2297 }
2298 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersF32s)2299 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
2300   XlaBuilder builder(TestName());
2301 
2302   Literal param0_literal =
2303       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2304   std::unique_ptr<GlobalData> param0_data =
2305       client_->TransferToServer(param0_literal).value();
2306 
2307   Literal param1_literal =
2308       LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
2309   std::unique_ptr<GlobalData> param1_data =
2310       client_->TransferToServer(param1_literal).value();
2311 
2312   auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2313   auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2314   Add(p0, p1);
2315 
2316   ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
2317                              {param0_data.get(), param1_data.get()},
2318                              error_spec_);
2319 }
2320 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersZeroElementF32s)2321 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
2322   XlaBuilder builder(TestName());
2323 
2324   Literal param0_literal =
2325       LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2326   std::unique_ptr<GlobalData> param0_data =
2327       client_->TransferToServer(param0_literal).value();
2328 
2329   Literal param1_literal =
2330       LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2331   std::unique_ptr<GlobalData> param1_data =
2332       client_->TransferToServer(param1_literal).value();
2333 
2334   auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2335   auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2336   Add(p0, p1);
2337 
2338   Array3D<float> expected(0, 7, 0);
2339   ComputeAndCompareR3<float>(
2340       &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
2341 }
2342 
XLA_TEST_F(ArrayElementwiseOpTest,AddParameterToConstantF32s)2343 XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
2344   XlaBuilder builder(TestName());
2345 
2346   Literal param0_literal =
2347       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2348   std::unique_ptr<GlobalData> param0_data =
2349       client_->TransferToServer(param0_literal).value();
2350 
2351   auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2352   auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
2353   Add(a, p);
2354 
2355   ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
2356                              {param0_data.get()}, error_spec_);
2357 }
2358 
XLA_TEST_F(ArrayElementwiseOpTest,CosF32s)2359 XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) {
2360   XlaBuilder builder(TestName());
2361   auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2362   Cos(a);
2363 
2364   ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {},
2365                              error_spec_);
2366 }
2367 
XLA_TEST_F(ArrayElementwiseOpTest,SinF32s)2368 XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
2369   XlaBuilder builder(TestName());
2370   auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2371   Sin(a);
2372 
2373   ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {},
2374                              error_spec_);
2375 }
2376 
XLA_TEST_F(ArrayElementwiseOpTest,RealF64s)2377 XLA_TEST_F(ArrayElementwiseOpTest, RealF64s) {
2378   XlaBuilder builder(TestName());
2379   std::vector<double> xs = {3.14159f, 0.0f, 1.570796f, -0.78539f};
2380   auto a = ConstantR1<double>(&builder, xs);
2381   Real(a);
2382   ComputeAndCompareR1<double>(&builder, xs, {});
2383 }
2384 
XLA_TEST_F(ArrayElementwiseOpTest,ImagF64s)2385 XLA_TEST_F(ArrayElementwiseOpTest, ImagF64s) {
2386   XlaBuilder builder(TestName());
2387   std::vector<double> xs = {3.14159, 0.0, 1.570796, -0.78539};
2388   auto a = ConstantR1<double>(&builder, xs);
2389   Imag(a);
2390   ComputeAndCompareR1<double>(&builder, {0., 0., 0., 0.}, {});
2391 }
2392 
XLA_TEST_F(ArrayElementwiseOpTest,Atan2F32s)2393 XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
2394   XlaBuilder builder(TestName());
2395   auto inf = std::numeric_limits<float>::infinity();
2396   std::vector<float> ys;
2397   std::vector<float> xs;
2398   const auto _ys = {+0.0f, -0.0f, inf, -inf, 5.0f, -3.0f, 2.0f, -8.0f, 1.0f};
2399   const auto _xs = {+0.0f, -0.0f, inf, -inf, 6.0f, -4.0f, 2.0f, 8.0f};
2400   const auto n = _ys.size() * _xs.size();
2401   ys.reserve(n);
2402   xs.reserve(n);
2403   for (auto y : _ys) {
2404     for (auto x : _xs) {
2405       ys.push_back(y);
2406       xs.push_back(x);
2407     }
2408   }
2409   auto y = ConstantR1<float>(&builder, ys);
2410   auto x = ConstantR1<float>(&builder, xs);
2411   Atan2(y, x);
2412 
2413   ComputeAndCompare(&builder, {}, error_spec_);
2414 }
2415 
XLA_TEST_F(ArrayElementwiseOpTest,Atan2C64s)2416 XLA_TEST_F(ArrayElementwiseOpTest, Atan2C64s) {
2417   XlaBuilder builder(TestName());
2418   auto inf = std::numeric_limits<float>::infinity();
2419   std::vector<std::complex<float>> ys;
2420   std::vector<std::complex<float>> xs;
2421   const auto _ys = {+0.0f, -0.0f, inf, -inf, 5.0f, -3.0f, 2.0f, -8.0f, 1.0f};
2422   const auto _xs = {+0.0f, -0.0f, inf, -inf, 6.0f, -4.0f, 2.0f, 8.0f};
2423   const auto n = _ys.size() * _xs.size();
2424   ys.reserve(n);
2425   xs.reserve(n);
2426   for (auto y : _ys) {
2427     for (auto x : _xs) {
2428       ys.push_back(y);
2429       xs.push_back(x);
2430     }
2431   }
2432   auto y = ConstantR1<std::complex<float>>(&builder, ys);
2433   auto x = ConstantR1<std::complex<float>>(&builder, xs);
2434   Atan2(y, x);
2435 
2436   ComputeAndCompare(&builder, {}, error_spec_);
2437 }
2438 
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32s)2439 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
2440   XlaBuilder builder(TestName());
2441   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f});
2442   Tanh(a);
2443 
2444   ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
2445                              error_spec_);
2446 }
2447 
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32sVector)2448 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
2449   // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that
2450   // the input tensor is large enough to exercise the vectorized tanh
2451   // implementation on XLA CPU.
2452   XlaBuilder builder(TestName());
2453   auto input_literal = LiteralUtil::CreateR1<float>(
2454       {1.02,  -0.32, 0.85,  0.90,  1.23,  -0.91, -0.49, 0.80,  -0.67, 0.16,
2455        -0.07, 0.39,  -0.41, 0.04,  1.36,  1.25,  0.41,  0.65,  -1.08, 0.32,
2456        -1.45, -0.77, -1.09, 0.91,  -1.03, -0.30, -1.11, -1.17, 1.50,  -0.85,
2457        0.04,  1.02,  0.34,  -0.61, 0.41,  0.07,  -0.02, 1.42,  -0.62, 0.81,
2458        0.08,  0.81,  -0.30, 1.17,  -0.65, -0.44, 0.92,  1.26,  -1.29, 1.35,
2459        0.08,  -1.24, -0.92, 0.49,  1.17,  -0.45, -1.31, -1.44, -0.13, -1.31,
2460        -0.79, 1.41,  1.21,  1.05});
2461   TF_ASSERT_OK_AND_ASSIGN(auto input_data,
2462                           client_->TransferToServer(input_literal));
2463 
2464   auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2465   Tanh(input);
2466 
2467   ComputeAndCompareR1<float>(
2468       &builder,
2469       {0.77009583,  -0.30665702, 0.69070244,  0.71401149,  0.84400684,
2470        -0.71985596, -0.45764771, 0.66664988,  -0.58278900, 0.16050975,
2471        -0.06770509, 0.36843640,  -0.38476998, 0.04018109,  0.87562293,
2472        0.84788644,  0.38603750,  0.57294142,  -0.79140943, 0.31032649,
2473        -0.89590985, -0.64770776, -0.79625875, 0.72234446,  -0.77389336,
2474        -0.28871772, -0.80428445, -0.82541436, 0.90456349,  -0.68856895,
2475        0.03877772,  0.76877952,  0.32561871,  -0.54546672, 0.39072621,
2476        0.07273290,  -0.01924866, 0.88924897,  -0.55283129, 0.67183107,
2477        0.08006320,  0.66944766,  -0.29068485, 0.82573754,  -0.57170743,
2478        -0.41581789, 0.72739530,  0.85025692,  -0.85931867, 0.87357593,
2479        0.07782833,  -0.84597743, -0.72748238, 0.45396307,  0.82449573,
2480        -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558,
2481        -0.65565848, 0.88789743,  0.83566397,  0.78287679},
2482       {input_data.get()},
2483       // The error spec is unusually high here to account for the fact that we
2484       // use a rational interpolant to approximate tanh.
2485       ErrorSpec(0.004, 0.004));
2486 }
2487 
XLA_TEST_F(ArrayElementwiseOpTest,ExpF32sVector)2488 XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
2489   // The input tensor is large enough to exercise the vectorized exp
2490   // implementation on XLA CPU.
2491   XlaBuilder builder(TestName());
2492 
2493   // Just to help make sense of the scales here -- exp(89) saturates float32 and
2494   // exp(-10) is smaller than our error spec.
2495   Literal input_literal = LiteralUtil::CreateR1<float>(
2496       {1.02,   -0.32,  0.85,   0.9,    1.23,   -0.91,  -0.49, 0.8,    -1.31,
2497        -1.44,  -0.13,  -1.31,  -0.79,  1.41,   1.21,   1.05,  -195.6, -194.5,
2498        -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5,  -17.4,
2499        -16.3,  -15.2,  -14.1,  -13.0,  -11.9,  -10.8,  -9.7,  -8.6,   -7.5,
2500        -6.4,   -5.3,   -4.2,   -3.1,   -2.0,   -0.9,   0.2,   1.3,    2.4,
2501        3.5,    4.6,    5.7,    6.8,    7.9,    9.0,    10.1,  11.2,   12.3,
2502        13.4,   14.5,   15.6,   16.7,   17.8,   18.9,   20.0,  21.1,   22.2,
2503        23.3,   24.4,   25.5,   26.6,   27.7,   28.8,   29.9,  31.0,   32.1,
2504        68.4,   69.5,   70.6,   71.7,   72.8,   73.9,   75.0,  76.1,   77.2,
2505        78.3,   79.4,   80.5,   81.6,   82.7,   83.8,   84.9,  85.2,   86.3,
2506        86.4,   86.5,   87.6,   87.7,   87.8,   87.9});
2507   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2508                           client_->TransferToServer(input_literal));
2509 
2510   auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2511   Exp(input);
2512 
2513   std::vector<float> expected_result;
2514   int64_t input_size = input_literal.shape().dimensions(0);
2515   expected_result.reserve(input_size);
2516   for (int64_t i = 0; i < input_size; i++) {
2517     expected_result.push_back(std::exp(input_literal.Get<float>({i})));
2518   }
2519 
2520   ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2521                              error_spec_);
2522 }
2523 
XLA_TEST_F(ArrayElementwiseOpTest,LogF32sVector)2524 XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
2525   // The input tensor is large enough to exercise the vectorized exp
2526   // implementation on XLA CPU.
2527   XlaBuilder builder(TestName());
2528 
2529   Literal input_literal = LiteralUtil::CreateR1<float>(
2530       {-1.29,    -1.41,    -1.25,    -13.5,    -11.7,    -17.9,    -198,
2531        -167,     1.29,     1.41,     1.25,     13.5,     11.7,     17.9,
2532        198,      167,      1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04,  1.84e+04,
2533        1.74e+04, 1.89e+05, 1.9e+05,  1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
2534        1.66e+07, 1e+07,    1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
2535        1.44e+10, 1.5e+10,  1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
2536        1.4e+12,  1.03e+13, 1.6e+13,  1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
2537        1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
2538        2e+18,    1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
2539        1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21,  1.35e+22, 1.84e+22, 1.02e+22,
2540        1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
2541        1.62e+25, 1.2e+26,  1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
2542        1.5e+28,  1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30,  1.81e+30, 1.34e+30,
2543        1.7e+31,  1.44e+31, 1.1e+31,  1.4e+32,  1.67e+32, 1.96e+33, 1.11e+33,
2544        1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
2545   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2546                           client_->TransferToServer(input_literal));
2547 
2548   auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2549   Log(input);
2550 
2551   std::vector<float> expected_result;
2552   int64_t input_size = input_literal.shape().dimensions(0);
2553   expected_result.reserve(input_size);
2554   for (int64_t i = 0; i < input_size; i++) {
2555     expected_result.push_back(std::log(input_literal.Get<float>({i})));
2556   }
2557 
2558   ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2559                              error_spec_);
2560 }
2561 
XLA_TEST_F(ArrayElementwiseOpTest,ClzU32s)2562 XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) {
2563   XlaBuilder builder(TestName());
2564   auto a = ConstantR1<uint32_t>(
2565       &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678});
2566   Clz(a);
2567 
2568   ComputeAndCompareR1<uint32_t>(&builder, {32, 31, 27, 15, 9, 3, 0}, {});
2569 }
2570 
XLA_TEST_F(ArrayElementwiseOpTest,ClzS64s)2571 XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) {
2572   XlaBuilder builder(TestName());
2573   auto a = ConstantR1<int64_t>(&builder,
2574                                {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1});
2575   Clz(a);
2576 
2577   ComputeAndCompareR1<int64_t>(&builder, {64, 63, 32, 1, 0}, {});
2578 }
2579 
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldLeft)2580 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
2581   // a ------ (add) --------- (add)
2582   //         /               /
2583   // b -----/               /
2584   // c---------------------/
2585   XlaBuilder builder(TestName());
2586 
2587   auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2588   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2589   auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2590 
2591   auto add = Add(a, b);
2592   Add(add, c);
2593 
2594   ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
2595                              error_spec_);
2596 }
2597 
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldRight)2598 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
2599   // b ------ (add) --------- (add)
2600   //         /               /
2601   // c -----/               /
2602   // a---------------------/
2603   XlaBuilder builder(TestName());
2604 
2605   auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2606   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2607   auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2608 
2609   auto add = Add(b, c);
2610   Add(a, add);
2611 
2612   ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
2613                              error_spec_);
2614 }
2615 
XLA_TEST_F(ArrayElementwiseOpTest,AddWithNeg)2616 XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
2617   // a ----- (neg) ----- (add)
2618   //                    /
2619   // b ----- (neg) ----/
2620   XlaBuilder builder(TestName());
2621 
2622   auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2623   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2624 
2625   auto neg_a = Neg(a);
2626   auto neg_b = Neg(b);
2627   Add(neg_a, neg_b);
2628 
2629   ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
2630                              error_spec_);
2631 }
2632 
XLA_TEST_F(ArrayElementwiseOpTest,AddChainTwoSide)2633 XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
2634   // a ------ (add) ------------\
2635   //         /                   \
2636   // b -----/                    (add)
2637   //                             /
2638   // c ------ (add) ------------/
2639   //         /
2640   // d -----/
2641   XlaBuilder builder(TestName());
2642 
2643   auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2644   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2645   auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2646   auto d = ConstantR1<float>(&builder, {-19.0f, 10.0f, -40.0f, 20.2f});
2647 
2648   auto add_ab = Add(a, b);
2649   auto add_cd = Add(c, d);
2650   Add(add_ab, add_cd);
2651 
2652   ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
2653                              error_spec_);
2654 }
2655 
2656 XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
2657   XlaBuilder builder(TestName());
2658   auto a = ConstantR2<float>(&builder,
2659                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2660   auto b = ConstantR2<float>(&builder,
2661                              {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2662   Add(a, b);
2663 
2664   Array2D<float> expected_array(
2665       {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
2666   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2667 }
2668 
XLA_TEST_F(ArrayElementwiseOpTest,ScalarPlus2DF32)2669 XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
2670   // Add a scalar + matrix.
2671   XlaBuilder builder(TestName());
2672   auto a = ConstantR2<float>(&builder,
2673                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2674   auto scalar = ConstantR0<float>(&builder, 3.0f);
2675   Add(scalar, a);
2676 
2677   Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2678   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2679 }
2680 
2681 XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
2682   // Add a matrix + scalar.
2683   XlaBuilder builder(TestName());
2684   auto a = ConstantR2<float>(&builder,
2685                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2686   auto scalar = ConstantR0<float>(&builder, 3.0f);
2687   Add(a, scalar);
2688 
2689   Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2690   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2691 }
2692 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32)2693 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
2694   // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
2695   // only dim 0 of the matrix.
2696   XlaBuilder builder(TestName());
2697   auto v = ConstantR1<float>(&builder, {20.0f, 40.0f, 60.0f});
2698   // clang-format off
2699   auto m = ConstantR2<float>(&builder, {
2700     {-2.5f, 3.14f, 1.0f},
2701     {2.25f, -10.0f, 3.33f}});
2702   // clang-format on
2703   Add(v, m, /*broadcast_dimensions=*/{1});
2704   Array2D<float> expected_array(
2705       {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
2706   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2707 }
2708 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Eq)2709 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
2710   // Test broadcasting in Eq comparison.
2711   XlaBuilder builder(TestName());
2712   auto v = ConstantR1<int32_t>(&builder, {42, 73});
2713   auto m = ConstantR2<int32_t>(&builder, {{42, 73}, {42, 52}});
2714 
2715   // This test exercises both possible broadcast dimensions for a vector/matrix
2716   // comparison.
2717   auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1});
2718   auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
2719   Tuple(&builder, {cmp_dim_0, cmp_dim_1});
2720 
2721   auto expected = LiteralUtil::MakeTupleFromSlices(
2722       {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
2723        LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
2724   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
2725 }
2726 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ne)2727 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
2728   // Test broadcasting in Ne comparison.
2729   XlaBuilder builder(TestName());
2730   auto v = ConstantR1<int32_t>(&builder, {42, 73});
2731   auto m = ConstantR2<int32_t>(&builder, {{42, 73}, {42, 52}});
2732   Ne(v, m, /*broadcast_dimensions=*/{1});
2733 
2734   const std::string expected = R"(pred[2,2] {
2735   { 0, 0 },
2736   { 0, 1 }
2737 })";
2738   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2739 }
2740 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ge)2741 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
2742   // Test broadcasting in Ge comparison.
2743   XlaBuilder builder(TestName());
2744   auto v = ConstantR1<int32_t>(&builder, {1, 2, 3, 4});
2745   auto m = ConstantR2<int32_t>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2746   Ge(v, m, /*broadcast_dimensions=*/{1});
2747 
2748   const std::string expected = R"(pred[2,4] {
2749   { 1, 1, 0, 0 },
2750   { 0, 0, 0, 1 }
2751 })";
2752   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2753 }
2754 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Gt)2755 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
2756   // Test broadcasting in Gt comparison.
2757   XlaBuilder builder(TestName());
2758   auto v = ConstantR1<int32_t>(&builder, {1, 2, 3, 4});
2759   auto m = ConstantR2<int32_t>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2760   Gt(v, m, /*broadcast_dimensions=*/{1});
2761 
2762   const std::string expected = R"(pred[2,4] {
2763   { 0, 1, 0, 0 },
2764   { 0, 0, 0, 0 }
2765 })";
2766   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2767 }
2768 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Le)2769 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
2770   // Test broadcasting in Le comparison.
2771   XlaBuilder builder(TestName());
2772   auto v = ConstantR1<int32_t>(&builder, {1, 2, 3, 4});
2773   auto m = ConstantR2<int32_t>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2774   Le(v, m, /*broadcast_dimensions=*/{1});
2775 
2776   const std::string expected = R"(pred[2,4] {
2777   { 1, 0, 1, 1 },
2778   { 1, 1, 1, 1 }
2779 })";
2780   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2781 }
2782 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Lt)2783 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
2784   // Test broadcasting in Lt comparison.
2785   XlaBuilder builder(TestName());
2786   auto v = ConstantR1<int32_t>(&builder, {1, 2, 3, 4});
2787   auto m = ConstantR2<int32_t>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2788   Lt(v, m, /*broadcast_dimensions=*/{1});
2789 
2790   const std::string expected = R"(pred[2,4] {
2791   { 0, 0, 1, 1 },
2792   { 1, 1, 1, 0 }
2793 })";
2794   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2795 }
2796 
XLA_TEST_F(ArrayElementwiseOpTest,Mul2Dby1DF32)2797 XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
2798   // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
2799   // arguments is reversed.
2800   XlaBuilder builder(TestName());
2801   auto m =
2802       ConstantR2<float>(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
2803   auto v = ConstantR1<float>(&builder, {2.0f, 4.0f, 6.0f});
2804   Mul(m, v, /*broadcast_dimensions=*/{1});
2805   Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
2806   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2807 }
2808 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim1)2809 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
2810   // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2811   XlaBuilder builder(TestName());
2812   // m's shape in XLA notation is {3, 2}
2813   // md's shape in XLA notation is {3, 1}
2814   // The result has shape {3, 2}, where md is broadcast over m
2815   auto m = ConstantR2<float>(&builder,
2816                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2817   auto md = ConstantR2<float>(&builder, {{10.0f, 20.0f, 30.0f}});
2818   Add(m, md);
2819   Array2D<float> expected_array(
2820       {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
2821   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2822 }
2823 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim0)2824 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
2825   // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2826   XlaBuilder builder(TestName());
2827   // m's shape in XLA notation is {3, 2}
2828   // md's shape in XLA notation is {1, 2}
2829   // The result has shape {3, 2}, where md is broadcast over m
2830   auto m = ConstantR2<float>(&builder,
2831                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2832   auto md = ConstantR2<float>(&builder, {{10.0f}, {20.0f}});
2833   Add(m, md);
2834   Array2D<float> expected_array(
2835       {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
2836   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2837 }
2838 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DsWithDegenerateDimsOuterProduct)2839 XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
2840   // Tests broadcasting for two degenerate arrays. This kind of broadcasting
2841   // effectively creates an "outer product" operation.
2842   // This is taken from the Numpy docs example at:
2843   // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html
2844   XlaBuilder builder(TestName());
2845   // a's shape in XLA notation is {1, 4}
2846   // b's shape in XLA notation is {3, 1}
2847   // The result has shape {3, 4}.
2848   auto a = ConstantR2<float>(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}});
2849   auto b = ConstantR2<float>(&builder, {{1.0f, 2.0f, 3.0f}});
2850   Add(a, b);
2851   Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
2852                                  {11.0f, 12.0f, 13.0f},
2853                                  {21.0f, 22.0f, 23.0f},
2854                                  {31.0f, 32.0f, 33.0f}});
2855   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2856 }
2857 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver1)2858 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
2859   // Add together a (2,2) array and a (2) array, using dimension 0 for
2860   // broadcasting (though there are two ways to broadcast these shapes).
2861   XlaBuilder builder(TestName());
2862   auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2863   auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2864   Add(v, m, /*broadcast_dimensions=*/{1});
2865   Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
2866   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2867 }
2868 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver0)2869 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
2870   // Add together a (2,2) array and a (2) array, using dimension 1 for
2871   // broadcasting (though there are two ways to broadcast these shapes).
2872   XlaBuilder builder(TestName());
2873   auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2874   auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2875   Add(v, m, /*broadcast_dimensions=*/{0});
2876   Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
2877   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2878 }
2879 
2880 XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
2881   // Binary add of two R3s together
2882   XlaBuilder builder(TestName());
2883   Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2884                        {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2885   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2886 
2887   Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
2888                        {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
2889   auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2890   Add(a, b);
2891 
2892   Array3D<float> expected_3d(
2893       {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
2894        {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}});
2895   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2896 }
2897 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver2)2898 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
2899   // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for
2900   // broadcasting (though there are two ways to broadcast these shapes).
2901   XlaBuilder builder(TestName());
2902   // clang-format off
2903   Array3D<float> a_3d({
2904     {{1.0f, 2.0f},
2905      {3.0f, 4.0f},
2906      {5.0f, 6.0f}},
2907     {{7.0f, 8.0f},
2908      {9.0f, 10.0f},
2909      {11.0f, 12.0f}},
2910   });
2911   // clang-format on
2912   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2913   auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2914   Add(a, v, /*broadcast_dimensions=*/{2});
2915 
2916   Array3D<float> expected_3d(
2917       {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
2918        {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}});
2919   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2920 }
2921 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver0)2922 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
2923   // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for
2924   // broadcasting (though there are two ways to broadcast these shapes).
2925   XlaBuilder builder(TestName());
2926   // clang-format off
2927   Array3D<float> a_3d({
2928     {{1.0f, 2.0f},
2929      {3.0f, 4.0f},
2930      {5.0f, 6.0f}},
2931     {{7.0f, 8.0f},
2932      {9.0f, 10.0f},
2933      {11.0f, 12.0f}},
2934   });
2935   // clang-format on
2936   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2937   auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2938   Add(a, v, /*broadcast_dimensions=*/{0});
2939 
2940   // clang-format off
2941   Array3D<float> expected_3d({
2942     {{11.0f, 12.0f},
2943      {13.0f, 14.0f},
2944      {15.0f, 16.0f}},
2945     {{27.0f, 28.0f},
2946      {29.0f, 30.0f},
2947      {31.0f, 32.0f}},
2948   });
2949   // clang-format on
2950   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2951 }
2952 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo3D)2953 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
2954   // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2}
2955   // for broadcasting.
2956   XlaBuilder builder(TestName());
2957   // clang-format off
2958   Array3D<float> a_3d({
2959     {{1.0f, 2.0f},
2960      {3.0f, 4.0f},
2961      {5.0f, 6.0f}},
2962     {{7.0f, 8.0f},
2963      {9.0f, 10.0f},
2964      {11.0f, 12.0f}},
2965   });
2966   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2967   auto m = ConstantR2<float>(&builder, {
2968     {10.0f, 20.0f, 30.0f},
2969     {40.0f, 50.0f, 60.0f},
2970   });
2971   Add(a, m, /*broadcast_dimensions=*/{0, 1});
2972 
2973   Array3D<float> expected_3d({
2974     {{11.0f, 12.0f},
2975      {23.0f, 24.0f},
2976      {35.0f, 36.0f}},
2977     {{47.0f, 48.0f},
2978      {59.0f, 60.0f},
2979      {71.0f, 72.0f}},
2980   });
2981   // clang-format on
2982   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2983 }
2984 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtR3F32sWithDegenerateDim2)2985 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
2986   // Comparison between two 3D arrays of compatible shapes:
2987   // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs.
2988   XlaBuilder builder(TestName());
2989   Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2990                        {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2991   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2992 
2993   Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
2994   auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2995 
2996   Gt(a, b);
2997 
2998   Array3D<int> expected_3d(
2999       {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
3000   const std::string expected = R"(pred[2,3,2] {
3001 {
3002   { 0, 1 },
3003   { 0, 0 },
3004   { 0, 0 }
3005 },
3006 {
3007   { 0, 1 },
3008   { 1, 0 },
3009   { 0, 1 }
3010 }
3011 })";
3012   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
3013 }
3014 
3015 XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
3016   XlaBuilder builder(TestName());
3017 
3018   std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
3019   std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5));
3020   std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
3021   float value = 0.0;
3022   for (int64_t p = 0; p < 2; ++p) {
3023     for (int64_t z = 0; z < 3; ++z) {
3024       for (int64_t y = 0; y < 4; ++y) {
3025         for (int64_t x = 0; x < 5; ++x) {
3026           (*operand_a_4d)(p, z, y, x) = value;
3027           (*operand_b_4d)(p, z, y, x) = 2.0 * value;
3028           (*expected_4d)(p, z, y, x) = 3.0 * value;
3029           value += 0.1;
3030         }
3031       }
3032     }
3033   }
3034 
3035   auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
3036   auto b = ConstantR4FromArray4D<float>(&builder, *operand_b_4d);
3037   Add(a, b);
3038 
3039   ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
3040 }
3041 
XLA_TEST_F(ArrayElementwiseOpTest,R4PlusR1InDim1)3042 XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
3043   XlaBuilder builder(TestName());
3044 
3045   std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
3046   std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
3047   std::vector<float> operand_b_1d(3);
3048   std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0);
3049 
3050   float value = 0.0;
3051   for (int64_t p = 0; p < 2; ++p) {
3052     for (int64_t z = 0; z < 3; ++z) {
3053       for (int64_t y = 0; y < 4; ++y) {
3054         for (int64_t x = 0; x < 5; ++x) {
3055           (*operand_a_4d)(p, z, y, x) = value;
3056           (*expected_4d)(p, z, y, x) = value + operand_b_1d[z];
3057           value += 0.1;
3058         }
3059       }
3060     }
3061   }
3062 
3063   auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
3064   auto b = ConstantR1<float>(&builder, operand_b_1d);
3065   Add(a, b, {1});
3066 
3067   ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
3068 }
3069 
XLA_TEST_F(ArrayElementwiseOpTest,R4_16x16x2x2_Plus_R1_16)3070 XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
3071   constexpr int d0 = 16;
3072   constexpr int d1 = 16;
3073   constexpr int d2 = 2;
3074   constexpr int d3 = 2;
3075   Array4D<float> r4(d0, d1, d2, d3);
3076   r4.Fill(1.0);
3077   std::vector<float> r1(d1);
3078   std::iota(r1.begin(), r1.end(), 1.0);
3079 
3080   XlaBuilder builder(TestName());
3081   Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
3082       r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
3083   auto a = ConstantLiteral(&builder, a_literal);
3084   auto b = ConstantR1<float>(&builder, r1);
3085   Add(a, b, {1});
3086 
3087   for (int i0 = 0; i0 < d0; ++i0) {
3088     for (int i1 = 0; i1 < d1; ++i1) {
3089       for (int i2 = 0; i2 < d2; ++i2) {
3090         for (int i3 = 0; i3 < d3; ++i3) {
3091           r4(i0, i1, i2, i3) += r1[i1];
3092         }
3093       }
3094     }
3095   }
3096   ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_);
3097 }
3098 
3099 // Show that we can't add two opaques.
XLA_TEST_F(ArrayElementwiseOpTest,CannotAddOpaques)3100 XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
3101   XlaBuilder builder(TestName());
3102   auto shape = ShapeUtil::MakeOpaqueShape();
3103   auto x = Parameter(&builder, 0, shape, "x");
3104   Add(x, x);
3105   auto computation_status = builder.Build();
3106   ASSERT_FALSE(computation_status.ok());
3107   EXPECT_THAT(computation_status.status().ToString(),
3108               ::testing::ContainsRegex(
3109                   "Expected array argument for lhs of binary operation"));
3110 }
3111 
XLA_TEST_F(ArrayElementwiseOpTest,IdentityBroadcastOfSameRankIsAllowed)3112 XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
3113   XlaBuilder builder(TestName());
3114   auto a = ConstantR2<float>(&builder,
3115                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
3116   auto b = ConstantR2<float>(&builder,
3117                              {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
3118   Add(a, b, /*broadcast_dimensions=*/{0, 1});
3119 
3120   Array2D<float> expected_array(
3121       {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
3122   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
3123 }
3124 
XLA_TEST_F(ArrayElementwiseOpTest,NonIdentityBroadcastOfSameRankIsDisallowed)3125 XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
3126   XlaBuilder builder(TestName());
3127   auto a = ConstantR2<float>(&builder,
3128                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
3129   auto b = ConstantR2<float>(&builder,
3130                              {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
3131   Add(a, b, /*broadcast_dimensions=*/{1, 0});
3132 
3133   auto computation_status = builder.Build();
3134   ASSERT_FALSE(computation_status.ok());
3135   EXPECT_THAT(computation_status.status().error_message(),
3136               ::testing::ContainsRegex("must.*be the identity"));
3137 }
3138 
3139 // Regression test for b/31927799. "slice - y" is fused and requires implicit
3140 // broadcast.
XLA_TEST_F(ArrayElementwiseOpTest,ImplicitBroadcastInFusedExpressions)3141 XLA_TEST_F(ArrayElementwiseOpTest, ImplicitBroadcastInFusedExpressions) {
3142   XlaBuilder builder(TestName());
3143   auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
3144   auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
3145   auto x_data = client_->TransferToServer(x_literal).value();
3146   auto y_data = client_->TransferToServer(y_literal).value();
3147 
3148   auto x = Parameter(&builder, 0, x_literal.shape(), "x");
3149   auto y = Parameter(&builder, 1, y_literal.shape(), "y");
3150   auto slice = Slice(x, {1}, {2}, {1});
3151   Sub(slice, y);
3152 
3153   ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
3154                              error_spec_);
3155 }
3156 
3157 INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
3158                         ArrayElementwiseOpTestParamCount,
3159                         ::testing::Values(127, 128, 129, 17 * 4096));
3160 
3161 }  // namespace
3162 }  // namespace xla
3163