1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <initializer_list>
17 #include <memory>
18
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/client/local_client.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace xla {
37 namespace {
38
39 class TupleTest : public ClientLibraryTestBase {
40 public:
41 ErrorSpec error_spec_{0.0001};
42 };
43
44 // Tests a tuple-shaped constant.
XLA_TEST_F(TupleTest,TupleConstant)45 XLA_TEST_F(TupleTest, TupleConstant) {
46 XlaBuilder builder(TestName());
47
48 const float constant_scalar = 7.3f;
49 std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
50 std::initializer_list<std::initializer_list<float>> constant_matrix = {
51 {1.1f, 2.2f, 3.5f}, // row 0
52 {4.8f, 5.0f, 6.7f}, // row 1
53 };
54 auto value = LiteralUtil::MakeTupleFromSlices(
55 {LiteralUtil::CreateR0<float>(constant_scalar),
56 LiteralUtil::CreateR1<float>(constant_vector),
57 LiteralUtil::CreateR2<float>(constant_matrix)});
58
59 ConstantLiteral(&builder, value);
60 ComputeAndCompareTuple(&builder, value, {}, error_spec_);
61 }
62
63 // Tests a tuple made of scalar constants.
XLA_TEST_F(TupleTest,TupleScalarConstant)64 XLA_TEST_F(TupleTest, TupleScalarConstant) {
65 XlaBuilder builder(TestName());
66
67 const float constant_scalar1 = 7.3f;
68 const float constant_scalar2 = 1.2f;
69 auto value = LiteralUtil::MakeTupleFromSlices(
70 {LiteralUtil::CreateR0<float>(constant_scalar1),
71 LiteralUtil::CreateR0<float>(constant_scalar2)});
72
73 ConstantLiteral(&builder, value);
74 ComputeAndCompareTuple(&builder, value, {}, error_spec_);
75 }
76
77 // Tests the creation of tuple data.
XLA_TEST_F(TupleTest,TupleCreate)78 XLA_TEST_F(TupleTest, TupleCreate) {
79 XlaBuilder builder(TestName());
80
81 const float constant_scalar = 7.3f;
82 std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
83 std::initializer_list<std::initializer_list<float>> constant_matrix = {
84 {1.1f, 2.2f, 3.5f}, // row 0
85 {4.8f, 5.0f, 6.7f}, // row 1
86 };
87 Tuple(&builder, {ConstantR0<float>(&builder, constant_scalar),
88 ConstantR1<float>(&builder, constant_vector),
89 ConstantR2<float>(&builder, constant_matrix)});
90
91 auto expected = LiteralUtil::MakeTupleFromSlices(
92 {LiteralUtil::CreateR0<float>(constant_scalar),
93 LiteralUtil::CreateR1<float>(constant_vector),
94 LiteralUtil::CreateR2<float>(constant_matrix)});
95 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
96 }
97
98 // Tests the creation of tuple data.
XLA_TEST_F(TupleTest,TupleCreateWithZeroElementEntry)99 XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
100 XlaBuilder builder(TestName());
101
102 Tuple(&builder,
103 {ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
104
105 auto expected = LiteralUtil::MakeTupleFromSlices(
106 {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
107 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
108 }
109
110 // Tests the creation of an empty tuple.
XLA_TEST_F(TupleTest,EmptyTupleCreate)111 XLA_TEST_F(TupleTest, EmptyTupleCreate) {
112 XlaBuilder builder(TestName());
113 Tuple(&builder, {});
114 auto expected = LiteralUtil::MakeTuple({});
115 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
116 }
117
118 // Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest,GetTupleElement)119 XLA_TEST_F(TupleTest, GetTupleElement) {
120 XlaBuilder builder(TestName());
121 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
122 std::initializer_list<std::initializer_list<float>> constant_matrix = {
123 {1.f, 2.f, 3.f}, // row 0
124 {4.f, 5.f, 6.f}, // row 1
125 };
126 auto tuple_data =
127 Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
128 ConstantR2<float>(&builder, constant_matrix)});
129 GetTupleElement(tuple_data, 1);
130 ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
131 error_spec_);
132 }
133
134 // Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest,GetTupleElementWithZeroElements)135 XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
136 XlaBuilder builder(TestName());
137 auto tuple_data =
138 Tuple(&builder,
139 {ConstantR1<float>(&builder, {}),
140 ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 101))});
141 GetTupleElement(tuple_data, 1);
142 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
143 }
144
XLA_TEST_F(TupleTest,GetTupleElementOfNonTupleFailsGracefully)145 XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
146 XlaBuilder builder(TestName());
147 auto value = ConstantR1<float>(&builder, {4.5f});
148 GetTupleElement(value, 1);
149 auto result_status = builder.Build();
150 EXPECT_FALSE(result_status.ok());
151 EXPECT_THAT(
152 result_status.status().error_message(),
153 ::testing::HasSubstr("Operand to GetTupleElement() is not a tuple"));
154 }
155
156 // Extracts both elements from a tuple with GetTupleElement and then adds them
157 // together.
XLA_TEST_F(TupleTest,AddTupleElements)158 XLA_TEST_F(TupleTest, AddTupleElements) {
159 XlaBuilder builder(TestName());
160 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
161 std::initializer_list<std::initializer_list<float>> constant_matrix = {
162 {1.f, 2.f, 3.f}, // row 0
163 {4.f, 5.f, 6.f}, // row 1
164 };
165 auto tuple_data =
166 Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
167 ConstantR2<float>(&builder, constant_matrix)});
168 auto vector_element = GetTupleElement(tuple_data, 0);
169 auto matrix_element = GetTupleElement(tuple_data, 1);
170 auto vector_shape = builder.GetShape(vector_element).value();
171 auto matrix_shape = builder.GetShape(matrix_element).value();
172 Add(matrix_element, vector_element,
173 /*broadcast_dimensions=*/{1});
174
175 Array2D<float> expected({
176 {2.f, 4.f, 6.f}, // row 0
177 {5.f, 7.f, 9.f}, // row 1
178 });
179 ASSERT_TRUE(ShapeUtil::Equal(vector_shape, ShapeUtil::MakeShape(F32, {3})));
180 ASSERT_TRUE(ShapeUtil::Equal(matrix_shape,
181 ShapeUtil::MakeShape(F32, {/*y=*/2, /*x=*/3})));
182 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
183 }
184
185 // Extracts both elements from a tuple and then puts them into a new tuple in
186 // the opposite order.
XLA_TEST_F(TupleTest,TupleGTEToTuple)187 XLA_TEST_F(TupleTest, TupleGTEToTuple) {
188 XlaBuilder builder(TestName());
189 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
190 std::initializer_list<std::initializer_list<float>> constant_matrix = {
191 {1.f, 2.f, 3.f}, // row 0
192 {4.f, 5.f, 6.f}, // row 1
193 };
194 auto tuple_data =
195 Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
196 ConstantR2<float>(&builder, constant_matrix)});
197 Tuple(&builder,
198 {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
199 auto expected = LiteralUtil::MakeTupleFromSlices(
200 {LiteralUtil::CreateR2<float>(constant_matrix),
201 LiteralUtil::CreateR1<float>(constant_vector)});
202 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
203 }
204
205
206 // Builds two new tuples from an existing tuple (by means of GetTupleElement),
207 // then adds up the components of the new tuples.
XLA_TEST_F(TupleTest,TupleGTEToTupleToGTEAdd)208 XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
209 //
210 // v------ --(GTE 0)-- --(GTE 0)----------
211 // \ / \ / \
212 // (tuple)-- (tuple01)-- \
213 // / | \ / \ \
214 // m------ | --(GTE 1)-- --(GTE 1)------------ \
215 // | \ \
216 // | (add)
217 // | / /
218 // |--------(GTE 1)-- --(GTE 0)------------ /
219 // \ \ / /
220 // \ (tuple10)-- /
221 // \ / \ /
222 // -----(GTE 0)-- --(GTE 1)----------
223 XlaBuilder builder(TestName());
224 std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
225 std::initializer_list<std::initializer_list<float>> constant_matrix = {
226 {1.f, 2.f, 3.f}, // row 0
227 {4.f, 5.f, 6.f}, // row 1
228 };
229 auto tuple_data =
230 Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
231 ConstantR2<float>(&builder, constant_matrix)});
232 auto new_tuple01 = Tuple(&builder, {GetTupleElement(tuple_data, 0),
233 GetTupleElement(tuple_data, 1)});
234 auto new_tuple10 = Tuple(&builder, {GetTupleElement(tuple_data, 1),
235 GetTupleElement(tuple_data, 0)});
236 auto vector_from_01 = GetTupleElement(new_tuple01, 0);
237 auto vector_from_10 = GetTupleElement(new_tuple10, 1);
238 auto matrix_from_01 = GetTupleElement(new_tuple01, 1);
239 auto matrix_from_10 = GetTupleElement(new_tuple10, 0);
240
241 auto addvectors = Add(vector_from_01, vector_from_10);
242 auto addmatrices = Add(matrix_from_01, matrix_from_10);
243
244 Add(addmatrices, addvectors,
245 /*broadcast_dimensions=*/{1});
246
247 Array2D<float> expected({
248 {4.f, 8.f, 12.f}, // row 0
249 {10.f, 14.f, 18.f}, // row 1
250 });
251 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
252 }
253
XLA_TEST_F(TupleTest,NestedTuples)254 XLA_TEST_F(TupleTest, NestedTuples) {
255 XlaBuilder builder(TestName());
256 auto inner_tuple = Tuple(&builder, {ConstantR1<float>(&builder, {1.0, 2.0}),
257 ConstantR0<float>(&builder, 42.0)});
258 Tuple(&builder, {inner_tuple, ConstantR1<float>(&builder, {22.0, 44.0})});
259
260 auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
261 auto expected_s = LiteralUtil::CreateR0<float>(42.0);
262 auto expected_inner_tuple =
263 LiteralUtil::MakeTuple({&expected_v1, &expected_s});
264 auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
265 auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
266
267 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
268 }
269
XLA_TEST_F(TupleTest,GetTupleElementOfNestedTuple)270 XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
271 XlaBuilder builder(TestName());
272
273 Shape data_shape = ShapeUtil::MakeShape(F32, {3});
274 Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
275 Shape outer_tuple_shape =
276 ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape});
277
278 auto input = Parameter(&builder, 0, outer_tuple_shape, "input");
279 auto gte0 = GetTupleElement(input, 0);
280 auto gte1 = GetTupleElement(gte0, 1);
281 Add(gte1, ConstantR1<float>(&builder, {10.0, 11.0, 12.0}));
282
283 std::unique_ptr<GlobalData> data =
284 client_
285 ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
286 LiteralUtil::MakeTupleFromSlices({
287 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
288 LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
289 }),
290 LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
291 }))
292 .value();
293
294 std::vector<GlobalData*> arguments = {data.get()};
295 const std::vector<float> expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0};
296 ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
297 }
298
XLA_TEST_F(TupleTest,ComplexTuples)299 XLA_TEST_F(TupleTest, ComplexTuples) {
300 XlaBuilder builder(TestName());
301 {
302 Shape c64r0 = ShapeUtil::MakeShape(C64, {});
303 Shape c64r1 = ShapeUtil::MakeShape(C64, {2});
304 Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2});
305 Shape arg0_shape = ShapeUtil::MakeTupleShape(
306 {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})});
307 auto input0 = Parameter(&builder, 0, arg0_shape, "input0");
308 auto t0 = GetTupleElement(input0, 0);
309 auto t1 = GetTupleElement(input0, 1);
310 auto t10 = GetTupleElement(t1, 0);
311 auto t11 = GetTupleElement(t1, 1);
312 auto sum = Add(Add(t10, t11, {1}), t0);
313 auto input1 = Parameter(&builder, 1, c64r1, "input1");
314 auto prod = Mul(input1, sum, {1});
315 Tuple(&builder, {Tuple(&builder, {prod, sum}),
316 ConstantR0<complex64>(&builder, {123, 456})});
317 }
318
319 std::unique_ptr<GlobalData> arg0 =
320 client_
321 ->TransferToServer(LiteralUtil::MakeTupleFromSlices(
322 {LiteralUtil::CreateR0<complex64>({1, 2}),
323 LiteralUtil::MakeTupleFromSlices(
324 {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
325 LiteralUtil::CreateR2<complex64>(
326 {{{100, 200}, {300, 400}},
327 {{1000, 2000}, {3000, 4000}},
328 {{10000, 20000}, {30000, 40000}}})})}))
329 .value();
330 std::unique_ptr<GlobalData> arg1 =
331 client_
332 ->TransferToServer(
333 LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
334 .value();
335 auto sum =
336 LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
337 {{1011, 2022}, {3031, 4042}},
338 {{10011, 20022}, {30031, 40042}}});
339 Literal prod(sum.shape());
340 ASSERT_TRUE(prod.Populate<complex64>([&sum](
341 absl::Span<const int64_t> indexes) {
342 return sum.Get<complex64>(indexes) *
343 (indexes[indexes.size() - 1] == 0
344 ? complex64(1, 2)
345 : complex64(1, -2));
346 })
347 .ok());
348 auto expected = LiteralUtil::MakeTupleFromSlices(
349 {LiteralUtil::MakeTupleFromSlices({prod, sum}),
350 LiteralUtil::CreateR0<complex64>({123, 456})});
351 ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
352 error_spec_);
353 }
354
355 class TupleHloTest : public HloTestBase {};
356
XLA_TEST_F(TupleHloTest,BadTupleShapeFailsGracefully)357 XLA_TEST_F(TupleHloTest, BadTupleShapeFailsGracefully) {
358 const char* testcase = R"(
359 HloModule m, is_scheduled=true
360
361 ENTRY test {
362 parameter = f32[3]{0} parameter(0)
363 ROOT tuple = (f32[3]{0}, f32[3]{0}) tuple(parameter)
364 }
365 )";
366
367 TF_ASSERT_OK_AND_ASSIGN(auto module,
368 ParseAndReturnUnverifiedModule(testcase));
369 auto status = verifier().Run(module.get()).status();
370 EXPECT_FALSE(status.ok());
371 EXPECT_THAT(
372 status.error_message(),
373 ::testing::HasSubstr("Expected instruction to have shape equal to"));
374 EXPECT_THAT(status.error_message(), ::testing::HasSubstr("actual shape is"));
375 }
376
XLA_TEST_F(TupleHloTest,BitcastAfterGTE)377 XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
378 const char* testcase = R"(
379 HloModule m, is_scheduled=true
380
381 ENTRY test {
382 name.1 = (f32[3]{0}) parameter(0)
383 get-tuple-element.1 = f32[3]{0} get-tuple-element(name.1), index=0
384 bitcast = f32[1,3]{1,0} bitcast(get-tuple-element.1)
385 copy = f32[1,3]{1,0} copy(bitcast)
386 ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy)
387 }
388 )";
389 auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
390 auto param =
391 LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
392 auto result = ExecuteNoHloPasses(std::move(module), {¶m});
393 EXPECT_TRUE(LiteralTestUtil::Equal(
394 LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
395 result));
396 }
397
398 } // namespace
399 } // namespace xla
400