xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/tuple_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <initializer_list>
17 #include <memory>
18 
19 #include "tensorflow/compiler/xla/array2d.h"
20 #include "tensorflow/compiler/xla/client/local_client.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_parser.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace xla {
37 namespace {
38 
39 class TupleTest : public ClientLibraryTestBase {
40  public:
41   ErrorSpec error_spec_{0.0001};
42 };
43 
44 // Tests a tuple-shaped constant.
XLA_TEST_F(TupleTest,TupleConstant)45 XLA_TEST_F(TupleTest, TupleConstant) {
46   XlaBuilder builder(TestName());
47 
48   const float constant_scalar = 7.3f;
49   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
50   std::initializer_list<std::initializer_list<float>> constant_matrix = {
51       {1.1f, 2.2f, 3.5f},  // row 0
52       {4.8f, 5.0f, 6.7f},  // row 1
53   };
54   auto value = LiteralUtil::MakeTupleFromSlices(
55       {LiteralUtil::CreateR0<float>(constant_scalar),
56        LiteralUtil::CreateR1<float>(constant_vector),
57        LiteralUtil::CreateR2<float>(constant_matrix)});
58 
59   ConstantLiteral(&builder, value);
60   ComputeAndCompareTuple(&builder, value, {}, error_spec_);
61 }
62 
63 // Tests a tuple made of scalar constants.
XLA_TEST_F(TupleTest,TupleScalarConstant)64 XLA_TEST_F(TupleTest, TupleScalarConstant) {
65   XlaBuilder builder(TestName());
66 
67   const float constant_scalar1 = 7.3f;
68   const float constant_scalar2 = 1.2f;
69   auto value = LiteralUtil::MakeTupleFromSlices(
70       {LiteralUtil::CreateR0<float>(constant_scalar1),
71        LiteralUtil::CreateR0<float>(constant_scalar2)});
72 
73   ConstantLiteral(&builder, value);
74   ComputeAndCompareTuple(&builder, value, {}, error_spec_);
75 }
76 
77 // Tests the creation of tuple data.
XLA_TEST_F(TupleTest,TupleCreate)78 XLA_TEST_F(TupleTest, TupleCreate) {
79   XlaBuilder builder(TestName());
80 
81   const float constant_scalar = 7.3f;
82   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
83   std::initializer_list<std::initializer_list<float>> constant_matrix = {
84       {1.1f, 2.2f, 3.5f},  // row 0
85       {4.8f, 5.0f, 6.7f},  // row 1
86   };
87   Tuple(&builder, {ConstantR0<float>(&builder, constant_scalar),
88                    ConstantR1<float>(&builder, constant_vector),
89                    ConstantR2<float>(&builder, constant_matrix)});
90 
91   auto expected = LiteralUtil::MakeTupleFromSlices(
92       {LiteralUtil::CreateR0<float>(constant_scalar),
93        LiteralUtil::CreateR1<float>(constant_vector),
94        LiteralUtil::CreateR2<float>(constant_matrix)});
95   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
96 }
97 
98 // Tests the creation of tuple data.
XLA_TEST_F(TupleTest,TupleCreateWithZeroElementEntry)99 XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
100   XlaBuilder builder(TestName());
101 
102   Tuple(&builder,
103         {ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
104 
105   auto expected = LiteralUtil::MakeTupleFromSlices(
106       {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
107   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
108 }
109 
110 // Tests the creation of an empty tuple.
XLA_TEST_F(TupleTest,EmptyTupleCreate)111 XLA_TEST_F(TupleTest, EmptyTupleCreate) {
112   XlaBuilder builder(TestName());
113   Tuple(&builder, {});
114   auto expected = LiteralUtil::MakeTuple({});
115   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
116 }
117 
118 // Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest,GetTupleElement)119 XLA_TEST_F(TupleTest, GetTupleElement) {
120   XlaBuilder builder(TestName());
121   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
122   std::initializer_list<std::initializer_list<float>> constant_matrix = {
123       {1.f, 2.f, 3.f},  // row 0
124       {4.f, 5.f, 6.f},  // row 1
125   };
126   auto tuple_data =
127       Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
128                        ConstantR2<float>(&builder, constant_matrix)});
129   GetTupleElement(tuple_data, 1);
130   ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
131                              error_spec_);
132 }
133 
134 // Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest,GetTupleElementWithZeroElements)135 XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
136   XlaBuilder builder(TestName());
137   auto tuple_data =
138       Tuple(&builder,
139             {ConstantR1<float>(&builder, {}),
140              ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 101))});
141   GetTupleElement(tuple_data, 1);
142   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
143 }
144 
XLA_TEST_F(TupleTest,GetTupleElementOfNonTupleFailsGracefully)145 XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
146   XlaBuilder builder(TestName());
147   auto value = ConstantR1<float>(&builder, {4.5f});
148   GetTupleElement(value, 1);
149   auto result_status = builder.Build();
150   EXPECT_FALSE(result_status.ok());
151   EXPECT_THAT(
152       result_status.status().error_message(),
153       ::testing::HasSubstr("Operand to GetTupleElement() is not a tuple"));
154 }
155 
156 // Extracts both elements from a tuple with GetTupleElement and then adds them
157 // together.
XLA_TEST_F(TupleTest,AddTupleElements)158 XLA_TEST_F(TupleTest, AddTupleElements) {
159   XlaBuilder builder(TestName());
160   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
161   std::initializer_list<std::initializer_list<float>> constant_matrix = {
162       {1.f, 2.f, 3.f},  // row 0
163       {4.f, 5.f, 6.f},  // row 1
164   };
165   auto tuple_data =
166       Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
167                        ConstantR2<float>(&builder, constant_matrix)});
168   auto vector_element = GetTupleElement(tuple_data, 0);
169   auto matrix_element = GetTupleElement(tuple_data, 1);
170   auto vector_shape = builder.GetShape(vector_element).value();
171   auto matrix_shape = builder.GetShape(matrix_element).value();
172   Add(matrix_element, vector_element,
173       /*broadcast_dimensions=*/{1});
174 
175   Array2D<float> expected({
176       {2.f, 4.f, 6.f},  // row 0
177       {5.f, 7.f, 9.f},  // row 1
178   });
179   ASSERT_TRUE(ShapeUtil::Equal(vector_shape, ShapeUtil::MakeShape(F32, {3})));
180   ASSERT_TRUE(ShapeUtil::Equal(matrix_shape,
181                                ShapeUtil::MakeShape(F32, {/*y=*/2, /*x=*/3})));
182   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
183 }
184 
185 // Extracts both elements from a tuple and then puts them into a new tuple in
186 // the opposite order.
XLA_TEST_F(TupleTest,TupleGTEToTuple)187 XLA_TEST_F(TupleTest, TupleGTEToTuple) {
188   XlaBuilder builder(TestName());
189   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
190   std::initializer_list<std::initializer_list<float>> constant_matrix = {
191       {1.f, 2.f, 3.f},  // row 0
192       {4.f, 5.f, 6.f},  // row 1
193   };
194   auto tuple_data =
195       Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
196                        ConstantR2<float>(&builder, constant_matrix)});
197   Tuple(&builder,
198         {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
199   auto expected = LiteralUtil::MakeTupleFromSlices(
200       {LiteralUtil::CreateR2<float>(constant_matrix),
201        LiteralUtil::CreateR1<float>(constant_vector)});
202   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
203 }
204 
205 
206 // Builds two new tuples from an existing tuple (by means of GetTupleElement),
207 // then adds up the components of the new tuples.
XLA_TEST_F(TupleTest,TupleGTEToTupleToGTEAdd)208 XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
209   //
210   // v------           --(GTE 0)--             --(GTE 0)----------
211   //        \         /           \           /                   \
212   //         (tuple)--             (tuple01)--                     \
213   //        /   |     \           /           \                     \
214   // m------    |      --(GTE 1)--             --(GTE 1)------------ \
215   //            |                                                   \ \
216   //            |                                                    (add)
217   //            |                                                   / /
218   //            |--------(GTE 1)--             --(GTE 0)------------ /
219   //             \                \           /                     /
220   //              \                (tuple10)--                     /
221   //               \              /           \                   /
222   //                -----(GTE 0)--             --(GTE 1)----------
223   XlaBuilder builder(TestName());
224   std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
225   std::initializer_list<std::initializer_list<float>> constant_matrix = {
226       {1.f, 2.f, 3.f},  // row 0
227       {4.f, 5.f, 6.f},  // row 1
228   };
229   auto tuple_data =
230       Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
231                        ConstantR2<float>(&builder, constant_matrix)});
232   auto new_tuple01 = Tuple(&builder, {GetTupleElement(tuple_data, 0),
233                                       GetTupleElement(tuple_data, 1)});
234   auto new_tuple10 = Tuple(&builder, {GetTupleElement(tuple_data, 1),
235                                       GetTupleElement(tuple_data, 0)});
236   auto vector_from_01 = GetTupleElement(new_tuple01, 0);
237   auto vector_from_10 = GetTupleElement(new_tuple10, 1);
238   auto matrix_from_01 = GetTupleElement(new_tuple01, 1);
239   auto matrix_from_10 = GetTupleElement(new_tuple10, 0);
240 
241   auto addvectors = Add(vector_from_01, vector_from_10);
242   auto addmatrices = Add(matrix_from_01, matrix_from_10);
243 
244   Add(addmatrices, addvectors,
245       /*broadcast_dimensions=*/{1});
246 
247   Array2D<float> expected({
248       {4.f, 8.f, 12.f},    // row 0
249       {10.f, 14.f, 18.f},  // row 1
250   });
251   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
252 }
253 
XLA_TEST_F(TupleTest,NestedTuples)254 XLA_TEST_F(TupleTest, NestedTuples) {
255   XlaBuilder builder(TestName());
256   auto inner_tuple = Tuple(&builder, {ConstantR1<float>(&builder, {1.0, 2.0}),
257                                       ConstantR0<float>(&builder, 42.0)});
258   Tuple(&builder, {inner_tuple, ConstantR1<float>(&builder, {22.0, 44.0})});
259 
260   auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
261   auto expected_s = LiteralUtil::CreateR0<float>(42.0);
262   auto expected_inner_tuple =
263       LiteralUtil::MakeTuple({&expected_v1, &expected_s});
264   auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
265   auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
266 
267   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
268 }
269 
XLA_TEST_F(TupleTest,GetTupleElementOfNestedTuple)270 XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
271   XlaBuilder builder(TestName());
272 
273   Shape data_shape = ShapeUtil::MakeShape(F32, {3});
274   Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
275   Shape outer_tuple_shape =
276       ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape});
277 
278   auto input = Parameter(&builder, 0, outer_tuple_shape, "input");
279   auto gte0 = GetTupleElement(input, 0);
280   auto gte1 = GetTupleElement(gte0, 1);
281   Add(gte1, ConstantR1<float>(&builder, {10.0, 11.0, 12.0}));
282 
283   std::unique_ptr<GlobalData> data =
284       client_
285           ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
286               LiteralUtil::MakeTupleFromSlices({
287                   LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
288                   LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
289               }),
290               LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
291           }))
292           .value();
293 
294   std::vector<GlobalData*> arguments = {data.get()};
295   const std::vector<float> expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0};
296   ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
297 }
298 
XLA_TEST_F(TupleTest,ComplexTuples)299 XLA_TEST_F(TupleTest, ComplexTuples) {
300   XlaBuilder builder(TestName());
301   {
302     Shape c64r0 = ShapeUtil::MakeShape(C64, {});
303     Shape c64r1 = ShapeUtil::MakeShape(C64, {2});
304     Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2});
305     Shape arg0_shape = ShapeUtil::MakeTupleShape(
306         {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})});
307     auto input0 = Parameter(&builder, 0, arg0_shape, "input0");
308     auto t0 = GetTupleElement(input0, 0);
309     auto t1 = GetTupleElement(input0, 1);
310     auto t10 = GetTupleElement(t1, 0);
311     auto t11 = GetTupleElement(t1, 1);
312     auto sum = Add(Add(t10, t11, {1}), t0);
313     auto input1 = Parameter(&builder, 1, c64r1, "input1");
314     auto prod = Mul(input1, sum, {1});
315     Tuple(&builder, {Tuple(&builder, {prod, sum}),
316                      ConstantR0<complex64>(&builder, {123, 456})});
317   }
318 
319   std::unique_ptr<GlobalData> arg0 =
320       client_
321           ->TransferToServer(LiteralUtil::MakeTupleFromSlices(
322               {LiteralUtil::CreateR0<complex64>({1, 2}),
323                LiteralUtil::MakeTupleFromSlices(
324                    {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
325                     LiteralUtil::CreateR2<complex64>(
326                         {{{100, 200}, {300, 400}},
327                          {{1000, 2000}, {3000, 4000}},
328                          {{10000, 20000}, {30000, 40000}}})})}))
329           .value();
330   std::unique_ptr<GlobalData> arg1 =
331       client_
332           ->TransferToServer(
333               LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
334           .value();
335   auto sum =
336       LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
337                                         {{1011, 2022}, {3031, 4042}},
338                                         {{10011, 20022}, {30031, 40042}}});
339   Literal prod(sum.shape());
340   ASSERT_TRUE(prod.Populate<complex64>([&sum](
341                                            absl::Span<const int64_t> indexes) {
342                     return sum.Get<complex64>(indexes) *
343                            (indexes[indexes.size() - 1] == 0
344                                 ? complex64(1, 2)
345                                 : complex64(1, -2));
346                   })
347                   .ok());
348   auto expected = LiteralUtil::MakeTupleFromSlices(
349       {LiteralUtil::MakeTupleFromSlices({prod, sum}),
350        LiteralUtil::CreateR0<complex64>({123, 456})});
351   ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
352                          error_spec_);
353 }
354 
355 class TupleHloTest : public HloTestBase {};
356 
XLA_TEST_F(TupleHloTest,BadTupleShapeFailsGracefully)357 XLA_TEST_F(TupleHloTest, BadTupleShapeFailsGracefully) {
358   const char* testcase = R"(
359     HloModule m, is_scheduled=true
360 
361     ENTRY test {
362       parameter = f32[3]{0} parameter(0)
363       ROOT tuple = (f32[3]{0}, f32[3]{0}) tuple(parameter)
364     }
365   )";
366 
367   TF_ASSERT_OK_AND_ASSIGN(auto module,
368                           ParseAndReturnUnverifiedModule(testcase));
369   auto status = verifier().Run(module.get()).status();
370   EXPECT_FALSE(status.ok());
371   EXPECT_THAT(
372       status.error_message(),
373       ::testing::HasSubstr("Expected instruction to have shape equal to"));
374   EXPECT_THAT(status.error_message(), ::testing::HasSubstr("actual shape is"));
375 }
376 
XLA_TEST_F(TupleHloTest,BitcastAfterGTE)377 XLA_TEST_F(TupleHloTest, BitcastAfterGTE) {
378   const char* testcase = R"(
379     HloModule m, is_scheduled=true
380 
381     ENTRY test {
382       name.1 = (f32[3]{0}) parameter(0)
383       get-tuple-element.1 = f32[3]{0} get-tuple-element(name.1), index=0
384       bitcast = f32[1,3]{1,0} bitcast(get-tuple-element.1)
385       copy = f32[1,3]{1,0} copy(bitcast)
386       ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy)
387     }
388   )";
389   auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie();
390   auto param =
391       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
392   auto result = ExecuteNoHloPasses(std::move(module), {&param});
393   EXPECT_TRUE(LiteralTestUtil::Equal(
394       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
395       result));
396 }
397 
398 }  // namespace
399 }  // namespace xla
400