xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/map_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/client/global_data.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/test_helpers.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
37 
38 namespace xla {
39 namespace {
40 
41 class MapTest : public ClientLibraryTestBase {
42  public:
MapTest(se::Platform * platform=nullptr)43   explicit MapTest(se::Platform* platform = nullptr)
44       : ClientLibraryTestBase(platform) {
45     mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
46     mutable_debug_options()->add_xla_disable_hlo_passes("inline");
47   }
48 
49   // Creates a function that adds its scalar argument with the constant 1.0.
50   //
51   // x {R0F32} ----> (add)
52   //                /
53   // 1.0f ---------/
CreateAdderToOne()54   XlaComputation CreateAdderToOne() {
55     XlaBuilder mapped_builder(TestName());
56     auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
57     auto one = ConstantR0<float>(&mapped_builder, 1.0);
58     Add(x, one);
59     auto computation_status = mapped_builder.Build();
60     TF_CHECK_OK(computation_status.status());
61     return std::move(computation_status).value();
62   }
63 
CreateMax()64   XlaComputation CreateMax() {
65     XlaBuilder b(TestName());
66     auto lhs = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
67     auto rhs = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
68     Max(lhs, rhs);
69     auto computation_status = b.Build();
70     TF_CHECK_OK(computation_status.status());
71     return std::move(computation_status).value();
72   }
73 
74   // Creates a computation that accepts an F32 and returns T(1) (ignoring the
75   // argument).
76   template <class T>
CreateScalarOne()77   XlaComputation CreateScalarOne() {
78     XlaBuilder mapped_builder("scalar_one");
79     (void)Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
80     ConstantR0<T>(&mapped_builder, 1);
81     auto computation_status = mapped_builder.Build();
82     TF_CHECK_OK(computation_status.status());
83     return std::move(computation_status).value();
84   }
85 
86   // Creates a function that multiplies its scalar argument by the constant 2.0
87   //
88   // x {R0F32} ----> (mul)
89   //                /
90   // 2.0f ---------/
CreateMulByTwo()91   XlaComputation CreateMulByTwo() {
92     XlaBuilder mapped_builder(TestName());
93     auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
94     auto two = ConstantR0<float>(&mapped_builder, 2.0);
95     Mul(x, two);
96     auto computation_status = mapped_builder.Build();
97     TF_CHECK_OK(computation_status.status());
98     return std::move(computation_status).value();
99   }
100 
101   // Creates a function that adds its scalar argument with the constant 1.0 and
102   // then multiplies by the original element.
103   //
104   //           /------------------|
105   //          /                   |
106   // x {R0F32} ----> (add) ----> (mul)
107   //                /
108   // 1.0f ---------/
CreateAdderToOneTimesItself()109   XlaComputation CreateAdderToOneTimesItself() {
110     XlaBuilder mapped_builder(TestName());
111     auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
112     auto one = ConstantR0<float>(&mapped_builder, 1.0);
113     auto adder_to_one = Add(x, one);
114     Mul(x, adder_to_one);
115     auto computation_status = mapped_builder.Build();
116     TF_CHECK_OK(computation_status.status());
117     return std::move(computation_status).value();
118   }
119 
120   // Creates a function that takes a single parameter and calls map with
121   // "embedded_computation" on it, and then adds "n" to the result.
122   //
123   // x {R0F32} -----------> (map) ----> (add)
124   //                         /           /
125   // embedded_computation --/       n --/
CreateMapPlusN(const XlaComputation & embedded_computation,float n)126   XlaComputation CreateMapPlusN(const XlaComputation& embedded_computation,
127                                 float n) {
128     XlaBuilder builder(TestName());
129     auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
130     auto map = Map(&builder, {x}, embedded_computation, {});
131     auto constant_n = ConstantR0<float>(&builder, n);
132     Add(map, constant_n);
133     auto computation_status = builder.Build();
134     TF_CHECK_OK(computation_status.status());
135     return std::move(computation_status).value();
136   }
137 
138   // Creates a binary function with signature (F32, F32) -> Pred
139   // defined by (x, y) -> x > y.
CreateGt()140   XlaComputation CreateGt() {
141     XlaBuilder b("Gt");
142     auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
143     auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
144     Gt(x, y);
145     auto computation_status = b.Build();
146     TF_CHECK_OK(computation_status.status());
147     return std::move(computation_status).value();
148   }
149 
150   // Creates a function that adds three scalar arguments
151   //
152   // x {R0F32} -------|
153   //                  |
154   // y {R0F32} ----> (add) ---> (add)
155   //                           /
156   // z {R0F32} ---------------/
CreateTernaryAdder()157   XlaComputation CreateTernaryAdder() {
158     XlaBuilder mapped_builder("TernaryAdder");
159     auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
160     auto y = Parameter(&mapped_builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
161     auto z = Parameter(&mapped_builder, 2, ShapeUtil::MakeShape(F32, {}), "z");
162     auto xy = Add(x, y);
163     Add(xy, z);
164     auto computation_status = mapped_builder.Build();
165     TF_CHECK_OK(computation_status.status());
166     return std::move(computation_status).value();
167   }
168 };
169 
TEST_F(MapTest,MapEachElemPlusOneR0)170 TEST_F(MapTest, MapEachElemPlusOneR0) {
171   // Applies lambda (x) (+ x 1)) to an input scalar.
172   XlaBuilder builder(TestName());
173   Literal param0_literal = LiteralUtil::CreateR0<float>(42.0);
174   std::unique_ptr<GlobalData> param0_data =
175       client_->TransferToServer(param0_literal).value();
176 
177   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
178   Map(&builder, {param}, CreateAdderToOne(), {});
179 
180   ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
181                              ErrorSpec(0.01f));
182 }
183 
XLA_TEST_F(MapTest,MapEachElemPlusOneR1S0)184 XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
185   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
186   XlaBuilder builder(TestName());
187   Literal param0_literal = LiteralUtil::CreateR1<float>({});
188   std::unique_ptr<GlobalData> param0_data =
189       client_->TransferToServer(param0_literal).value();
190 
191   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
192   Map(&builder, {param}, CreateAdderToOne(), {0});
193 
194   ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
195                              ErrorSpec(0.01f));
196 }
197 
TEST_F(MapTest,MapEachElemPlusOneR1S4)198 TEST_F(MapTest, MapEachElemPlusOneR1S4) {
199   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
200   XlaBuilder builder(TestName());
201   Literal param0_literal =
202       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
203   std::unique_ptr<GlobalData> param0_data =
204       client_->TransferToServer(param0_literal).value();
205 
206   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
207   Map(&builder, {param}, CreateAdderToOne(), {0});
208 
209   ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
210                              {param0_data.get()}, ErrorSpec(0.01f));
211 }
212 
TEST_F(MapTest,MapEachF32ElementToS32Constant)213 TEST_F(MapTest, MapEachF32ElementToS32Constant) {
214   XlaBuilder builder(TestName());
215   Literal param0_literal =
216       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
217   std::unique_ptr<GlobalData> param0_data =
218       client_->TransferToServer(param0_literal).value();
219 
220   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
221   Map(&builder, {param}, CreateScalarOne<int32_t>(), {0});
222 
223   ComputeAndCompareR1<int32_t>(&builder, {1, 1, 1, 1}, {param0_data.get()});
224 }
225 
TEST_F(MapTest,MapEachF32ElementToU32Constant)226 TEST_F(MapTest, MapEachF32ElementToU32Constant) {
227   XlaBuilder builder(TestName());
228   Literal param0_literal =
229       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
230   std::unique_ptr<GlobalData> param0_data =
231       client_->TransferToServer(param0_literal).value();
232 
233   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
234   Map(&builder, {param}, CreateScalarOne<uint32_t>(), {0});
235 
236   ComputeAndCompareR1<uint32_t>(&builder, {1, 1, 1, 1}, {param0_data.get()});
237 }
238 
TEST_F(MapTest,MapEachElemLongerChainR1)239 TEST_F(MapTest, MapEachElemLongerChainR1) {
240   // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
241   XlaBuilder builder(TestName());
242   Literal param0_literal =
243       LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
244   std::unique_ptr<GlobalData> param0_data =
245       client_->TransferToServer(param0_literal).value();
246 
247   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
248   Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
249 
250   ComputeAndCompareR1<float>(
251       &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f},
252       {param0_data.get()}, ErrorSpec(0.01f));
253 }
254 
XLA_TEST_F(MapTest,MapMultipleMapsR1S0)255 XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
256   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
257   // maps (lambda (x) (* x 2)) on the result.
258   XlaBuilder builder(TestName());
259   Literal param0_literal = LiteralUtil::CreateR1<float>({});
260   std::unique_ptr<GlobalData> param0_data =
261       client_->TransferToServer(param0_literal).value();
262 
263   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
264   auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
265   Map(&builder, {map1}, CreateMulByTwo(), {0});
266 
267   ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
268                              ErrorSpec(0.01f));
269 }
270 
TEST_F(MapTest,MapMultipleMapsR1S4)271 TEST_F(MapTest, MapMultipleMapsR1S4) {
272   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
273   // maps (lambda (x) (* x 2)) on the result.
274   XlaBuilder builder(TestName());
275   Literal param0_literal =
276       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
277   std::unique_ptr<GlobalData> param0_data =
278       client_->TransferToServer(param0_literal).value();
279 
280   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
281   auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
282   Map(&builder, {map1}, CreateMulByTwo(), {0});
283 
284   ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f},
285                              {param0_data.get()}, ErrorSpec(0.01f));
286 }
287 
TEST_F(MapTest,MapEachElemPlusOneR2)288 TEST_F(MapTest, MapEachElemPlusOneR2) {
289   // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
290   XlaBuilder builder(TestName());
291   Literal param0_literal = LiteralUtil::CreateR2<float>(
292       {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
293   std::unique_ptr<GlobalData> param0_data =
294       client_->TransferToServer(param0_literal).value();
295 
296   auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
297   Map(&builder, {param}, CreateAdderToOne(), {0, 1});
298 
299   Array2D<float> expected_array(
300       {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}});
301   ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
302                              ErrorSpec(0.01f));
303 }
304 
XLA_TEST_F(MapTest,ComplexNestedMaps)305 XLA_TEST_F(MapTest, ComplexNestedMaps) {
306   // Constructs a complex graph of embedded computations to test the computation
307   // lowering order. Python equivalent:
308   //
309   //   embed1 = lambda x: x + 1                  #  x + 1
310   //   embed2 = lambda x: embed1(x) + 2          #  x + 3
311   //   embed3 = lambda x: embed1(x) + 4          #  x + 5
312   //   embed4 = lambda x: embed2(x) + embed3(x)  # 2x + 8
313   //   embed5 = lambda x: embed2(x) + 6          #  x + 9
314   //   result = embed5(42) + embed4(7)           # (42 + 9) + (2 * 7 + 8) = 73
315 
316   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
317 
318   auto embed1 = CreateAdderToOne();
319   auto embed2 = CreateMapPlusN(embed1, 2.0);
320   auto embed3 = CreateMapPlusN(embed1, 4.0);
321 
322   XlaBuilder embed4_builder("embed4");
323   auto embed4_param = Parameter(&embed4_builder, 0, scalar_shape, "x");
324   auto embed4_map_lhs = Map(&embed4_builder, {embed4_param}, embed2, {});
325   auto embed4_map_rhs = Map(&embed4_builder, {embed4_param}, embed3, {});
326   Add(embed4_map_lhs, embed4_map_rhs);
327   auto embed4_status = embed4_builder.Build();
328   ASSERT_IS_OK(embed4_status.status());
329   auto embed4 = std::move(embed4_status).value();
330 
331   auto embed5 = CreateMapPlusN(embed2, 6.0);
332 
333   XlaBuilder builder(TestName());
334   auto constant_42 = ConstantR0<float>(&builder, 42.0);
335   auto constant_7 = ConstantR0<float>(&builder, 7.0);
336   auto map_42 = Map(&builder, {constant_42}, embed5, {});
337   auto map_7 = Map(&builder, {constant_7}, embed4, {});
338   Add(map_42, map_7);
339 
340   ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
341 }
342 
TEST_F(MapTest,MapBinaryAdder)343 TEST_F(MapTest, MapBinaryAdder) {
344   // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
345   XlaBuilder builder(TestName());
346   Literal param0_literal =
347       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
348   std::unique_ptr<GlobalData> param0_data =
349       client_->TransferToServer(param0_literal).value();
350   Literal param1_literal =
351       LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
352   std::unique_ptr<GlobalData> param1_data =
353       client_->TransferToServer(param1_literal).value();
354 
355   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
356   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
357   Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
358       {0});
359 
360   ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0},
361                              {param0_data.get(), param1_data.get()},
362                              ErrorSpec(0.01f));
363 }
364 
365 // Adds two rank-2 arrays with different layouts. This test exercises a path
366 // for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest,AddWithMixedLayouts)367 XLA_TEST_F(MapTest, AddWithMixedLayouts) {
368   XlaBuilder builder(TestName());
369   Literal param0_literal = LiteralUtil::CreateR2WithLayout(
370       {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
371   std::unique_ptr<GlobalData> param0_data =
372       client_->TransferToServer(param0_literal).value();
373 
374   Literal param1_literal = LiteralUtil::CreateR2WithLayout(
375       {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
376   std::unique_ptr<GlobalData> param1_data =
377       client_->TransferToServer(param1_literal).value();
378 
379   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
380   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
381   Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
382       {0, 1});
383 
384   Array2D<int32_t> expected(2, 2);
385   expected(0, 0) = 11;
386   expected(0, 1) = 22;
387   expected(1, 0) = 33;
388   expected(1, 1) = 44;
389   ComputeAndCompareR2<int32_t>(&builder, expected,
390                                {param0_data.get(), param1_data.get()});
391 }
392 
XLA_TEST_F(MapTest,AddR3_3x0x2)393 XLA_TEST_F(MapTest, AddR3_3x0x2) {
394   XlaBuilder builder(TestName());
395   Literal param0_literal =
396       LiteralUtil::CreateR3FromArray3D<int32_t>(Array3D<int32_t>(3, 0, 2));
397   std::unique_ptr<GlobalData> param0_data =
398       client_->TransferToServer(param0_literal).value();
399 
400   Literal param1_literal =
401       LiteralUtil::CreateR3FromArray3D<int32_t>(Array3D<int32_t>(3, 0, 2));
402   std::unique_ptr<GlobalData> param1_data =
403       client_->TransferToServer(param1_literal).value();
404 
405   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
406   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
407   Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
408       {0, 1, 2});
409 
410   ComputeAndCompareR3<int32_t>(&builder, Array3D<int32_t>(3, 0, 2),
411                                {param0_data.get(), param1_data.get()});
412 }
413 
TEST_F(MapTest,MapTernaryAdder)414 TEST_F(MapTest, MapTernaryAdder) {
415   // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
416   XlaBuilder builder(TestName());
417   Literal param0_literal =
418       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
419   std::unique_ptr<GlobalData> param0_data =
420       client_->TransferToServer(param0_literal).value();
421   Literal param1_literal =
422       LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
423   std::unique_ptr<GlobalData> param1_data =
424       client_->TransferToServer(param1_literal).value();
425   Literal param2_literal =
426       LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
427   std::unique_ptr<GlobalData> param2_data =
428       client_->TransferToServer(param2_literal).value();
429 
430   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
431   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
432   auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2");
433   Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
434 
435   ComputeAndCompareR1<float>(
436       &builder, {-2.7f, -92.3f, -895.7f, -400.0f},
437       {param0_data.get(), param1_data.get(), param2_data.get()},
438       ErrorSpec(0.01f));
439 }
440 
TEST_F(MapTest,MapGt)441 TEST_F(MapTest, MapGt) {
442   // Maps (x,y) -> x > y onto two R1F32 vectors.
443   XlaBuilder b(TestName());
444   auto gt = CreateGt();
445   Map(&b, {ConstantR1<float>(&b, {1, 20}), ConstantR1<float>(&b, {10, 2})}, gt,
446       {0});
447   ComputeAndCompareR1<bool>(&b, {false, true}, {});
448 }
449 
TEST_F(MapTest,NestedBinaryMap)450 TEST_F(MapTest, NestedBinaryMap) {
451   XlaComputation max_with_square;
452   {
453     // max_with_square(x) = do max(x, x^2) via a map.
454     XlaBuilder b("max_with_square");
455     auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
456     Map(&b, {x, Mul(x, x)}, CreateMax(), {});
457     auto computation_status = b.Build();
458     ASSERT_IS_OK(computation_status.status());
459     max_with_square = std::move(computation_status).value();
460   }
461   XlaBuilder b(TestName());
462   auto input = ConstantR1<float>(&b, {0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
463   Map(&b, {input}, max_with_square, {0});
464   ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {});
465 }
466 
TEST_F(MapTest,MapOperationWithBuildError)467 TEST_F(MapTest, MapOperationWithBuildError) {
468   // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported
469   // type combination (F32 + U16) to test that the error is reported to the
470   // outermost XlaBuilder.
471   XlaBuilder builder(TestName());
472 
473   auto sub_builder = builder.CreateSubBuilder("ErrorAdd");
474   auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
475   auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(U16, {}), "y");
476   Add(x, y);
477   auto error_add = sub_builder->BuildAndNoteError();
478 
479   Literal param0_literal =
480       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
481   std::unique_ptr<GlobalData> param0_data =
482       client_->TransferToServer(param0_literal).value();
483   Literal param1_literal =
484       LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
485   std::unique_ptr<GlobalData> param1_data =
486       client_->TransferToServer(param1_literal).value();
487 
488   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
489   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
490   Map(&builder, {param0, param1}, error_add, {0});
491 
492   StatusOr<XlaComputation> computation_status = builder.Build();
493   ASSERT_TRUE(!computation_status.ok());
494   EXPECT_THAT(computation_status.status().ToString(),
495               ::testing::HasSubstr("error from: ErrorAdd: Binary op add with "
496                                    "different element types: f32[] and u16[]"));
497 }
498 
499 class MapHloTest : public HloTestBase {};
500 
501 // TODO(b/230123847): Enable this on GPU once mhlo allows mixed-type map.
XLA_TEST_F(MapHloTest,DISABLED_ON_GPU (MapWithMixedInputTypes))502 XLA_TEST_F(MapHloTest, DISABLED_ON_GPU(MapWithMixedInputTypes)) {
503   absl::string_view hlo_string = R"(
504   HloModule MapMixedInputTypes
505 
506   add {
507     op0 = f32[] parameter(0)
508     op1 = s32[] parameter(1)
509     cop1 = f32[] convert(op1)
510     ROOT result = f32[] add(op0, cop1)
511   }
512 
513   ENTRY main {
514     in0 = f32[10,3] parameter(0)
515     in1 = s32[10,3] parameter(1)
516 
517     ROOT out = f32[10,3] map(in0, in1), to_apply=add
518   }
519 )";
520 
521   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
522 }
523 
524 // MapTest disables inline and algsimp. MapTestWithFullOpt runs all
525 // optimizations.
526 using MapTestWithFullOpt = ClientLibraryTestBase;
527 
528 // Regression test for b/31466798. The inliner simplifies map(param0, param1,
529 // power) to power(param0, param1) without deleting the old subcomputation which
530 // is the same as the new entry computation. HloSubcomputationUnification used
531 // to have issues with such patterns and maybe invalidate the pointer to entry
532 // computation.
TEST_F(MapTestWithFullOpt,MapScalarPower)533 TEST_F(MapTestWithFullOpt, MapScalarPower) {
534   XlaBuilder builder(TestName());
535 
536   auto sub_builder = builder.CreateSubBuilder("power");
537   auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
538   auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y");
539   Pow(x, y);
540   auto power = sub_builder->BuildAndNoteError();
541 
542   Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
543   Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
544   std::unique_ptr<GlobalData> param0_data =
545       client_->TransferToServer(param0_literal).value();
546   std::unique_ptr<GlobalData> param1_data =
547       client_->TransferToServer(param1_literal).value();
548 
549   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
550   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
551   Map(&builder, {param0, param1}, power, {});
552 
553   ComputeAndCompareR0<float>(&builder, 32.0f,
554                              {param0_data.get(), param1_data.get()},
555                              ErrorSpec(0.01f));
556 }
557 
558 // Regression test for b/35786417, where the inliner would not notice the change
559 // of parameter order inside the map.
TEST_F(MapTestWithFullOpt,MapSubtractOppositeOrder)560 TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
561   XlaBuilder builder(TestName());
562 
563   auto sub_builder = builder.CreateSubBuilder("power");
564   auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
565   auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y");
566   Sub(y, x);  // note that this is y - x, not x - y
567   auto sub_opposite = sub_builder->BuildAndNoteError();
568 
569   Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
570   Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
571   std::unique_ptr<GlobalData> param0_data =
572       client_->TransferToServer(param0_literal).value();
573   std::unique_ptr<GlobalData> param1_data =
574       client_->TransferToServer(param1_literal).value();
575 
576   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
577   auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
578   Map(&builder, {param0, param1}, sub_opposite, {});
579 
580   ComputeAndCompareR0<float>(
581       &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f));
582 }
583 
584 // Regression test for b/35786417, where the inliner would CHECK-fail due to the
585 // mul inside the map having more parameters than the map does.
TEST_F(MapTestWithFullOpt,MapSquare)586 TEST_F(MapTestWithFullOpt, MapSquare) {
587   XlaBuilder builder(TestName());
588 
589   auto sub_builder = builder.CreateSubBuilder("power");
590   auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
591   Mul(x, x);
592   auto square = sub_builder->BuildAndNoteError();
593 
594   Literal param0_literal = LiteralUtil::CreateR0<float>(10.0f);
595   std::unique_ptr<GlobalData> param0_data =
596       client_->TransferToServer(param0_literal).value();
597 
598   auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
599   Map(&builder, {param0}, square, {});
600 
601   ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
602                              ErrorSpec(0.01f));
603 }
604 
605 }  // namespace
606 }  // namespace xla
607