1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/client/global_data.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/test_helpers.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
37
38 namespace xla {
39 namespace {
40
41 class MapTest : public ClientLibraryTestBase {
42 public:
MapTest(se::Platform * platform=nullptr)43 explicit MapTest(se::Platform* platform = nullptr)
44 : ClientLibraryTestBase(platform) {
45 mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
46 mutable_debug_options()->add_xla_disable_hlo_passes("inline");
47 }
48
49 // Creates a function that adds its scalar argument with the constant 1.0.
50 //
51 // x {R0F32} ----> (add)
52 // /
53 // 1.0f ---------/
CreateAdderToOne()54 XlaComputation CreateAdderToOne() {
55 XlaBuilder mapped_builder(TestName());
56 auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
57 auto one = ConstantR0<float>(&mapped_builder, 1.0);
58 Add(x, one);
59 auto computation_status = mapped_builder.Build();
60 TF_CHECK_OK(computation_status.status());
61 return std::move(computation_status).value();
62 }
63
CreateMax()64 XlaComputation CreateMax() {
65 XlaBuilder b(TestName());
66 auto lhs = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
67 auto rhs = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
68 Max(lhs, rhs);
69 auto computation_status = b.Build();
70 TF_CHECK_OK(computation_status.status());
71 return std::move(computation_status).value();
72 }
73
74 // Creates a computation that accepts an F32 and returns T(1) (ignoring the
75 // argument).
76 template <class T>
CreateScalarOne()77 XlaComputation CreateScalarOne() {
78 XlaBuilder mapped_builder("scalar_one");
79 (void)Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
80 ConstantR0<T>(&mapped_builder, 1);
81 auto computation_status = mapped_builder.Build();
82 TF_CHECK_OK(computation_status.status());
83 return std::move(computation_status).value();
84 }
85
86 // Creates a function that multiplies its scalar argument by the constant 2.0
87 //
88 // x {R0F32} ----> (mul)
89 // /
90 // 2.0f ---------/
CreateMulByTwo()91 XlaComputation CreateMulByTwo() {
92 XlaBuilder mapped_builder(TestName());
93 auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
94 auto two = ConstantR0<float>(&mapped_builder, 2.0);
95 Mul(x, two);
96 auto computation_status = mapped_builder.Build();
97 TF_CHECK_OK(computation_status.status());
98 return std::move(computation_status).value();
99 }
100
101 // Creates a function that adds its scalar argument with the constant 1.0 and
102 // then multiplies by the original element.
103 //
104 // /------------------|
105 // / |
106 // x {R0F32} ----> (add) ----> (mul)
107 // /
108 // 1.0f ---------/
CreateAdderToOneTimesItself()109 XlaComputation CreateAdderToOneTimesItself() {
110 XlaBuilder mapped_builder(TestName());
111 auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
112 auto one = ConstantR0<float>(&mapped_builder, 1.0);
113 auto adder_to_one = Add(x, one);
114 Mul(x, adder_to_one);
115 auto computation_status = mapped_builder.Build();
116 TF_CHECK_OK(computation_status.status());
117 return std::move(computation_status).value();
118 }
119
120 // Creates a function that takes a single parameter and calls map with
121 // "embedded_computation" on it, and then adds "n" to the result.
122 //
123 // x {R0F32} -----------> (map) ----> (add)
124 // / /
125 // embedded_computation --/ n --/
CreateMapPlusN(const XlaComputation & embedded_computation,float n)126 XlaComputation CreateMapPlusN(const XlaComputation& embedded_computation,
127 float n) {
128 XlaBuilder builder(TestName());
129 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
130 auto map = Map(&builder, {x}, embedded_computation, {});
131 auto constant_n = ConstantR0<float>(&builder, n);
132 Add(map, constant_n);
133 auto computation_status = builder.Build();
134 TF_CHECK_OK(computation_status.status());
135 return std::move(computation_status).value();
136 }
137
138 // Creates a binary function with signature (F32, F32) -> Pred
139 // defined by (x, y) -> x > y.
CreateGt()140 XlaComputation CreateGt() {
141 XlaBuilder b("Gt");
142 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
143 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
144 Gt(x, y);
145 auto computation_status = b.Build();
146 TF_CHECK_OK(computation_status.status());
147 return std::move(computation_status).value();
148 }
149
150 // Creates a function that adds three scalar arguments
151 //
152 // x {R0F32} -------|
153 // |
154 // y {R0F32} ----> (add) ---> (add)
155 // /
156 // z {R0F32} ---------------/
CreateTernaryAdder()157 XlaComputation CreateTernaryAdder() {
158 XlaBuilder mapped_builder("TernaryAdder");
159 auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
160 auto y = Parameter(&mapped_builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
161 auto z = Parameter(&mapped_builder, 2, ShapeUtil::MakeShape(F32, {}), "z");
162 auto xy = Add(x, y);
163 Add(xy, z);
164 auto computation_status = mapped_builder.Build();
165 TF_CHECK_OK(computation_status.status());
166 return std::move(computation_status).value();
167 }
168 };
169
TEST_F(MapTest,MapEachElemPlusOneR0)170 TEST_F(MapTest, MapEachElemPlusOneR0) {
171 // Applies lambda (x) (+ x 1)) to an input scalar.
172 XlaBuilder builder(TestName());
173 Literal param0_literal = LiteralUtil::CreateR0<float>(42.0);
174 std::unique_ptr<GlobalData> param0_data =
175 client_->TransferToServer(param0_literal).value();
176
177 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
178 Map(&builder, {param}, CreateAdderToOne(), {});
179
180 ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
181 ErrorSpec(0.01f));
182 }
183
XLA_TEST_F(MapTest,MapEachElemPlusOneR1S0)184 XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
185 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
186 XlaBuilder builder(TestName());
187 Literal param0_literal = LiteralUtil::CreateR1<float>({});
188 std::unique_ptr<GlobalData> param0_data =
189 client_->TransferToServer(param0_literal).value();
190
191 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
192 Map(&builder, {param}, CreateAdderToOne(), {0});
193
194 ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
195 ErrorSpec(0.01f));
196 }
197
TEST_F(MapTest,MapEachElemPlusOneR1S4)198 TEST_F(MapTest, MapEachElemPlusOneR1S4) {
199 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
200 XlaBuilder builder(TestName());
201 Literal param0_literal =
202 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
203 std::unique_ptr<GlobalData> param0_data =
204 client_->TransferToServer(param0_literal).value();
205
206 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
207 Map(&builder, {param}, CreateAdderToOne(), {0});
208
209 ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
210 {param0_data.get()}, ErrorSpec(0.01f));
211 }
212
TEST_F(MapTest,MapEachF32ElementToS32Constant)213 TEST_F(MapTest, MapEachF32ElementToS32Constant) {
214 XlaBuilder builder(TestName());
215 Literal param0_literal =
216 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
217 std::unique_ptr<GlobalData> param0_data =
218 client_->TransferToServer(param0_literal).value();
219
220 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
221 Map(&builder, {param}, CreateScalarOne<int32_t>(), {0});
222
223 ComputeAndCompareR1<int32_t>(&builder, {1, 1, 1, 1}, {param0_data.get()});
224 }
225
TEST_F(MapTest,MapEachF32ElementToU32Constant)226 TEST_F(MapTest, MapEachF32ElementToU32Constant) {
227 XlaBuilder builder(TestName());
228 Literal param0_literal =
229 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
230 std::unique_ptr<GlobalData> param0_data =
231 client_->TransferToServer(param0_literal).value();
232
233 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
234 Map(&builder, {param}, CreateScalarOne<uint32_t>(), {0});
235
236 ComputeAndCompareR1<uint32_t>(&builder, {1, 1, 1, 1}, {param0_data.get()});
237 }
238
TEST_F(MapTest,MapEachElemLongerChainR1)239 TEST_F(MapTest, MapEachElemLongerChainR1) {
240 // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
241 XlaBuilder builder(TestName());
242 Literal param0_literal =
243 LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
244 std::unique_ptr<GlobalData> param0_data =
245 client_->TransferToServer(param0_literal).value();
246
247 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
248 Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
249
250 ComputeAndCompareR1<float>(
251 &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f},
252 {param0_data.get()}, ErrorSpec(0.01f));
253 }
254
XLA_TEST_F(MapTest,MapMultipleMapsR1S0)255 XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
256 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
257 // maps (lambda (x) (* x 2)) on the result.
258 XlaBuilder builder(TestName());
259 Literal param0_literal = LiteralUtil::CreateR1<float>({});
260 std::unique_ptr<GlobalData> param0_data =
261 client_->TransferToServer(param0_literal).value();
262
263 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
264 auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
265 Map(&builder, {map1}, CreateMulByTwo(), {0});
266
267 ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
268 ErrorSpec(0.01f));
269 }
270
TEST_F(MapTest,MapMultipleMapsR1S4)271 TEST_F(MapTest, MapMultipleMapsR1S4) {
272 // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
273 // maps (lambda (x) (* x 2)) on the result.
274 XlaBuilder builder(TestName());
275 Literal param0_literal =
276 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
277 std::unique_ptr<GlobalData> param0_data =
278 client_->TransferToServer(param0_literal).value();
279
280 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
281 auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
282 Map(&builder, {map1}, CreateMulByTwo(), {0});
283
284 ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f},
285 {param0_data.get()}, ErrorSpec(0.01f));
286 }
287
TEST_F(MapTest,MapEachElemPlusOneR2)288 TEST_F(MapTest, MapEachElemPlusOneR2) {
289 // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
290 XlaBuilder builder(TestName());
291 Literal param0_literal = LiteralUtil::CreateR2<float>(
292 {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
293 std::unique_ptr<GlobalData> param0_data =
294 client_->TransferToServer(param0_literal).value();
295
296 auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
297 Map(&builder, {param}, CreateAdderToOne(), {0, 1});
298
299 Array2D<float> expected_array(
300 {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}});
301 ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
302 ErrorSpec(0.01f));
303 }
304
XLA_TEST_F(MapTest,ComplexNestedMaps)305 XLA_TEST_F(MapTest, ComplexNestedMaps) {
306 // Constructs a complex graph of embedded computations to test the computation
307 // lowering order. Python equivalent:
308 //
309 // embed1 = lambda x: x + 1 # x + 1
310 // embed2 = lambda x: embed1(x) + 2 # x + 3
311 // embed3 = lambda x: embed1(x) + 4 # x + 5
312 // embed4 = lambda x: embed2(x) + embed3(x) # 2x + 8
313 // embed5 = lambda x: embed2(x) + 6 # x + 9
314 // result = embed5(42) + embed4(7) # (42 + 9) + (2 * 7 + 8) = 73
315
316 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
317
318 auto embed1 = CreateAdderToOne();
319 auto embed2 = CreateMapPlusN(embed1, 2.0);
320 auto embed3 = CreateMapPlusN(embed1, 4.0);
321
322 XlaBuilder embed4_builder("embed4");
323 auto embed4_param = Parameter(&embed4_builder, 0, scalar_shape, "x");
324 auto embed4_map_lhs = Map(&embed4_builder, {embed4_param}, embed2, {});
325 auto embed4_map_rhs = Map(&embed4_builder, {embed4_param}, embed3, {});
326 Add(embed4_map_lhs, embed4_map_rhs);
327 auto embed4_status = embed4_builder.Build();
328 ASSERT_IS_OK(embed4_status.status());
329 auto embed4 = std::move(embed4_status).value();
330
331 auto embed5 = CreateMapPlusN(embed2, 6.0);
332
333 XlaBuilder builder(TestName());
334 auto constant_42 = ConstantR0<float>(&builder, 42.0);
335 auto constant_7 = ConstantR0<float>(&builder, 7.0);
336 auto map_42 = Map(&builder, {constant_42}, embed5, {});
337 auto map_7 = Map(&builder, {constant_7}, embed4, {});
338 Add(map_42, map_7);
339
340 ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
341 }
342
TEST_F(MapTest,MapBinaryAdder)343 TEST_F(MapTest, MapBinaryAdder) {
344 // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
345 XlaBuilder builder(TestName());
346 Literal param0_literal =
347 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
348 std::unique_ptr<GlobalData> param0_data =
349 client_->TransferToServer(param0_literal).value();
350 Literal param1_literal =
351 LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
352 std::unique_ptr<GlobalData> param1_data =
353 client_->TransferToServer(param1_literal).value();
354
355 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
356 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
357 Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
358 {0});
359
360 ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0},
361 {param0_data.get(), param1_data.get()},
362 ErrorSpec(0.01f));
363 }
364
365 // Adds two rank-2 arrays with different layouts. This test exercises a path
366 // for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest,AddWithMixedLayouts)367 XLA_TEST_F(MapTest, AddWithMixedLayouts) {
368 XlaBuilder builder(TestName());
369 Literal param0_literal = LiteralUtil::CreateR2WithLayout(
370 {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
371 std::unique_ptr<GlobalData> param0_data =
372 client_->TransferToServer(param0_literal).value();
373
374 Literal param1_literal = LiteralUtil::CreateR2WithLayout(
375 {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
376 std::unique_ptr<GlobalData> param1_data =
377 client_->TransferToServer(param1_literal).value();
378
379 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
380 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
381 Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
382 {0, 1});
383
384 Array2D<int32_t> expected(2, 2);
385 expected(0, 0) = 11;
386 expected(0, 1) = 22;
387 expected(1, 0) = 33;
388 expected(1, 1) = 44;
389 ComputeAndCompareR2<int32_t>(&builder, expected,
390 {param0_data.get(), param1_data.get()});
391 }
392
XLA_TEST_F(MapTest,AddR3_3x0x2)393 XLA_TEST_F(MapTest, AddR3_3x0x2) {
394 XlaBuilder builder(TestName());
395 Literal param0_literal =
396 LiteralUtil::CreateR3FromArray3D<int32_t>(Array3D<int32_t>(3, 0, 2));
397 std::unique_ptr<GlobalData> param0_data =
398 client_->TransferToServer(param0_literal).value();
399
400 Literal param1_literal =
401 LiteralUtil::CreateR3FromArray3D<int32_t>(Array3D<int32_t>(3, 0, 2));
402 std::unique_ptr<GlobalData> param1_data =
403 client_->TransferToServer(param1_literal).value();
404
405 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
406 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
407 Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
408 {0, 1, 2});
409
410 ComputeAndCompareR3<int32_t>(&builder, Array3D<int32_t>(3, 0, 2),
411 {param0_data.get(), param1_data.get()});
412 }
413
TEST_F(MapTest,MapTernaryAdder)414 TEST_F(MapTest, MapTernaryAdder) {
415 // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
416 XlaBuilder builder(TestName());
417 Literal param0_literal =
418 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
419 std::unique_ptr<GlobalData> param0_data =
420 client_->TransferToServer(param0_literal).value();
421 Literal param1_literal =
422 LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
423 std::unique_ptr<GlobalData> param1_data =
424 client_->TransferToServer(param1_literal).value();
425 Literal param2_literal =
426 LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
427 std::unique_ptr<GlobalData> param2_data =
428 client_->TransferToServer(param2_literal).value();
429
430 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
431 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
432 auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2");
433 Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
434
435 ComputeAndCompareR1<float>(
436 &builder, {-2.7f, -92.3f, -895.7f, -400.0f},
437 {param0_data.get(), param1_data.get(), param2_data.get()},
438 ErrorSpec(0.01f));
439 }
440
TEST_F(MapTest,MapGt)441 TEST_F(MapTest, MapGt) {
442 // Maps (x,y) -> x > y onto two R1F32 vectors.
443 XlaBuilder b(TestName());
444 auto gt = CreateGt();
445 Map(&b, {ConstantR1<float>(&b, {1, 20}), ConstantR1<float>(&b, {10, 2})}, gt,
446 {0});
447 ComputeAndCompareR1<bool>(&b, {false, true}, {});
448 }
449
TEST_F(MapTest,NestedBinaryMap)450 TEST_F(MapTest, NestedBinaryMap) {
451 XlaComputation max_with_square;
452 {
453 // max_with_square(x) = do max(x, x^2) via a map.
454 XlaBuilder b("max_with_square");
455 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
456 Map(&b, {x, Mul(x, x)}, CreateMax(), {});
457 auto computation_status = b.Build();
458 ASSERT_IS_OK(computation_status.status());
459 max_with_square = std::move(computation_status).value();
460 }
461 XlaBuilder b(TestName());
462 auto input = ConstantR1<float>(&b, {0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
463 Map(&b, {input}, max_with_square, {0});
464 ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {});
465 }
466
TEST_F(MapTest,MapOperationWithBuildError)467 TEST_F(MapTest, MapOperationWithBuildError) {
468 // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported
469 // type combination (F32 + U16) to test that the error is reported to the
470 // outermost XlaBuilder.
471 XlaBuilder builder(TestName());
472
473 auto sub_builder = builder.CreateSubBuilder("ErrorAdd");
474 auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
475 auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(U16, {}), "y");
476 Add(x, y);
477 auto error_add = sub_builder->BuildAndNoteError();
478
479 Literal param0_literal =
480 LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
481 std::unique_ptr<GlobalData> param0_data =
482 client_->TransferToServer(param0_literal).value();
483 Literal param1_literal =
484 LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
485 std::unique_ptr<GlobalData> param1_data =
486 client_->TransferToServer(param1_literal).value();
487
488 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
489 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
490 Map(&builder, {param0, param1}, error_add, {0});
491
492 StatusOr<XlaComputation> computation_status = builder.Build();
493 ASSERT_TRUE(!computation_status.ok());
494 EXPECT_THAT(computation_status.status().ToString(),
495 ::testing::HasSubstr("error from: ErrorAdd: Binary op add with "
496 "different element types: f32[] and u16[]"));
497 }
498
499 class MapHloTest : public HloTestBase {};
500
501 // TODO(b/230123847): Enable this on GPU once mhlo allows mixed-type map.
XLA_TEST_F(MapHloTest,DISABLED_ON_GPU (MapWithMixedInputTypes))502 XLA_TEST_F(MapHloTest, DISABLED_ON_GPU(MapWithMixedInputTypes)) {
503 absl::string_view hlo_string = R"(
504 HloModule MapMixedInputTypes
505
506 add {
507 op0 = f32[] parameter(0)
508 op1 = s32[] parameter(1)
509 cop1 = f32[] convert(op1)
510 ROOT result = f32[] add(op0, cop1)
511 }
512
513 ENTRY main {
514 in0 = f32[10,3] parameter(0)
515 in1 = s32[10,3] parameter(1)
516
517 ROOT out = f32[10,3] map(in0, in1), to_apply=add
518 }
519 )";
520
521 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-5, 1e-5}));
522 }
523
524 // MapTest disables inline and algsimp. MapTestWithFullOpt runs all
525 // optimizations.
526 using MapTestWithFullOpt = ClientLibraryTestBase;
527
528 // Regression test for b/31466798. The inliner simplifies map(param0, param1,
529 // power) to power(param0, param1) without deleting the old subcomputation which
530 // is the same as the new entry computation. HloSubcomputationUnification used
531 // to have issues with such patterns and maybe invalidate the pointer to entry
532 // computation.
TEST_F(MapTestWithFullOpt,MapScalarPower)533 TEST_F(MapTestWithFullOpt, MapScalarPower) {
534 XlaBuilder builder(TestName());
535
536 auto sub_builder = builder.CreateSubBuilder("power");
537 auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
538 auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y");
539 Pow(x, y);
540 auto power = sub_builder->BuildAndNoteError();
541
542 Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
543 Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
544 std::unique_ptr<GlobalData> param0_data =
545 client_->TransferToServer(param0_literal).value();
546 std::unique_ptr<GlobalData> param1_data =
547 client_->TransferToServer(param1_literal).value();
548
549 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
550 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
551 Map(&builder, {param0, param1}, power, {});
552
553 ComputeAndCompareR0<float>(&builder, 32.0f,
554 {param0_data.get(), param1_data.get()},
555 ErrorSpec(0.01f));
556 }
557
558 // Regression test for b/35786417, where the inliner would not notice the change
559 // of parameter order inside the map.
TEST_F(MapTestWithFullOpt,MapSubtractOppositeOrder)560 TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
561 XlaBuilder builder(TestName());
562
563 auto sub_builder = builder.CreateSubBuilder("power");
564 auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
565 auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y");
566 Sub(y, x); // note that this is y - x, not x - y
567 auto sub_opposite = sub_builder->BuildAndNoteError();
568
569 Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
570 Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
571 std::unique_ptr<GlobalData> param0_data =
572 client_->TransferToServer(param0_literal).value();
573 std::unique_ptr<GlobalData> param1_data =
574 client_->TransferToServer(param1_literal).value();
575
576 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
577 auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
578 Map(&builder, {param0, param1}, sub_opposite, {});
579
580 ComputeAndCompareR0<float>(
581 &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f));
582 }
583
584 // Regression test for b/35786417, where the inliner would CHECK-fail due to the
585 // mul inside the map having more parameters than the map does.
TEST_F(MapTestWithFullOpt,MapSquare)586 TEST_F(MapTestWithFullOpt, MapSquare) {
587 XlaBuilder builder(TestName());
588
589 auto sub_builder = builder.CreateSubBuilder("power");
590 auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
591 Mul(x, x);
592 auto square = sub_builder->BuildAndNoteError();
593
594 Literal param0_literal = LiteralUtil::CreateR0<float>(10.0f);
595 std::unique_ptr<GlobalData> param0_data =
596 client_->TransferToServer(param0_literal).value();
597
598 auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
599 Map(&builder, {param0}, square, {});
600
601 ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
602 ErrorSpec(0.01f));
603 }
604
605 } // namespace
606 } // namespace xla
607