1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/literal.h"
17
18 #include <limits>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/xla/array3d.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33
34 namespace xla {
35 namespace {
36
37 using ::testing::ElementsAre;
38 using ::testing::HasSubstr;
39
40 class LiteralUtilTest : public ::testing::Test {
41 protected:
LiteralUtilTest()42 LiteralUtilTest() {
43 Array4D<float> arr4d({
44 // clang-format off
45 { // i0=0
46 { // i1=0
47 {1, 2, 3}, // i2=0
48 {4, 5, 6}, // i2=1
49 {7, 8, 9}, // i2=2
50 },
51 { // i1=1
52 {11, 12, 13},
53 {14, 15, 16},
54 {17, 18, 19},
55 },
56 },
57 { // i0=1
58 { // i1=0
59 {101, 102, 103},
60 {104, 105, 106},
61 {107, 108, 109},
62 },
63 { // i1=1
64 {201, 202, 203}, // i2=0
65 {204, 205, 206}, // i2=1
66 {207, 208, 209}, // i2=2
67 },
68 },
69 // clang-format on
70 });
71
72 layout_r2_dim0major_ = LayoutUtil::MakeLayout({1, 0});
73 layout_r2_dim0minor_ = LayoutUtil::MakeLayout({0, 1});
74 layout_r3_dim0major_ = LayoutUtil::MakeLayout({2, 1, 0});
75 layout_r3_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2});
76 layout_r4_dim0major_ = LayoutUtil::MakeLayout({3, 2, 1, 0});
77 layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3});
78
79 literal_r4_2x2x3x3_dim0major_ =
80 LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
81 layout_r4_dim0major_);
82 literal_r4_2x2x3x3_dim0minor_ =
83 LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
84 layout_r4_dim0minor_);
85 }
86
87 Layout layout_r2_dim0major_;
88 Layout layout_r2_dim0minor_;
89 Layout layout_r3_dim0major_;
90 Layout layout_r3_dim0minor_;
91 Layout layout_r4_dim0major_;
92 Layout layout_r4_dim0minor_;
93 Literal literal_r4_2x2x3x3_dim0major_;
94 Literal literal_r4_2x2x3x3_dim0minor_;
95 };
96
TEST_F(LiteralUtilTest,LiteralScalarToString)97 TEST_F(LiteralUtilTest, LiteralScalarToString) {
98 auto true_lit = LiteralUtil::CreateR0<bool>(true);
99 EXPECT_EQ("pred[] true", true_lit.ToString());
100
101 auto false_lit = LiteralUtil::CreateR0<bool>(false);
102 EXPECT_EQ("pred[] false", false_lit.ToString());
103
104 auto u32_lit = LiteralUtil::CreateR0<uint32_t>(42);
105 EXPECT_EQ("u32[] 42", u32_lit.ToString());
106
107 auto s32_lit = LiteralUtil::CreateR0<int32_t>(-999);
108 EXPECT_EQ("s32[] -999", s32_lit.ToString());
109
110 auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
111 EXPECT_EQ("f32[] 3.14", f32_lit.ToString());
112
113 auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
114 EXPECT_EQ("f16[] 0.5", f16_lit.ToString());
115
116 auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
117 EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString());
118
119 auto c128_lit = LiteralUtil::CreateR0<complex128>({3.14, 2.78});
120 EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString());
121
122 auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
123 EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString());
124
125 // 3.14 will be rounded to 3.140625 in bfloat16 format.
126 auto bf16_lit_truncated =
127 LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
128 ASSERT_EQ("bf16[] 3.141", bf16_lit_truncated.ToString());
129
130 auto bf16_lit_truncated2 =
131 LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
132 EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString());
133 }
134
TEST_F(LiteralUtilTest,LiteralVectorToString)135 TEST_F(LiteralUtilTest, LiteralVectorToString) {
136 auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
137 EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString());
138 }
139
TEST_F(LiteralUtilTest,R2ToString)140 TEST_F(LiteralUtilTest, R2ToString) {
141 const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
142 const std::string expected = R"(s32[3,2] {
143 { 1, 2 },
144 { 3, 4 },
145 { 5, 6 }
146 })";
147 EXPECT_EQ(expected, literal.ToString());
148 }
149
TEST_F(LiteralUtilTest,R2DynamicToString)150 TEST_F(LiteralUtilTest, R2DynamicToString) {
151 auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
152 literal.SetDynamicSize(0, {}, 2);
153 const std::string expected = R"(s32[<=3,2](2,2) {
154 { 1, 2 },
155 { 3, 4 }
156 })";
157 EXPECT_EQ(expected, literal.ToString());
158
159 // A Less trivial case where the memory layout is not consecutive.
160 auto literal2 = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}});
161 literal2.SetDynamicSize(1, {}, 2);
162 const std::string expected2 = R"(s32[2,<=3](2,2) {
163 { 1, 2 },
164 { 4, 5 }
165 })";
166 EXPECT_EQ(expected2, literal2.ToString());
167 }
168
TEST_F(LiteralUtilTest,R3ToString)169 TEST_F(LiteralUtilTest, R3ToString) {
170 const auto literal =
171 LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
172 const std::string expected = R"(s32[3,2,1] {
173 {
174 {1},
175 {2}
176 },
177 {
178 {3},
179 {4}
180 },
181 {
182 {5},
183 {6}
184 }
185 })";
186 EXPECT_EQ(expected, literal.ToString());
187 }
188
TEST_F(LiteralUtilTest,R6ToString)189 TEST_F(LiteralUtilTest, R6ToString) {
190 const auto literal =
191 LiteralUtil::CreateFromDimensions(S32, {2, 2, 1, 1, 1, 2});
192 const std::string expected = R"(s32[2,2,1,1,1,2] {
193 { /*i0=0*/
194 { /*i1=0*/
195 { /*i2=0*/
196 { /*i3=0*/
197 { 0, 0 }
198 }
199 }
200 },
201 { /*i1=1*/
202 { /*i2=0*/
203 { /*i3=0*/
204 { 0, 0 }
205 }
206 }
207 }
208 },
209 { /*i0=1*/
210 { /*i1=0*/
211 { /*i2=0*/
212 { /*i3=0*/
213 { 0, 0 }
214 }
215 }
216 },
217 { /*i1=1*/
218 { /*i2=0*/
219 { /*i3=0*/
220 { 0, 0 }
221 }
222 }
223 }
224 }
225 })";
226 EXPECT_EQ(expected, literal.ToString());
227 }
228
TEST_F(LiteralUtilTest,TupleToString)229 TEST_F(LiteralUtilTest, TupleToString) {
230 auto scalar = LiteralUtil::CreateR0<float>(1.0);
231 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
232 auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
233 const std::string expected = R"((
234 f32[] 1,
235 f32[2,2] {
236 { 1, 2 },
237 { 3, 4 }
238 }
239 ))";
240 EXPECT_EQ(expected, tuple.ToString());
241 }
242
TEST_F(LiteralUtilTest,CreateR3FromArray3d)243 TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
244 // clang-format off
245 Array3D<float> array_3d({
246 {{1.0f, 2.0f},
247 {3.0f, 4.0f},
248 {5.0f, 6.0f}},
249 {{7.0f, 8.0f},
250 {9.0f, 10.0f},
251 {11.0f, 12.0f}},
252 });
253 // clang-format on
254
255 auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
256 EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
257 std::string result = literal.ToString();
258 const std::string expected = R"(f32[2,3,2] {
259 {
260 { 1, 2 },
261 { 3, 4 },
262 { 5, 6 }
263 },
264 {
265 { 7, 8 },
266 { 9, 10 },
267 { 11, 12 }
268 }
269 })";
270 EXPECT_EQ(expected, result);
271 }
272
TEST_F(LiteralUtilTest,LiteralR4F32ProjectedStringifies)273 TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
274 // clang-format off
275 auto literal = LiteralUtil::CreateR4Projected<float>({
276 {1, 2},
277 {1001, 1002},
278 {2001, 2002},
279 }, /*projection_p=*/1, /*projection_z=*/2);
280 // clang-format on
281 EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
282 std::string result = literal.ToString();
283 const std::string expected = R"(f32[1,2,3,2] {
284 { /*i0=0*/
285 { /*i1=0*/
286 { 1, 2 },
287 { 1001, 1002 },
288 { 2001, 2002 }
289 },
290 { /*i1=1*/
291 { 1, 2 },
292 { 1001, 1002 },
293 { 2001, 2002 }
294 }
295 }
296 })";
297 EXPECT_EQ(expected, result);
298 }
299
TEST_F(LiteralUtilTest,LiteralR4F32Stringifies)300 TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
301 EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
302 ElementsAre(2, 2, 3, 3));
303 std::string result = literal_r4_2x2x3x3_dim0major_.ToString();
304 const std::string expected = R"(f32[2,2,3,3] {
305 { /*i0=0*/
306 { /*i1=0*/
307 { 1, 2, 3 },
308 { 4, 5, 6 },
309 { 7, 8, 9 }
310 },
311 { /*i1=1*/
312 { 11, 12, 13 },
313 { 14, 15, 16 },
314 { 17, 18, 19 }
315 }
316 },
317 { /*i0=1*/
318 { /*i1=0*/
319 { 101, 102, 103 },
320 { 104, 105, 106 },
321 { 107, 108, 109 }
322 },
323 { /*i1=1*/
324 { 201, 202, 203 },
325 { 204, 205, 206 },
326 { 207, 208, 209 }
327 }
328 }
329 })";
330 EXPECT_EQ(expected, result);
331 }
332
TEST_F(LiteralUtilTest,EachCellR2F32)333 TEST_F(LiteralUtilTest, EachCellR2F32) {
334 // clang-format off
335 auto literal = LiteralUtil::CreateR2<float>({
336 {3.1f, 4.2f},
337 {9.3f, 12.4f},
338 });
339 // clang-format on
340 std::vector<std::tuple<int64_t, int64_t, std::string>> seen;
341 literal.EachCellAsString(
342 [&seen](absl::Span<const int64_t> indices, const std::string& value) {
343 seen.emplace_back(indices[0], indices[1], value);
344 });
345
346 using Elem = std::tuple<int64_t, int64_t, std::string>;
347 std::vector<Elem> expected = {Elem(0, 0, "3.1"), Elem(0, 1, "4.2"),
348 Elem(1, 0, "9.3"), Elem(1, 1, "12.4")};
349 EXPECT_EQ(expected, seen);
350 }
351
TEST_F(LiteralUtilTest,ScalarEquality)352 TEST_F(LiteralUtilTest, ScalarEquality) {
353 // Test equality with scalars.
354 auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
355 auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
356
357 EXPECT_EQ(f32_42, f32_42);
358 EXPECT_EQ(f32_42, f32_42_clone);
359
360 auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
361 EXPECT_NE(f32_42, f32_123);
362
363 auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
364 EXPECT_NE(f32_42, f64_42);
365 }
366
TEST_F(LiteralUtilTest,NonScalarEquality)367 TEST_F(LiteralUtilTest, NonScalarEquality) {
368 // Test equality with nonscalars.
369 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
370 auto matrix_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
371 auto matrix_different =
372 LiteralUtil::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
373 auto vector_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
374 auto scalar = LiteralUtil::CreateR0<float>(1.0);
375 Literal nil(ShapeUtil::MakeNil());
376
377 EXPECT_EQ(matrix, matrix);
378 EXPECT_EQ(matrix, matrix_clone);
379 EXPECT_NE(matrix, matrix_different);
380 EXPECT_NE(matrix, vector_literal);
381 EXPECT_NE(matrix, scalar);
382 EXPECT_NE(matrix, nil);
383 EXPECT_EQ(nil, nil);
384 }
385
TEST_F(LiteralUtilTest,TokenEquality)386 TEST_F(LiteralUtilTest, TokenEquality) {
387 auto token0 = LiteralUtil::CreateToken();
388 auto token1 = LiteralUtil::CreateToken();
389 auto scalar = LiteralUtil::CreateR0<float>(1.0);
390
391 EXPECT_EQ(token0, token1);
392 EXPECT_NE(token0, scalar);
393
394 EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
395 LiteralUtil::MakeTuple({&token0}));
396 EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
397 LiteralUtil::MakeTuple({&token1, &scalar}));
398 EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
399 LiteralUtil::MakeTuple({&scalar, &token1}));
400 }
401
TEST_F(LiteralUtilTest,DifferentLayoutEquality)402 TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
403 // Test equality with literals which have different layouts.
404 Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
405 colmajor.Set<float>({0, 0}, 1.0);
406 colmajor.Set<float>({0, 1}, 2.0);
407 colmajor.Set<float>({1, 0}, 3.0);
408 colmajor.Set<float>({1, 1}, 4.0);
409
410 Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
411 rowmajor.Set<float>({0, 0}, 1.0);
412 rowmajor.Set<float>({0, 1}, 2.0);
413 rowmajor.Set<float>({1, 0}, 3.0);
414 rowmajor.Set<float>({1, 1}, 4.0);
415
416 EXPECT_EQ(rowmajor, colmajor);
417 }
418
TEST_F(LiteralUtilTest,TupleEquality)419 TEST_F(LiteralUtilTest, TupleEquality) {
420 // Test equality with tuples.
421 auto scalar = LiteralUtil::CreateR0<float>(1.0);
422 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
423 auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
424
425 // Tuple with the same elements. One element is shared with the original
426 // tuple, the other is a clone of the element in the original tuple.
427 auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
428 auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
429 EXPECT_EQ(tuple1, tuple2);
430
431 // Tuple with elements reversed.
432 auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
433 EXPECT_NE(tuple1, reversed_tuple);
434
435 // Tuple with different value.
436 auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
437 auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
438 EXPECT_NE(tuple1, different_tuple);
439 }
440
TEST_F(LiteralUtilTest,DynamicShapeEquality)441 TEST_F(LiteralUtilTest, DynamicShapeEquality) {
442 // Test equality with tuples.
443 auto r1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
444 r1.SetDynamicSize(0, {}, 1);
445 auto r2 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
446 r2.SetDynamicSize(0, {}, 1);
447 auto tuple1 = LiteralUtil::MakeTuple({&r1, &r2});
448
449 // Tuple with the same elements. One element is shared with the original
450 // tuple, the other is a clone of the element in the original tuple.
451 auto r1_clone = LiteralUtil::CreateR1<float>({1.0, 3.0});
452 r1_clone.SetDynamicSize(0, {}, 1);
453 auto tuple2 = LiteralUtil::MakeTuple({&r1_clone, &r2});
454 EXPECT_EQ(tuple1, tuple2);
455
456 // Tuple with different dynamic sizes.
457 auto r2_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
458 r2_clone.SetDynamicSize(0, {}, 2);
459 auto tuple_3 = LiteralUtil::MakeTuple({&r1_clone, &r2_clone});
460 EXPECT_NE(tuple1, tuple_3);
461 }
462
TEST_F(LiteralUtilTest,C64Equality)463 TEST_F(LiteralUtilTest, C64Equality) {
464 // Test equality with tuples.
465 auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
466
467 // Tuple with the same elements. One element is shared with the original
468 // tuple, the other is a clone of the element in the original tuple.
469 auto vector_clone =
470 LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
471 EXPECT_EQ(vector, vector_clone);
472
473 auto vector_reversed =
474 LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
475 EXPECT_NE(vector, vector_reversed);
476 }
477
TEST_F(LiteralUtilTest,C128Equality)478 TEST_F(LiteralUtilTest, C128Equality) {
479 // Test equality with tuples.
480 auto vector = LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
481
482 // Tuple with the same elements. One element is shared with the original
483 // tuple, the other is a clone of the element in the original tuple.
484 auto vector_clone =
485 LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
486 EXPECT_EQ(vector, vector_clone);
487
488 auto vector_reversed =
489 LiteralUtil::CreateR1<complex128>({{3.0, 4.0}, {1.0, 2.0}});
490 EXPECT_NE(vector, vector_reversed);
491 }
492
TEST_F(LiteralUtilTest,IsAllTuple)493 TEST_F(LiteralUtilTest, IsAllTuple) {
494 auto element1 = LiteralUtil::CreateR0<float>(0.0);
495 auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
496 auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
497
498 // Tuples should always return false for IsAll.
499 EXPECT_FALSE(tuple.IsAll(0));
500 EXPECT_FALSE(tuple.IsAll(1));
501 }
502
503 // Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest,CreateFromShapeTuple)504 TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
505 auto scalar = LiteralUtil::CreateR0<float>(0.0);
506 auto matrix = LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}});
507 auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
508
509 auto x = Literal::CreateFromShape(tuple.shape());
510 EXPECT_EQ(tuple, x);
511 }
512
TEST_F(LiteralUtilTest,IsAll)513 TEST_F(LiteralUtilTest, IsAll) {
514 EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
515 EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
516 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
517 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
518 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
519 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
520 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
521
522 // We shouldn't reinterpret int8_min as an unsigned type and then decide that
523 // it is equal to 255.
524 auto int8_min = std::numeric_limits<int8_t>::min();
525 EXPECT_FALSE(LiteralUtil::CreateR0<uint8_t>(255).IsAll(int8_min));
526
527 EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
528 EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
529
530 EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
531 EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
532
533 EXPECT_TRUE(LiteralUtil::CreateR2<uint64_t>({{8, 8}, {8, 8}}).IsAll(8));
534 EXPECT_FALSE(LiteralUtil::CreateR2<uint64_t>({{8, 8}, {8, 9}}).IsAll(8));
535 EXPECT_FALSE(LiteralUtil::CreateR2<uint64_t>({{9, 8}, {8, 8}}).IsAll(8));
536
537 half h8(8.0f);
538 half h9(9.0f);
539 EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
540 EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
541 EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
542
543 bfloat16 b8(8.0f);
544 bfloat16 b9(9.0f);
545
546 EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
547 EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
548 EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
549
550 // 9.001 will be truncated to 9.0
551 bfloat16 b91(9.001f);
552 bfloat16 b90(9.00f);
553 EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
554
555 complex64 c8_9 = {8, 9};
556 EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
557
558 auto uint64_max = std::numeric_limits<uint64_t>::max();
559 EXPECT_FALSE(LiteralUtil::CreateR2<uint64_t>(
560 {{uint64_max, uint64_max}, {uint64_max, uint64_max}})
561 .IsAll(-1));
562 }
563
TEST_F(LiteralUtilTest,IsAllFloat)564 TEST_F(LiteralUtilTest, IsAllFloat) {
565 // IsAllFloat always returns false when the literal is not floating-point.
566 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
567 EXPECT_FALSE(LiteralUtil::CreateR0<int8_t>(0).IsAllFloat(0));
568 EXPECT_FALSE(LiteralUtil::CreateR0<uint8_t>(0).IsAllFloat(0));
569 EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
570
571 EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
572 EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
573 EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
574 EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
575 EXPECT_FALSE(
576 LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
577 EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
578 .IsAllFloat(.5));
579
580 EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
581 EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
582 EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
583 EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
584 EXPECT_FALSE(
585 LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
586 }
587
TEST_F(LiteralUtilTest,IsAllComplex)588 TEST_F(LiteralUtilTest, IsAllComplex) {
589 // IsAllComplex always returns false when the literal is not complex.
590 EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
591 EXPECT_FALSE(LiteralUtil::CreateR0<int8_t>(0).IsAllComplex(0));
592 EXPECT_FALSE(LiteralUtil::CreateR0<uint8_t>(0).IsAllComplex(0));
593 EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
594 EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
595 EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
596
597 complex64 c8_9 = {8, 9};
598 complex64 c7_9 = {7, 9};
599 EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
600 .IsAllComplex({8.0f, 9.0f}));
601 EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
602 .IsAllComplex({8.0f, 9.0f}));
603 EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
604 .IsAllComplex({8.0f, 9.0f}));
605 }
606
TEST_F(LiteralUtilTest,IsAllFirst)607 TEST_F(LiteralUtilTest, IsAllFirst) {
608 // IsAllComplex always returns false when the literal is not complex.
609 EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
610 EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
611 EXPECT_FALSE(LiteralUtil::CreateR1<int8_t>({1, 1, 2}).IsAllFirst());
612 EXPECT_TRUE(LiteralUtil::CreateR1<int8_t>({5, 5, 5, 5}).IsAllFirst());
613 EXPECT_FALSE(LiteralUtil::CreateR1<uint8_t>({1, 1, 2}).IsAllFirst());
614 EXPECT_TRUE(LiteralUtil::CreateR1<int32_t>({5, 5, 5, 5}).IsAllFirst());
615 EXPECT_FALSE(LiteralUtil::CreateR1<int32_t>({1, 1, 2}).IsAllFirst());
616 EXPECT_TRUE(LiteralUtil::CreateR1<uint32_t>({5, 5, 5, 5}).IsAllFirst());
617 EXPECT_FALSE(LiteralUtil::CreateR1<uint32_t>({1, 1, 2}).IsAllFirst());
618
619 complex64 c8_9 = {8, 9};
620 complex64 c7_9 = {7, 9};
621 EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
622 EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
623 }
624
TEST_F(LiteralUtilTest,IsZero)625 TEST_F(LiteralUtilTest, IsZero) {
626 auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
627 auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
628 EXPECT_TRUE(scalar_zero.IsZero({}));
629 EXPECT_FALSE(scalar_one.IsZero({}));
630
631 auto array = LiteralUtil::CreateR2<uint32_t>({{1, 2, 0, 3}, {1, 0, 1, 2}});
632 EXPECT_FALSE(array.IsZero({0, 1}));
633 EXPECT_TRUE(array.IsZero({0, 2}));
634 EXPECT_TRUE(array.IsZero({1, 1}));
635 EXPECT_FALSE(array.IsZero({1, 2}));
636
637 auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
638 auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
639 EXPECT_TRUE(complex_zero.IsZero({}));
640 EXPECT_FALSE(complex_nonzero.IsZero({}));
641 }
642
643 template <typename T>
644 class LiteralUtilTestTemplated : public ::testing::Test {};
645
646 using TestedTypes = ::testing::Types<float, int32_t, uint32_t, complex64>;
647 TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes);
648
TYPED_TEST(LiteralUtilTestTemplated,Relayout2x2)649 TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
650 // Make a non-integer for floating point types.
651 TypeParam half = TypeParam(1) / TypeParam(2);
652 auto data = LiteralUtil::CreateR2<TypeParam>({{half, 2}, {3, 4}});
653 const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
654 const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
655
656 auto data01 = data.Relayout(layout01);
657 EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
658 EXPECT_EQ(data, data01);
659
660 auto data10 = data.Relayout(layout10);
661 EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
662 EXPECT_EQ(data, data10);
663 }
664
TEST_F(LiteralUtilTest,ReshapeR0)665 TEST_F(LiteralUtilTest, ReshapeR0) {
666 auto original = LiteralUtil::CreateR0<float>(1.7f);
667 auto reshape = original.Reshape(/*dimensions=*/{}).value();
668 EXPECT_EQ(original, reshape);
669 }
670
TEST_F(LiteralUtilTest,ReshapeR4)671 TEST_F(LiteralUtilTest, ReshapeR4) {
672 // clang-format off
673 // F32[1x3x2x4]
674 auto original = LiteralUtil::CreateR4WithLayout<float>({{
675 {{10, 11, 12, 13}, {14, 15, 16, 17}},
676 {{18, 19, 20, 21}, {22, 23, 24, 25}},
677 {{26, 27, 28, 29}, {30, 31, 32, 33}},
678 }}, layout_r4_dim0major_);
679 // F32[1x3x4x2]
680 auto expected = LiteralUtil::CreateR3WithLayout<float>({
681 {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
682 {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
683 {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
684 }, layout_r3_dim0major_);
685 // clang-format on
686 auto reshape = original.Reshape({3, 4, 2}).value();
687
688 EXPECT_EQ(expected, reshape);
689 }
690
TEST_F(LiteralUtilTest,ReshapeR4Dim0Minor)691 TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
692 // clang-format off
693 // F32[1x3x2x4]
694 auto original = LiteralUtil::CreateR4WithLayout<float>({{
695 {{10, 11, 12, 13}, {14, 15, 16, 17}},
696 {{18, 19, 20, 21}, {22, 23, 24, 25}},
697 {{26, 27, 28, 29}, {30, 31, 32, 33}},
698 }}, layout_r4_dim0minor_);
699 // F32[1x3x4x2]
700 auto expected = LiteralUtil::CreateR3WithLayout<float>({
701 {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
702 {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
703 {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
704 }, layout_r3_dim0major_);
705 // clang-format on
706 auto reshape = original.Reshape({3, 4, 2}).value();
707
708 EXPECT_EQ(expected, reshape);
709 }
710
TEST_F(LiteralUtilTest,TransposeR0)711 TEST_F(LiteralUtilTest, TransposeR0) {
712 auto original = LiteralUtil::CreateR0<float>(1.7f);
713 auto reshape = original.Transpose(/*permutation=*/{});
714 EXPECT_EQ(original, reshape);
715 }
716
TEST_F(LiteralUtilTest,TransposeR4)717 TEST_F(LiteralUtilTest, TransposeR4) {
718 // clang-format off
719 // F32[1x3x2x4]
720 auto original = LiteralUtil::CreateR4<float>({{
721 {{10, 11, 12, 13}, {14, 15, 16, 17}},
722 {{18, 19, 20, 21}, {22, 23, 24, 25}},
723 {{26, 27, 28, 29}, {30, 31, 32, 33}},
724 }});
725 // clang-format on
726 auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
727
728 reshape.EachCell<float>([&](absl::Span<const int64_t> indices, float value) {
729 EXPECT_EQ(value, original.Get<float>(
730 {indices[2], indices[3], indices[0], indices[1]}));
731 });
732 }
733
TEST_F(LiteralUtilTest,TransposeDynamicR2)734 TEST_F(LiteralUtilTest, TransposeDynamicR2) {
735 // F32[2, <=3] (2, 1)
736 auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
737 original.SetDynamicSize(1, 1);
738 // F32[<=3, 2] (1, 2)
739 auto reshape = original.Transpose(/*permutation=*/{1, 0});
740
741 reshape.EachCell<float>([&](absl::Span<const int64_t> indices, float value) {
742 EXPECT_EQ(value, original.Get<float>({indices[1], indices[0]}));
743 });
744 }
745
TEST_F(LiteralUtilTest,ToStaticR2)746 TEST_F(LiteralUtilTest, ToStaticR2) {
747 // F32[2, <=3] (2, 1)
748 auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
749 original.SetDynamicSize(1, 1);
750 // F32[2, 1]
751 auto static_literal = original.ToStatic();
752 EXPECT_EQ(static_literal.shape(), ShapeUtil::MakeShape(F32, {2, 1}));
753 EXPECT_TRUE(static_literal.shape().is_static());
754
755 static_literal.EachCell<float>(
756 [&](absl::Span<const int64_t> indices, float value) {
757 EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
758 });
759 }
760
TEST_F(LiteralUtilTest,ToBoundedDynamicR2)761 TEST_F(LiteralUtilTest, ToBoundedDynamicR2) {
762 // F32[2, 1]
763 auto original = LiteralUtil::CreateR2<float>({{1}, {4}});
764 // F32[2, <=3] (2, 1)
765 auto dynamic_shape = ShapeUtil::MakeShape(F32, {2, 3}, {false, true});
766 auto dynamic_literal = original.ToBoundedDynamic(dynamic_shape);
767 EXPECT_EQ(dynamic_literal.shape(), dynamic_shape);
768
769 dynamic_literal.EachCell<float>(
770 [&](absl::Span<const int64_t> indices, float value) {
771 EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
772 });
773 }
774
TEST_F(LiteralUtilTest,TestR4RelayoutEquivalence)775 TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
776 // Tests that using Relayout on an array is equivalent to creating it in the
777 // target layout in the first place.
778 auto dim0minor_relaid_to_dim0major =
779 literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
780 EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
781
782 auto dim0major_relaid_to_dim0minor =
783 literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
784 EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
785 }
786
TEST_F(LiteralUtilTest,TestR2LinearLayout)787 TEST_F(LiteralUtilTest, TestR2LinearLayout) {
788 // Test expected memory layout of R2 dim0-minor (column-major) literal.
789 auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32_t>(
790 {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
791 EXPECT_EQ(mat_dim0minor.element_count(), 6);
792 EXPECT_THAT(mat_dim0minor.data<int32_t>(), ElementsAre(1, 4, 2, 5, 3, 6));
793
794 // Test expected memory layout when using Relayout to row major.
795 auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
796 EXPECT_THAT(relaid_mat_to_dim0major.data<int32_t>(),
797 ElementsAre(1, 2, 3, 4, 5, 6));
798
799 // Test expected memory layout of R2 created with dim0-major (row-major).
800 auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32_t>(
801 {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
802 EXPECT_EQ(mat_dim0major.element_count(), 6);
803 EXPECT_THAT(mat_dim0major.data<int32_t>(), ElementsAre(1, 2, 3, 4, 5, 6));
804
805 // Test expected memory layout when using Relayout to column major.
806 auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
807 EXPECT_THAT(relaid_mat_to_dim0minor.data<int32_t>(),
808 ElementsAre(1, 4, 2, 5, 3, 6));
809 }
810
TEST_F(LiteralUtilTest,TestR3LinearLayout)811 TEST_F(LiteralUtilTest, TestR3LinearLayout) {
812 // Test expected memory layout of R3 dim0-minor (column-major) literal.
813 Array3D<int> arr3d(
814 // clang-format off
815 {
816 {
817 {1, 2, 3},
818 {4, 5, 6},
819 },
820 {
821 {7, 8, 9},
822 {10, 11, 12},
823 },
824 }); // clang-format on
825 auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
826 arr3d, layout_r3_dim0minor_);
827
828 EXPECT_EQ(lit_dim0minor.element_count(), 12);
829 std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
830 EXPECT_THAT(lit_dim0minor.data<int32_t>(),
831 testing::ElementsAreArray(expected_dim0minor));
832
833 // Test expected memory layout when using Relayout to row major.
834 auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
835 std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
836 EXPECT_THAT(relaid_lit_to_dim0major.data<int32_t>(),
837 testing::ElementsAreArray(expected_dim0major));
838
839 // Test expected memory layout of R3 created with dim0-major (row-major).
840 auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
841 arr3d, layout_r3_dim0major_);
842 EXPECT_EQ(lit_dim0major.element_count(), 12);
843 EXPECT_THAT(lit_dim0major.data<int32_t>(),
844 testing::ElementsAreArray(expected_dim0major));
845
846 // Test expected memory layout when using Relayout to column major.
847 auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
848 EXPECT_THAT(relaid_lit_to_dim0minor.data<int32_t>(),
849 testing::ElementsAreArray(expected_dim0minor));
850 }
851
TEST_F(LiteralUtilTest,SliceR0S32)852 TEST_F(LiteralUtilTest, SliceR0S32) {
853 auto input = LiteralUtil::CreateR0<int32_t>(1);
854 auto result = input.Slice({}, {});
855 EXPECT_EQ(input, result);
856 }
857
TEST_F(LiteralUtilTest,SliceR1F32)858 TEST_F(LiteralUtilTest, SliceR1F32) {
859 auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
860 auto result = input.Slice({3}, {4});
861 auto expected = LiteralUtil::CreateR1<float>({4.0});
862 EXPECT_EQ(expected, result);
863 }
864
TEST_F(LiteralUtilTest,SliceR2U32)865 TEST_F(LiteralUtilTest, SliceR2U32) {
866 auto input_3x4 = LiteralUtil::CreateR2<uint32_t>(
867 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
868 auto result = input_3x4.Slice({0, 2}, {2, 4});
869 auto expected = LiteralUtil::CreateR2<uint32_t>({{3, 4}, {7, 8}});
870 EXPECT_EQ(expected, result);
871 }
872
TEST_F(LiteralUtilTest,SliceR3U32Full)873 TEST_F(LiteralUtilTest, SliceR3U32Full) {
874 auto input_2x3x2 = LiteralUtil::CreateR3<uint32_t>(
875 {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
876 auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
877 EXPECT_EQ(input_2x3x2, result);
878 }
879
TEST_F(LiteralUtilTest,SliceR2Dynamic)880 TEST_F(LiteralUtilTest, SliceR2Dynamic) {
881 auto input_3x4 = LiteralUtil::CreateR2<uint32_t>(
882 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
883 input_3x4.SetDynamicSize(1, 3);
884 // slice second dim from dynamic size 3 to dynamic size 1.
885 auto result = input_3x4.Slice({0, 1}, {2, 2});
886 auto expected = LiteralUtil::CreateR2<uint32_t>({{2}, {6}});
887 EXPECT_EQ(expected, result);
888 EXPECT_EQ(result.GetDynamicSize(1), 1);
889 }
890
TEST_F(LiteralUtilTest,SliceR2DynamicInBound)891 TEST_F(LiteralUtilTest, SliceR2DynamicInBound) {
892 auto input_3x4 = LiteralUtil::CreateR2<uint32_t>(
893 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
894 input_3x4.SetDynamicSize(1, 1);
895 auto result = input_3x4.Slice({0, 0}, {2, 2});
896 auto expected = LiteralUtil::CreateR2<uint32_t>({{1}, {5}});
897 EXPECT_EQ(expected, result);
898 EXPECT_EQ(result.GetDynamicSize(1), 1);
899 }
900
TEST_F(LiteralUtilTest,SliceR2DynamicOutOfBound)901 TEST_F(LiteralUtilTest, SliceR2DynamicOutOfBound) {
902 auto input_3x4 = LiteralUtil::CreateR2<uint32_t>(
903 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
904 input_3x4.SetDynamicSize(1, 1);
905 auto result = input_3x4.Slice({0, 1}, {2, 3});
906 auto expected = LiteralUtil::CreateR2<uint32_t>({{}, {}});
907 EXPECT_EQ(expected, result);
908 // Out of bound access clamps into 0 sized dimension.
909 EXPECT_EQ(result.GetDynamicSize(1), 0);
910 }
911
TEST_F(LiteralUtilTest,PopulateR1S64)912 TEST_F(LiteralUtilTest, PopulateR1S64) {
913 Literal output(ShapeUtil::MakeShape(S64, {1}));
914 output.PopulateR1<int64_t>({77});
915 auto expected = LiteralUtil::CreateR1<int64_t>({77});
916 EXPECT_EQ(output, expected);
917 }
918
TEST_F(LiteralUtilTest,PopulateR1U64)919 TEST_F(LiteralUtilTest, PopulateR1U64) {
920 Literal output(ShapeUtil::MakeShape(U64, {2}));
921 output.PopulateR1<uint64_t>({{77, 88}});
922 auto expected = LiteralUtil::CreateR1<uint64_t>({{77, 88}});
923 EXPECT_EQ(output, expected);
924 }
925
TEST_F(LiteralUtilTest,PopulateR1C64)926 TEST_F(LiteralUtilTest, PopulateR1C64) {
927 Literal output(ShapeUtil::MakeShape(C64, {1}));
928 output.PopulateR1<complex64>({{77, 88}});
929 auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
930 EXPECT_EQ(output, expected);
931 }
932
TEST_F(LiteralUtilTest,PopulateR1C128)933 TEST_F(LiteralUtilTest, PopulateR1C128) {
934 Literal output(ShapeUtil::MakeShape(C128, {1}));
935 output.PopulateR1<complex128>({{77, 88}});
936 auto expected = LiteralUtil::CreateR1<complex128>({{77, 88}});
937 EXPECT_EQ(output, expected);
938 }
939
TEST_F(LiteralUtilTest,PopulateR2C64)940 TEST_F(LiteralUtilTest, PopulateR2C64) {
941 Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
942 output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
943 auto expected =
944 LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
945 EXPECT_EQ(output, expected);
946 }
947
TEST_F(LiteralUtilTest,PopulateWithValueR0BF16)948 TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
949 Literal output(ShapeUtil::MakeShape(BF16, {}));
950 bfloat16 h(0.25f);
951 output.PopulateWithValue<bfloat16>(h);
952 auto expected = LiteralUtil::CreateR0<bfloat16>(h);
953 EXPECT_EQ(output, expected);
954 }
955
TEST_F(LiteralUtilTest,PopulateWithValueR1BF16)956 TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
957 Literal output(ShapeUtil::MakeShape(BF16, {3}));
958 bfloat16 h(0.5f);
959 output.PopulateWithValue<bfloat16>(h);
960 auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
961 EXPECT_EQ(output, expected);
962 }
963
TEST_F(LiteralUtilTest,PopulateWithValueR2BF16)964 TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
965 Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
966 bfloat16 h(2.0f);
967 output.PopulateWithValue<bfloat16>(h);
968 auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
969 EXPECT_EQ(output, expected);
970 }
971
TEST_F(LiteralUtilTest,PopulateWithValueR0F32)972 TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
973 Literal output(ShapeUtil::MakeShape(F32, {}));
974 output.PopulateWithValue<float>(2.5f);
975 auto expected = LiteralUtil::CreateR0<float>(2.5f);
976 EXPECT_EQ(output, expected);
977 }
978
TEST_F(LiteralUtilTest,PopulateWithValueR1S64)979 TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
980 Literal output(ShapeUtil::MakeShape(S64, {3}));
981 output.PopulateWithValue<int64_t>(-7);
982 auto expected = LiteralUtil::CreateR1<int64_t>({-7, -7, -7});
983 EXPECT_EQ(output, expected);
984 }
985
TEST_F(LiteralUtilTest,PopulateWithValueR2U64)986 TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
987 Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
988 output.PopulateWithValue<uint64_t>(42);
989 auto expected = LiteralUtil::CreateR2<uint64_t>({{42, 42}, {42, 42}});
990 EXPECT_EQ(output, expected);
991 }
992
TEST_F(LiteralUtilTest,PopulateWithValueR2C64)993 TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
994 Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
995 output.PopulateWithValue<complex64>({4, 2});
996 auto expected =
997 LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
998 EXPECT_EQ(output, expected);
999 }
1000
TEST_F(LiteralUtilTest,PopulateWithValueR2C128)1001 TEST_F(LiteralUtilTest, PopulateWithValueR2C128) {
1002 Literal output(ShapeUtil::MakeShape(C128, {2, 2}));
1003 output.PopulateWithValue<complex128>({4, 2});
1004 auto expected =
1005 LiteralUtil::CreateR2<complex128>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
1006 EXPECT_EQ(output, expected);
1007 }
1008
TEST_F(LiteralUtilTest,PopulateWithValueR0F16)1009 TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
1010 Literal output(ShapeUtil::MakeShape(F16, {}));
1011 half h(0.25f);
1012 output.PopulateWithValue<half>(h);
1013 auto expected = LiteralUtil::CreateR0<half>(h);
1014 EXPECT_EQ(output, expected);
1015 }
1016
TEST_F(LiteralUtilTest,PopulateWithValueR1F16)1017 TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
1018 Literal output(ShapeUtil::MakeShape(F16, {3}));
1019 half h(0.5f);
1020 output.PopulateWithValue<half>(h);
1021 auto expected = LiteralUtil::CreateR1<half>({h, h, h});
1022 EXPECT_EQ(output, expected);
1023 }
1024
TEST_F(LiteralUtilTest,PopulateWithValueR2F16)1025 TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
1026 Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
1027 half h(2.0f);
1028 output.PopulateWithValue<half>(h);
1029 auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
1030 EXPECT_EQ(output, expected);
1031 }
1032
TEST_F(LiteralUtilTest,ReplicateR2U32)1033 TEST_F(LiteralUtilTest, ReplicateR2U32) {
1034 auto input = LiteralUtil::CreateR2<uint32_t>(
1035 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
1036 auto output = input.Replicate<uint32_t>(3);
1037 auto expected = LiteralUtil::CreateR3<uint32_t>(
1038 {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
1039 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
1040 {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
1041 EXPECT_EQ(output, expected);
1042 }
1043
TEST_F(LiteralUtilTest,CopySliceFrom)1044 TEST_F(LiteralUtilTest, CopySliceFrom) {
1045 const int64_t dimensions[] = {17, 15, 34, 21};
1046 const int64_t layouts[][4] = {
1047 {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}};
1048 for (const auto& layout : layouts) {
1049 Shape shape = ShapeUtil::MakeShapeWithLayout(
1050 primitive_util::NativeToPrimitiveType<uint32_t>(), dimensions, layout);
1051
1052 auto source = Literal::CreateFromShape(shape);
1053 const int64_t zero_base[] = {0, 0, 0, 0};
1054 const int64_t step[] = {1, 1, 1, 1};
1055 uint32_t seqnr = 0;
1056 auto init_proc = [&](absl::Span<const int64_t> indexes) {
1057 source.Set(indexes, ++seqnr);
1058 return true;
1059 };
1060 ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
1061 init_proc);
1062
1063 auto blank = Literal::CreateFromShape(shape);
1064 const int64_t src_base[] = {3, 1, 5, 7};
1065 const int64_t dest_base[] = {6, 4, 12, 2};
1066 const int64_t copy_size[] = {7, 8, 11, 9};
1067 TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
1068
1069 std::vector<int64_t> source_indexes(TF_ARRAYSIZE(dimensions), 0);
1070 std::vector<int64_t> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
1071 bool matched = true;
1072 auto check_proc = [&](absl::Span<const int64_t> indexes) {
1073 std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
1074 std::transform(source_indexes.begin(), source_indexes.end(), src_base,
1075 source_indexes.begin(), std::plus<int64_t>());
1076 std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
1077 std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
1078 blank_indexes.begin(), std::plus<int64_t>());
1079 auto bval = blank.Get<uint32_t>(blank_indexes);
1080 matched = (bval != 0 && bval == source.Get<uint32_t>(source_indexes));
1081 return matched;
1082 };
1083
1084 ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
1085 check_proc);
1086 EXPECT_TRUE(matched);
1087 }
1088 }
1089
TEST_F(LiteralUtilTest,CopyFromScalars)1090 TEST_F(LiteralUtilTest, CopyFromScalars) {
1091 auto zero = LiteralUtil::CreateR0<uint32_t>(0);
1092 auto nine = LiteralUtil::CreateR0<uint32_t>(9);
1093 TF_EXPECT_OK(zero.CopyFrom(nine));
1094 EXPECT_EQ(zero, nine);
1095
1096 auto vect = LiteralUtil::CreateR1<uint32_t>({3, 4, 9, 12, 5, 17, 21});
1097 TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
1098 EXPECT_EQ(zero.Get<uint32_t>({}), 17);
1099 TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
1100 EXPECT_EQ(vect.Get<uint32_t>({4}), 17);
1101 }
1102
TEST_F(LiteralUtilTest,CopyFromAndToZeroElement)1103 TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
1104 const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0});
1105 const auto const_nine = LiteralUtil::CreateR1<float>({9});
1106 const auto const_empty = Literal::CreateFromShape(empty_r1_shape);
1107
1108 {
1109 // Source contains dimension with zero elements.
1110 const auto empty = Literal::CreateFromShape(empty_r1_shape);
1111 auto nine = LiteralUtil::CreateR1<float>({9});
1112
1113 TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
1114 EXPECT_EQ(nine, const_nine);
1115 }
1116
1117 {
1118 // Copy 0 element to destination with zero elements.
1119 auto empty = Literal::CreateFromShape(empty_r1_shape);
1120 auto nine = LiteralUtil::CreateR1<float>({9});
1121
1122 TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
1123 EXPECT_EQ(empty, const_empty);
1124 }
1125 }
1126
TEST_F(LiteralUtilTest,CopyFromNilShape)1127 TEST_F(LiteralUtilTest, CopyFromNilShape) {
1128 Literal nil_literal0(ShapeUtil::MakeNil());
1129 Literal nil_literal1(ShapeUtil::MakeNil());
1130 // This doesn't actually do any copying, but it should succeed.
1131 TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1));
1132 }
1133
TEST_F(LiteralUtilTest,CopyFromArrays)1134 TEST_F(LiteralUtilTest, CopyFromArrays) {
1135 auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
1136 auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
1137 EXPECT_NE(scalar_42, scalar_123);
1138 TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
1139 /*src_shape_index=*/{}));
1140 EXPECT_EQ(scalar_42, scalar_123);
1141 EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
1142
1143 auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1144 auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
1145 EXPECT_NE(matrix_1234, matrix_5678);
1146 EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
1147 TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
1148 /*src_shape_index=*/{}));
1149 EXPECT_EQ(matrix_1234, matrix_5678);
1150 EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
1151 }
1152
TEST_F(LiteralUtilTest,CopyFromTuples)1153 TEST_F(LiteralUtilTest, CopyFromTuples) {
1154 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1155 Literal nil_literal(ShapeUtil::MakeNil());
1156 Literal inner_elements[] = {LiteralUtil::CreateR0<int32_t>(42),
1157 LiteralUtil::CreateR1<double>({23.0, 44.0})};
1158 Literal inner_tuple = LiteralUtil::MakeTuple(
1159 {&inner_elements[0], &inner_elements[1], &nil_literal});
1160 Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
1161 // Create a tuple the same shape as the inner tuple of nested_tuple but with
1162 // different values..
1163 Literal int32_minus5 = LiteralUtil::CreateR0<int32_t>(-5);
1164 Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
1165 Literal tuple =
1166 LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
1167
1168 EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
1169 EXPECT_EQ(nested_tuple.Get<int32_t>({}, {1, 0}), 42);
1170 EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
1171 EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
1172
1173 // Overwrite the inner tuple element of nested_tuple with the contents of
1174 // 'tuple'.
1175 TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
1176 /*src_shape_index=*/{}));
1177
1178 // The matrix element should be unchanged.
1179 EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
1180
1181 // The tuple element should have been copied from 'tuple'.
1182 EXPECT_EQ(nested_tuple.Get<int32_t>({}, {1, 0}), -5);
1183 EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
1184 EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
1185 }
TEST_F(LiteralUtilTest,CopyBetweenSameTuple)1186 TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
1187 Literal elements[] = {LiteralUtil::CreateR0<int32_t>(-2),
1188 LiteralUtil::CreateR0<int32_t>(4)};
1189 Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
1190
1191 EXPECT_EQ(tuple.Get<int32_t>({}, {0}), -2);
1192 EXPECT_EQ(tuple.Get<int32_t>({}, {1}), 4);
1193
1194 // Copy from one element to the other.
1195 TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
1196 /*src_shape_index=*/{0}));
1197
1198 EXPECT_EQ(tuple.Get<int32_t>({}, {0}), -2);
1199 EXPECT_EQ(tuple.Get<int32_t>({}, {1}), -2);
1200 }
1201
TEST_F(LiteralUtilTest,CopyFromDifferentShapes)1202 TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
1203 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1204 auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
1205 Status status = matrix.CopyFrom(vector);
1206 ASSERT_FALSE(status.ok());
1207 EXPECT_THAT(status.error_message(),
1208 HasSubstr("Destination subshape incompatible"));
1209 }
1210
TEST_F(LiteralUtilTest,F16)1211 TEST_F(LiteralUtilTest, F16) {
1212 // Verify that the internal data views are consistent and that they
1213 // are in little endian format
1214 // TODO - modify if we make the data format machine endianness dependent
1215 Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
1216 const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
1217 EXPECT_EQ(d1[0], 0);
1218 EXPECT_EQ(d1[1], 0);
1219 EXPECT_EQ(d1[2], 0);
1220 EXPECT_EQ(d1[3], 0);
1221 EXPECT_EQ(d1[4], 0);
1222 EXPECT_EQ(d1[5], 0);
1223 EXPECT_EQ(d1[6], 0);
1224 EXPECT_EQ(d1[7], 0);
1225
1226 half h1(1.0f);
1227 half h2(2.0f);
1228 auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
1229 const uint16_t* d2 =
1230 reinterpret_cast<const uint16_t*>(m2.data<half>().data());
1231 EXPECT_EQ(d2[0], 0x3C00);
1232 EXPECT_EQ(d2[1], 0x4000);
1233 EXPECT_EQ(d2[2], 0x4000);
1234 EXPECT_EQ(d2[3], 0x3C00);
1235 }
1236
TEST_F(LiteralUtilTest,Populate)1237 TEST_F(LiteralUtilTest, Populate) {
1238 struct PopulateData {
1239 std::vector<int64_t> dimensions;
1240 std::vector<int64_t> layout;
1241 } populate_data[] = {
1242 {{}, {}},
1243 {{0}, {0}},
1244 {{16}, {0}},
1245 {{2, 0}, {1, 0}},
1246 {{4, 16}, {1, 0}},
1247 {{21, 12}, {0, 1}},
1248 {{6, 11, 17}, {2, 0, 1}},
1249 {{6, 11, 5, 17}, {3, 2, 0, 1}},
1250 };
1251 for (const auto& data : populate_data) {
1252 Shape shape = ShapeUtil::MakeShapeWithLayout(
1253 primitive_util::NativeToPrimitiveType<uint32_t>(), data.dimensions,
1254 data.layout);
1255 Literal literal(shape);
1256 auto generator = [&](absl::Span<const int64_t> indexes) -> uint32_t {
1257 // Offsets from linear index just to avoid R0 literals to be initialized
1258 // with zero.
1259 return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
1260 indexes) +
1261 17;
1262 };
1263 TF_EXPECT_OK(literal.Populate<uint32_t>(generator));
1264
1265 std::vector<int64_t> zero_base(data.dimensions.size(), 0);
1266 std::vector<int64_t> step(data.dimensions.size(), 1);
1267 bool matched = true;
1268 auto check_function = [&](absl::Span<const int64_t> indexes) {
1269 auto value = literal.Get<uint32_t>(indexes);
1270 matched = matched && (value == generator(indexes));
1271 return matched;
1272 };
1273 ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
1274 check_function);
1275 EXPECT_TRUE(matched);
1276 }
1277 }
1278
TEST_F(LiteralUtilTest,PopulateParallel)1279 TEST_F(LiteralUtilTest, PopulateParallel) {
1280 struct PopulateData {
1281 std::vector<int64_t> dimensions;
1282 std::vector<int64_t> layout;
1283 } populate_data[] = {
1284 {{}, {}},
1285 {{0}, {0}},
1286 {{16}, {0}},
1287 {{2, 0}, {1, 0}},
1288 {{4, 16}, {1, 0}},
1289 {{21, 12}, {0, 1}},
1290 {{6, 11, 17}, {2, 0, 1}},
1291 {{6, 11, 5, 17}, {3, 2, 0, 1}},
1292 };
1293 for (const auto& data : populate_data) {
1294 Shape shape = ShapeUtil::MakeShapeWithLayout(
1295 primitive_util::NativeToPrimitiveType<uint32_t>(), data.dimensions,
1296 data.layout);
1297 Literal literal(shape);
1298 auto generator = [&](absl::Span<const int64_t> indexes,
1299 int /*thread_id*/) -> uint32_t {
1300 // Offsets from linear index just to avoid R0 literals to be initialized
1301 // with zero.
1302 return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
1303 indexes) +
1304 17;
1305 };
1306 TF_EXPECT_OK(literal.PopulateParallel<uint32_t>(generator));
1307
1308 std::vector<int64_t> zero_base(data.dimensions.size(), 0);
1309 std::vector<int64_t> step(data.dimensions.size(), 1);
1310 bool matched = true;
1311 auto check_function = [&](absl::Span<const int64_t> indexes) {
1312 auto value = literal.Get<uint32_t>(indexes);
1313 matched = matched && (value == generator(indexes, /*thread_id=*/-1));
1314 return matched;
1315 };
1316 ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
1317 check_function);
1318 EXPECT_TRUE(matched);
1319 }
1320 }
1321
TEST_F(LiteralUtilTest,ConvertR4)1322 TEST_F(LiteralUtilTest, ConvertR4) {
1323 // clang-format off
1324 auto original = LiteralUtil::CreateR4WithLayout<int8_t>({{
1325 {{10, 11, 12, 13}, {14, 15, 16, 17}},
1326 {{18, 19, 20, 21}, {22, 23, 24, 25}},
1327 {{26, 27, 28, 29}, {30, 31, 32, 33}},
1328 }}, layout_r4_dim0major_);
1329 auto expected = LiteralUtil::CreateR4WithLayout<uint32_t>({{
1330 {{10, 11, 12, 13}, {14, 15, 16, 17}},
1331 {{18, 19, 20, 21}, {22, 23, 24, 25}},
1332 {{26, 27, 28, 29}, {30, 31, 32, 33}},
1333 }}, layout_r4_dim0major_);
1334 // clang-format on
1335 TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
1336
1337 EXPECT_EQ(expected, converted);
1338 }
1339
TEST_F(LiteralUtilTest,ConvertIfTypesMatch)1340 TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
1341 // clang-format off
1342 auto s8 = LiteralUtil::CreateR4WithLayout<int8_t>({{
1343 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1344 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1345 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1346 }}, layout_r4_dim0major_);
1347 auto s16 = LiteralUtil::CreateR4WithLayout<int16_t>({{
1348 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1349 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1350 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1351 }}, layout_r4_dim0major_);
1352 auto s32 = LiteralUtil::CreateR4WithLayout<int32_t>({{
1353 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1354 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1355 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1356 }}, layout_r4_dim0major_);
1357 auto u16 = LiteralUtil::CreateR4WithLayout<uint16_t>({{
1358 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1359 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1360 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1361 }}, layout_r4_dim0major_);
1362 auto u32 = LiteralUtil::CreateR4WithLayout<uint32_t>({{
1363 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1364 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1365 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1366 }}, layout_r4_dim0major_);
1367 auto s64 = LiteralUtil::CreateR4WithLayout<int64_t>({{
1368 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1369 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1370 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1371 }}, layout_r4_dim0major_);
1372 auto u64 = LiteralUtil::CreateR4WithLayout<uint64_t>({{
1373 {{10, 0, 12, 0}, {0, 15, 0, 17}},
1374 {{0, 19, 0, 21}, {22, 0, 24, 0}},
1375 {{26, 0, 28, 0}, {0, 31, 0, 33}},
1376 }}, layout_r4_dim0major_);
1377 auto pred = LiteralUtil::CreateR4WithLayout<bool>({{
1378 {{true, false, true, false}, {false, true, false, true}},
1379 {{false, true, false, true}, {true, false, true, false}},
1380 {{true, false, true, false}, {false, true, false, true}},
1381 }}, layout_r4_dim0major_);
1382 auto int32_pred = LiteralUtil::CreateR4WithLayout<int32_t>({{
1383 {{1, 0, 1, 0}, {0, 1, 0, 1}},
1384 {{0, 1, 0, 1}, {1, 0, 1, 0}},
1385 {{1, 0, 1, 0}, {0, 1, 0, 1}},
1386 }}, layout_r4_dim0major_);
1387 auto f16 = LiteralUtil::CreateR4WithLayout<half>({{
1388 {{half(10.0), half(0.0), half(12.0), half(0.0)},
1389 {half(0.0), half(15.0), half(0.0), half(17.0)}},
1390 {{half(0.0), half(19.0), half(0.0), half(21.0)},
1391 {half(22.0), half(0.0), half(24.0), half(0.0)}},
1392 {{half(26.0), half(0.0), half(28.0), half(0.0)},
1393 {half(0.0), half(31.0), half(0.0), half(33.0)}},
1394 }}, layout_r4_dim0major_);
1395 auto bf16 = LiteralUtil::CreateR4WithLayout<bfloat16>({{
1396 {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
1397 {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
1398 {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
1399 {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
1400 {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
1401 {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
1402 }}, layout_r4_dim0major_);
1403 auto f32 = LiteralUtil::CreateR4WithLayout<float>({{
1404 {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
1405 {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
1406 {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
1407 }}, layout_r4_dim0major_);
1408 auto f64 = LiteralUtil::CreateR4WithLayout<double>({{
1409 {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
1410 {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
1411 {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
1412 }}, layout_r4_dim0major_);
1413 auto c64 = LiteralUtil::CreateR4WithLayout<complex64>({{
1414 {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
1415 {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
1416 {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
1417 }}, layout_r4_dim0major_);
1418 auto c128 = LiteralUtil::CreateR4WithLayout<complex128>({{
1419 {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
1420 {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
1421 {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
1422 }}, layout_r4_dim0major_); // clang-format on
1423 Literal conv;
1424
1425 conv = s8.Convert(U16).value();
1426 EXPECT_EQ(conv, u16);
1427
1428 conv = s8.Convert(S16).value();
1429 EXPECT_EQ(conv, s16);
1430
1431 conv = s8.Convert(U32).value();
1432 EXPECT_EQ(conv, u32);
1433
1434 conv = s8.Convert(S32).value();
1435 EXPECT_EQ(conv, s32);
1436
1437 conv = s8.Convert(U64).value();
1438 EXPECT_EQ(conv, u64);
1439
1440 conv = s8.Convert(S64).value();
1441 EXPECT_EQ(conv, s64);
1442
1443 conv = s8.Convert(PRED).value();
1444 EXPECT_EQ(conv, pred);
1445
1446 conv = bf16.Convert(S32).value();
1447 EXPECT_EQ(conv, s32);
1448
1449 conv = bf16.Convert(F32).value();
1450 EXPECT_EQ(conv, f32);
1451
1452 conv = pred.Convert(S32).value();
1453 EXPECT_EQ(conv, int32_pred);
1454
1455 conv = f32.Convert(S32).value();
1456 EXPECT_EQ(conv, s32);
1457
1458 conv = f64.Convert(S32).value();
1459 EXPECT_EQ(conv, s32);
1460
1461 conv = s32.Convert(F32).value();
1462 EXPECT_EQ(conv, f32);
1463
1464 conv = f32.Convert(F16).value();
1465 EXPECT_EQ(conv, f16);
1466
1467 conv = f64.Convert(F16).value();
1468 EXPECT_EQ(conv, f16);
1469
1470 conv = s32.Convert(F16).value();
1471 EXPECT_EQ(conv, f16);
1472
1473 conv = u32.Convert(F16).value();
1474 EXPECT_EQ(conv, f16);
1475
1476 conv = s32.Convert(C64).value();
1477 EXPECT_EQ(conv, c64);
1478
1479 conv = f16.Convert(C64).value();
1480 EXPECT_EQ(conv, c64);
1481
1482 conv = s32.Convert(S16).value();
1483 EXPECT_EQ(conv, s16);
1484
1485 conv = s32.Convert(U16).value();
1486 EXPECT_EQ(conv, u16);
1487
1488 conv = s32.Convert(C128).value();
1489 EXPECT_EQ(conv, c128);
1490
1491 conv = f16.Convert(C128).value();
1492 EXPECT_EQ(conv, c128);
1493
1494 EXPECT_EQ(s32.Convert(TUPLE).status().code(),
1495 tensorflow::error::UNIMPLEMENTED);
1496 EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
1497 EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
1498 EXPECT_EQ(c128.Convert(F32).status().code(),
1499 tensorflow::error::UNIMPLEMENTED);
1500 EXPECT_EQ(c128.Convert(S32).status().code(),
1501 tensorflow::error::UNIMPLEMENTED);
1502 }
1503
TEST_F(LiteralUtilTest,BitcastConvert)1504 TEST_F(LiteralUtilTest, BitcastConvert) {
1505 Literal original = LiteralUtil::CreateR1<uint32_t>(
1506 {absl::bit_cast<uint32_t>(2.5f), absl::bit_cast<uint32_t>(-42.25f),
1507 absl::bit_cast<uint32_t>(100.f), 0xbeef});
1508 Literal expected = LiteralUtil::CreateR1<float>(
1509 {2.5f, -42.25f, 100.0f, absl::bit_cast<float>(0xbeef)});
1510 TF_ASSERT_OK_AND_ASSIGN(Literal converted,
1511 original.BitcastConvert(ShapeUtil::ChangeElementType(
1512 original.shape(), F32)));
1513 }
1514
TEST_F(LiteralUtilTest,BitcastConvertBetweenInvalidTypes)1515 TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
1516 Literal literal = LiteralUtil::CreateR0<uint32_t>(1234);
1517 Status status =
1518 literal.BitcastConvert(ShapeUtil::ChangeElementType(literal.shape(), F64))
1519 .status();
1520 EXPECT_NE(OkStatus(), status);
1521 EXPECT_TRUE(absl::StrContains(status.error_message(),
1522 "to a shape of different size"));
1523 }
1524
1525 // Sets the layout of the given ShapeProto to the default.
SetDefaultLayoutOnProto(ShapeProto * shape_proto)1526 void SetDefaultLayoutOnProto(ShapeProto* shape_proto) {
1527 CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type()));
1528 auto* minor_to_major =
1529 shape_proto->mutable_layout()->mutable_minor_to_major();
1530 minor_to_major->Resize(shape_proto->dimensions_size(), 0);
1531 const int64_t size = minor_to_major->size();
1532 for (int64_t i = 0; i < size; ++i) {
1533 minor_to_major->Set(i, size - 1 - i);
1534 }
1535 }
1536
TEST_F(LiteralUtilTest,CopyFromProto_Bool)1537 TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
1538 LiteralProto p;
1539 p.mutable_shape()->set_element_type(PRED);
1540 for (int len = 0; len < 25; ++len) {
1541 p.mutable_shape()->clear_dimensions();
1542 p.mutable_shape()->add_dimensions(len);
1543 SetDefaultLayoutOnProto(p.mutable_shape());
1544 p.clear_preds();
1545 for (int i = 0; i < len; ++i) {
1546 p.add_preds((i % 2) == (len % 2));
1547 }
1548
1549 TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1550 ASSERT_EQ(len, literal.data<bool>().size());
1551 int i = 0;
1552 for (bool value : literal.data<bool>()) {
1553 EXPECT_EQ((i % 2) == (len % 2), value);
1554 ++i;
1555 }
1556 }
1557 }
1558
1559 // Note that f16 is currently stored in a byte array in little endian byte order
TEST_F(LiteralUtilTest,ToProto_f16)1560 TEST_F(LiteralUtilTest, ToProto_f16) {
1561 half h1(1.0f);
1562 half h2(2.0f);
1563
1564 auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
1565 EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
1566 EXPECT_EQ(4, m.data<half>().size());
1567
1568 LiteralProto p = m.ToProto();
1569 EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape())));
1570 EXPECT_EQ(8, p.f16s().size());
1571 const char* d = p.f16s().data();
1572 EXPECT_EQ(d[0], 0);
1573 EXPECT_EQ(d[1], 0x3C);
1574 EXPECT_EQ(d[2], 0);
1575 EXPECT_EQ(d[3], 0x40);
1576 EXPECT_EQ(d[4], 0);
1577 EXPECT_EQ(d[5], 0x40);
1578 EXPECT_EQ(d[6], 0);
1579 EXPECT_EQ(d[7], 0x3C);
1580 }
1581
1582 // Note that f16 is currently stored in a byte array in little endian byte order
TEST_F(LiteralUtilTest,CopyFromProto_f16)1583 TEST_F(LiteralUtilTest, CopyFromProto_f16) {
1584 half h1(1.0f);
1585 half h2(2.0f);
1586
1587 const char half_vals[8] = {0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C};
1588 LiteralProto p;
1589 p.mutable_shape()->set_element_type(F16);
1590 p.mutable_shape()->clear_dimensions();
1591 p.mutable_shape()->add_dimensions(4);
1592 SetDefaultLayoutOnProto(p.mutable_shape());
1593 p.clear_f16s();
1594 p.set_f16s(half_vals, 8);
1595 TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1596 auto r = literal.data<half>();
1597 ASSERT_EQ(4, r.size());
1598 EXPECT_EQ(h1, r[0]);
1599 EXPECT_EQ(h2, r[1]);
1600 EXPECT_EQ(h2, r[2]);
1601 EXPECT_EQ(h1, r[3]);
1602 }
1603
TEST_F(LiteralUtilTest,CopyFromProto_u16)1604 TEST_F(LiteralUtilTest, CopyFromProto_u16) {
1605 uint16_t u1(0xabcd);
1606 uint16_t u2(0x1234);
1607
1608 const unsigned char uint16_vals[8] = {0xcd, 0xab, 0x34, 0x12,
1609 0x34, 0x12, 0xcd, 0xab};
1610 LiteralProto p;
1611 p.mutable_shape()->set_element_type(U16);
1612 p.mutable_shape()->clear_dimensions();
1613 p.mutable_shape()->add_dimensions(4);
1614 SetDefaultLayoutOnProto(p.mutable_shape());
1615 p.clear_u16s();
1616 p.set_u16s(uint16_vals, 8);
1617 TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1618 auto r = literal.data<uint16_t>();
1619 ASSERT_EQ(4, r.size());
1620 EXPECT_EQ(u1, r[0]);
1621 EXPECT_EQ(u2, r[1]);
1622 EXPECT_EQ(u2, r[2]);
1623 EXPECT_EQ(u1, r[3]);
1624 }
1625
TEST_F(LiteralUtilTest,LiteralDynamicSliceTest)1626 TEST_F(LiteralUtilTest, LiteralDynamicSliceTest) {
1627 auto scalar = LiteralUtil::CreateR0<float>(1.0);
1628 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1629 auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1630 auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1631 Literal nil(ShapeUtil::MakeNil());
1632
1633 EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
1634 EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
1635 EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
1636 EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
1637 EXPECT_EQ(LiteralSlice(nil, {}), nil);
1638
1639 EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
1640 EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
1641
1642 EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
1643 EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
1644 EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
1645 EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
1646 }
1647
TEST_F(LiteralUtilTest,MutatingLiteralSlice)1648 TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
1649 auto scalar = LiteralUtil::CreateR0<float>(1.0);
1650 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1651 auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1652 auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1653 // Verify that changing the underlying data beneath the view changes the
1654 // data of the view itself.
1655 const auto nested_tuple_view = LiteralSlice(nested_tuple);
1656 EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
1657 1.0f);
1658 EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
1659 /*shape_index=*/{0, 0}),
1660 1.0f);
1661 nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
1662 EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
1663 555.0f);
1664 EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
1665 /*shape_index=*/{0, 0}),
1666 555.0f);
1667 }
1668
TEST_F(LiteralUtilTest,LiteralSliceOfALiteralSlice)1669 TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
1670 auto scalar = LiteralUtil::CreateR0<float>(1.0);
1671 auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1672 auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1673 auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1674
1675 const auto nested_tuple_view = LiteralSlice(nested_tuple);
1676 const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
1677 const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
1678 EXPECT_EQ(matrix_view,
1679 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
1680 }
1681
TEST_F(LiteralUtilTest,BorrowingLiteralFromOneBufferPtr)1682 TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
1683 std::vector<int64_t> int64_values = {1, 2, 3};
1684 const Shape literal_shape = ShapeUtil::MakeShape(S64, {3});
1685
1686 BorrowingLiteral literal(reinterpret_cast<const char*>(int64_values.data()),
1687 literal_shape);
1688
1689 EXPECT_EQ(literal.Get<int64_t>({0}), 1);
1690 EXPECT_EQ(literal.Get<int64_t>({1}), 2);
1691 EXPECT_EQ(literal.Get<int64_t>({2}), 3);
1692 }
1693
TEST_F(LiteralUtilTest,BorrowingLiteralFromMultipleBufferPtrs)1694 TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
1695 std::vector<int64_t> one_two_three = {1, 2, 3};
1696 const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3});
1697
1698 std::vector<int64_t> hundred = {100};
1699 const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1});
1700
1701 std::vector<const char*> src_buf_ptrs;
1702 src_buf_ptrs.emplace_back(
1703 reinterpret_cast<const char*>(one_two_three.data()));
1704 src_buf_ptrs.emplace_back(reinterpret_cast<const char*>(hundred.data()));
1705 auto literal_tuple = BorrowingLiteral(
1706 src_buf_ptrs,
1707 ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape}));
1708
1709 EXPECT_EQ(
1710 literal_tuple.Get<int64_t>(/*multi_index=*/{0}, /*shape_index=*/{0}), 1);
1711 EXPECT_EQ(
1712 literal_tuple.Get<int64_t>(/*multi_index=*/{0}, /*shape_index=*/{1}),
1713 100);
1714
1715 EXPECT_EQ(
1716 literal_tuple.Get<int64_t>(/*multi_index=*/{1}, /*shape_index=*/{0}), 2);
1717
1718 EXPECT_EQ(
1719 literal_tuple.Get<int64_t>(/*multi_index=*/{2}, /*shape_index=*/{0}), 3);
1720 }
1721
TEST_F(LiteralUtilTest,LiteralMove)1722 TEST_F(LiteralUtilTest, LiteralMove) {
1723 Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1724 Literal literal(std::move(matrix));
1725
1726 EXPECT_TRUE(
1727 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
1728 EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
1729 EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
1730 EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
1731 EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
1732 }
1733
TEST_F(LiteralUtilTest,DecomposeTuple)1734 TEST_F(LiteralUtilTest, DecomposeTuple) {
1735 Literal nil_literal(ShapeUtil::MakeNil());
1736 Literal inner_elements[] = {
1737 LiteralUtil::CreateR0<int32_t>(42),
1738 LiteralUtil::CreateR1<double>({23.0, 44.0}),
1739 };
1740 Literal tuple_elements[] = {
1741 LiteralUtil::CreateR2<int32_t>({{1, 2}, {3, 4}}),
1742 LiteralUtil::MakeTuple(
1743 {&inner_elements[0], &inner_elements[1], &nil_literal}),
1744 };
1745 Literal nested_tuple = LiteralUtil::MakeTuple(
1746 {&tuple_elements[0], &tuple_elements[1], &nil_literal});
1747
1748 EXPECT_FALSE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
1749 std::vector<Literal> elements = nested_tuple.DecomposeTuple();
1750 EXPECT_TRUE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
1751
1752 ASSERT_EQ(elements.size(), 3);
1753
1754 EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(),
1755 ShapeUtil::MakeShape(S32, {2, 2})));
1756 EXPECT_EQ(elements[0].Get<int32_t>({0, 0}), 1);
1757 EXPECT_EQ(elements[0].Get<int32_t>({0, 1}), 2);
1758 EXPECT_EQ(elements[0].Get<int32_t>({1, 0}), 3);
1759 EXPECT_EQ(elements[0].Get<int32_t>({1, 1}), 4);
1760
1761 EXPECT_TRUE(ShapeUtil::Compatible(
1762 elements[1].shape(),
1763 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}),
1764 ShapeUtil::MakeShape(F64, {2}),
1765 ShapeUtil::MakeNil()})));
1766 EXPECT_EQ(elements[1].Get<int32_t>({}, /*shape_index=*/{0}), 42);
1767 EXPECT_EQ(elements[1].Get<double>({0}, /*shape_index=*/{1}), 23.0);
1768 EXPECT_EQ(elements[1].Get<double>({1}, /*shape_index=*/{1}), 44.0);
1769
1770 EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil()));
1771 }
1772
TEST_F(LiteralUtilTest,DecomposeEmptyTuple)1773 TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
1774 Literal nil_literal(ShapeUtil::MakeNil());
1775 std::vector<Literal> elements = nil_literal.DecomposeTuple();
1776 EXPECT_EQ(elements.size(), 0);
1777 }
1778
TEST_F(LiteralUtilTest,MoveIntoTuple)1779 TEST_F(LiteralUtilTest, MoveIntoTuple) {
1780 std::vector<Literal> elements;
1781 elements.push_back(LiteralUtil::CreateR0<float>(1.0));
1782 elements.push_back(LiteralUtil::CreateR1<int32_t>({4, 8}));
1783 std::vector<Literal> inner_elements;
1784 inner_elements.push_back(LiteralUtil::CreateR0<int32_t>(42));
1785 inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
1786 elements.push_back(
1787 LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
1788
1789 Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
1790 ASSERT_TRUE(literal.shape().IsTuple());
1791 ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
1792
1793 EXPECT_EQ(literal.Get<float>({}, /*shape_index=*/{0}), 1.0);
1794 EXPECT_EQ(literal.Get<int32_t>({0}, /*shape_index=*/{1}), 4);
1795 EXPECT_EQ(literal.Get<int32_t>({1}, /*shape_index=*/{1}), 8);
1796 EXPECT_EQ(literal.Get<int32_t>({}, /*shape_index=*/{2, 0}), 42);
1797 EXPECT_EQ(literal.Get<double>({0}, /*shape_index=*/{2, 1}), 23.0);
1798 EXPECT_EQ(literal.Get<double>({1}, /*shape_index=*/{2, 1}), 44.0);
1799
1800 for (const Literal& element : elements) {
1801 EXPECT_TRUE(ShapeUtil::IsEmptyTuple(element.shape()));
1802 }
1803 }
1804
TEST_F(LiteralUtilTest,MoveIntoEmptyTuple)1805 TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) {
1806 Literal literal = Literal::MoveIntoTuple({});
1807 ASSERT_TRUE(literal.shape().IsTuple());
1808 EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
1809 }
1810
TEST_F(LiteralUtilTest,LiteralMoveAssignment)1811 TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
1812 Literal literal;
1813 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
1814
1815 Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1816 literal = std::move(matrix);
1817
1818 EXPECT_TRUE(
1819 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
1820 EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
1821 EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
1822 EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
1823 EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
1824 }
1825
TEST_F(LiteralUtilTest,LiteralSliceCopy)1826 TEST_F(LiteralUtilTest, LiteralSliceCopy) {
1827 Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1828 const auto matrix_view = LiteralSlice(matrix);
1829 LiteralSlice matrix_view_copy(matrix_view);
1830
1831 EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
1832 EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
1833 EXPECT_EQ(matrix_view_copy.Get<float>({1, 0}), 3.0);
1834 EXPECT_EQ(matrix_view_copy.Get<float>({1, 1}), 4.0);
1835 }
1836
TEST_F(LiteralUtilTest,GetSetTuple)1837 TEST_F(LiteralUtilTest, GetSetTuple) {
1838 Literal elements[] = {
1839 LiteralUtil::CreateR0<float>(42.0),
1840 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
1841 };
1842 auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
1843 EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
1844 tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
1845 EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
1846
1847 EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
1848 tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
1849 EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
1850 -4.0);
1851 }
1852
TEST_F(LiteralUtilTest,CreateFromShapeZeroInitialized)1853 TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
1854 // Literals constructed using CreateFromShape should be zero initialized.
1855 Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
1856 EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
1857 EXPECT_TRUE(scalar_f32.IsAll(0));
1858
1859 Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
1860 EXPECT_EQ(vector_s32.Get<int32_t>({0}), 0);
1861 EXPECT_EQ(vector_s32.Get<int32_t>({1}), 0);
1862 EXPECT_EQ(vector_s32.Get<int32_t>({2}), 0);
1863 EXPECT_TRUE(vector_s32.IsAll(0));
1864
1865 Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
1866 {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
1867 ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}),
1868 ShapeUtil::MakeShape(C128, {})}));
1869
1870 EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
1871 EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
1872 EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
1873 EXPECT_EQ(tuple.Get<uint64_t>({0, 0}, {2}), 0);
1874 EXPECT_EQ(tuple.Get<uint64_t>({1, 0}, {2}), 0);
1875 EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
1876 EXPECT_EQ(tuple.Get<complex128>({}, {4}), complex128(0.0, 0.0));
1877 }
1878
TEST_F(LiteralUtilTest,ProtoRoundTrip)1879 TEST_F(LiteralUtilTest, ProtoRoundTrip) {
1880 // Test serializing then deserializing a Literal through a proto.
1881 auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
1882 auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
1883 auto vector_int8 = LiteralUtil::CreateR1<int8_t>({-128, 0, 2, 4, 7, 56, 127});
1884 auto vector_uint8 = LiteralUtil::CreateR1<uint8_t>({128, 0, 2, 56, 127, 255});
1885 auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
1886 auto vector_c128 =
1887 LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
1888 auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
1889 {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
1890 auto vector_half =
1891 LiteralUtil::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
1892 auto matrix_pred =
1893 LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
1894 auto tuple = LiteralUtil::MakeTuple(
1895 {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
1896 Literal nil_literal(ShapeUtil::MakeNil());
1897 auto nested_tuple =
1898 LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
1899
1900 auto to_from_proto = [](const Literal& literal) -> Literal {
1901 return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
1902 };
1903
1904 EXPECT_EQ(one_f32, to_from_proto(one_f32));
1905 EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
1906 EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
1907 EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
1908 EXPECT_EQ(vector_c128, to_from_proto(vector_c128));
1909 EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
1910 EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
1911 EXPECT_EQ(tuple, to_from_proto(tuple));
1912 EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
1913 EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
1914
1915 EXPECT_NE(one_f32, two_f32);
1916 EXPECT_NE(one_f32, to_from_proto(two_f32));
1917 }
1918
TEST_F(LiteralUtilTest,InvalidProtoNoValues)1919 TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
1920 // Proto contains a shape, but no values.
1921 LiteralProto proto;
1922 *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
1923 Status status = Literal::CreateFromProto(proto).status();
1924 ASSERT_FALSE(status.ok());
1925 EXPECT_THAT(status.error_message(),
1926 HasSubstr("Expected 3 elements in LiteralProto"));
1927 }
1928
TEST_F(LiteralUtilTest,ValidProtoNoValues)1929 TEST_F(LiteralUtilTest, ValidProtoNoValues) {
1930 // Proto contains a shape, but no values.
1931 LiteralProto proto;
1932 *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
1933 Status status =
1934 Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
1935 .status();
1936 EXPECT_TRUE(status.ok());
1937 }
1938
TEST_F(LiteralUtilTest,ValidProtoWithClearedValues)1939 TEST_F(LiteralUtilTest, ValidProtoWithClearedValues) {
1940 auto literal = LiteralUtil::CreateR1<bool>({true, false, true});
1941 LiteralProto proto = literal.ToProto();
1942 EXPECT_EQ(proto.preds_size(), 3);
1943
1944 // Clear values.
1945 proto.clear_preds();
1946 EXPECT_EQ(proto.preds_size(), 0);
1947 Status status =
1948 Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
1949 .status();
1950 EXPECT_TRUE(status.ok());
1951 }
1952
TEST_F(LiteralUtilTest,InvalidProtoNoShape)1953 TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
1954 // Proto contains values, but no shape.
1955 LiteralProto proto;
1956 proto.add_preds(false);
1957 proto.add_preds(true);
1958 proto.add_preds(false);
1959 Status status = Literal::CreateFromProto(proto).status();
1960 ASSERT_FALSE(status.ok());
1961 EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
1962 }
1963
TEST_F(LiteralUtilTest,InvalidProtoWrongContainer)1964 TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
1965 // Proto contains values in wrong container.
1966 LiteralProto proto;
1967 *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
1968 proto.add_preds(false);
1969 proto.add_preds(true);
1970 proto.add_preds(false);
1971 Status status = Literal::CreateFromProto(proto).status();
1972 ASSERT_FALSE(status.ok());
1973 EXPECT_THAT(status.error_message(),
1974 HasSubstr("Expected 3 elements in LiteralProto"));
1975 }
1976
TEST_F(LiteralUtilTest,InvalidProtoTooFewValues)1977 TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
1978 // Proto contains too few values.
1979 LiteralProto proto;
1980 *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto();
1981 proto.add_f32s(1.0);
1982 proto.add_f32s(2.0);
1983 proto.add_f32s(3.0);
1984 Status status = Literal::CreateFromProto(proto).status();
1985 ASSERT_FALSE(status.ok());
1986 EXPECT_THAT(status.error_message(),
1987 HasSubstr("Expected 84 elements in LiteralProto"));
1988 }
1989
TEST_F(LiteralUtilTest,InvalidProtoTooManyValues)1990 TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
1991 // Proto contains too many values.
1992 LiteralProto proto;
1993 *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto();
1994 proto.add_s32s(42);
1995 proto.add_s32s(-10);
1996 proto.add_s32s(100);
1997 Status status = Literal::CreateFromProto(proto).status();
1998 ASSERT_FALSE(status.ok());
1999 EXPECT_THAT(status.error_message(),
2000 HasSubstr("Expected 2 elements in LiteralProto"));
2001 }
2002
TEST_F(LiteralUtilTest,InvalidProtoMissingLayout)2003 TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
2004 // Proto shape missing layout.
2005 LiteralProto proto;
2006 *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto();
2007 proto.mutable_shape()->clear_layout();
2008 proto.add_preds(true);
2009 proto.add_preds(false);
2010 proto.add_preds(true);
2011 proto.add_preds(false);
2012 Status status = Literal::CreateFromProto(proto).status();
2013 ASSERT_FALSE(status.ok());
2014 EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
2015 }
2016
TEST_F(LiteralUtilTest,InvalidProtoTooFewTupleElements)2017 TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
2018 // Proto has the too few tuple elements.
2019 LiteralProto proto;
2020 *proto.mutable_shape() =
2021 ShapeUtil::MakeTupleShape(
2022 {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
2023 .ToProto();
2024 LiteralProto* element0 = proto.add_tuple_literals();
2025 *element0->mutable_shape() =
2026 ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
2027 element0->add_preds(false);
2028 element0->add_preds(true);
2029
2030 Status status = Literal::CreateFromProto(proto).status();
2031 ASSERT_FALSE(status.ok());
2032 EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
2033 }
2034
TEST_F(LiteralUtilTest,InvalidProtoTooManyTupleElements)2035 TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
2036 // Proto has the too many tuple elements.
2037 LiteralProto proto;
2038 *proto.mutable_shape() =
2039 ShapeUtil::MakeTupleShape(
2040 {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
2041 .ToProto();
2042 LiteralProto* element0 = proto.add_tuple_literals();
2043 *element0->mutable_shape() =
2044 ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
2045 element0->add_preds(false);
2046 element0->add_preds(true);
2047 LiteralProto* element1 = proto.add_tuple_literals();
2048 *element1->mutable_shape() =
2049 ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto();
2050 element1->add_f32s(42.0);
2051 LiteralProto* element2 = proto.add_tuple_literals();
2052 *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto();
2053 element2->add_f32s(123.0);
2054
2055 Status status = Literal::CreateFromProto(proto).status();
2056 ASSERT_FALSE(status.ok());
2057 EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
2058 }
2059
TEST_F(LiteralUtilTest,BroadcastVectorToMatrix0)2060 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
2061 Literal literal = LiteralUtil::CreateR1<int64_t>({1, 2});
2062 TF_ASSERT_OK_AND_ASSIGN(
2063 Literal broadcasted_literal,
2064 literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
2065 /*dimensions=*/{0}));
2066 EXPECT_EQ(broadcasted_literal,
2067 LiteralUtil::CreateR2<int64_t>({{1, 1}, {2, 2}}));
2068 }
2069
TEST_F(LiteralUtilTest,BroadcastVectorToMatrix1)2070 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
2071 Literal literal = LiteralUtil::CreateR1<int64_t>({1, 2});
2072 TF_ASSERT_OK_AND_ASSIGN(
2073 Literal broadcasted_literal,
2074 literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
2075 /*dimensions=*/{1}));
2076 EXPECT_EQ(broadcasted_literal,
2077 LiteralUtil::CreateR2<int64_t>({{1, 2}, {1, 2}}));
2078 }
2079
TEST_F(LiteralUtilTest,BroadcastScalarToMatrix)2080 TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
2081 Literal literal = LiteralUtil::CreateR0<int32_t>(9);
2082 TF_ASSERT_OK_AND_ASSIGN(
2083 Literal broadcasted_literal,
2084 literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
2085 /*dimensions=*/{}));
2086 EXPECT_EQ(broadcasted_literal,
2087 LiteralUtil::CreateR2<int32_t>({{9, 9}, {9, 9}}));
2088 }
2089
TEST_F(LiteralUtilTest,DynamicBroadcast)2090 TEST_F(LiteralUtilTest, DynamicBroadcast) {
2091 Literal literal = LiteralUtil::CreateR1<int64_t>({1, 2});
2092 literal.SetDynamicSize(0, 1);
2093 TF_ASSERT_OK_AND_ASSIGN(
2094 Literal broadcasted_literal,
2095 literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
2096 /*dimensions=*/{1}));
2097 EXPECT_EQ(broadcasted_literal, LiteralUtil::CreateR2<int64_t>({{1}, {1}}));
2098 EXPECT_EQ(broadcasted_literal.GetDynamicSize(1), 1);
2099 }
2100
TEST_F(LiteralUtilTest,GetAsComplex128)2101 TEST_F(LiteralUtilTest, GetAsComplex128) {
2102 complex128 value = {1, 0};
2103 Literal c1 = LiteralUtil::CreateR0<complex128>(value);
2104 EXPECT_EQ(*c1.GetAsComplex128({}), value);
2105 Literal c2 = LiteralUtil::CreateR0<double>(1);
2106 EXPECT_EQ(*c2.GetAsComplex128({}), value);
2107 complex64 float_value = {1, 0};
2108 Literal c4 = LiteralUtil::CreateR0<complex64>(float_value);
2109 EXPECT_EQ(*c4.GetAsComplex128({}), value);
2110 complex128 other_value = {1, 2};
2111 Literal c5 = LiteralUtil::CreateR0<complex128>(other_value);
2112 EXPECT_EQ(*c5.GetAsComplex128({}), other_value);
2113 Literal c6 = LiteralUtil::CreateR0<int64_t>(1);
2114 EXPECT_FALSE(c6.GetAsComplex128({}).has_value());
2115 }
2116
TEST_F(LiteralUtilTest,SliceOnBool)2117 TEST_F(LiteralUtilTest, SliceOnBool) {
2118 Literal c1 = LiteralUtil::CreateR1<bool>({true, true, false});
2119 EXPECT_EQ(c1, c1.Slice({0}, {3}));
2120 }
2121
TEST_F(LiteralUtilTest,IsEqualAt)2122 TEST_F(LiteralUtilTest, IsEqualAt) {
2123 double val_double = 10.0;
2124 int val_integral = 10;
2125 Literal c1 = LiteralUtil::CreateR0<int>(10);
2126 EXPECT_TRUE(c1.IsEqualAt({}, val_double));
2127 EXPECT_TRUE(c1.IsEqualAt({}, val_integral));
2128 Literal c2 = LiteralUtil::CreateR0<double>(10);
2129 EXPECT_TRUE(c2.IsEqualAt({}, val_double));
2130 EXPECT_TRUE(c2.IsEqualAt({}, val_integral));
2131 complex128 val_complex = {10, 0};
2132 EXPECT_TRUE(c2.IsEqualAt({}, val_complex));
2133 EXPECT_TRUE(c1.IsEqualAt({}, val_complex));
2134 Literal c3 = LiteralUtil::CreateR0<complex128>(val_complex);
2135 EXPECT_TRUE(c3.IsEqualAt({}, val_double));
2136 EXPECT_TRUE(c3.IsEqualAt({}, val_integral));
2137 EXPECT_TRUE(c3.IsEqualAt({}, val_complex));
2138 EXPECT_FALSE(c3.IsEqualAt({}, std::numeric_limits<double>::infinity()));
2139 complex128 val_true_complex = {10, 3};
2140 complex64 val_smaller_complex = {10, 3};
2141 Literal c4 = LiteralUtil::CreateR0<complex128>(val_true_complex);
2142 EXPECT_TRUE(c4.IsEqualAt({}, val_true_complex));
2143 EXPECT_TRUE(c4.IsEqualAt({}, val_smaller_complex));
2144 }
2145
TEST_F(LiteralUtilTest,CreateFromShapeWithUnknownLeafArrays)2146 TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) {
2147 Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays(
2148 ShapeUtil::MakeShape(F32, {4, 4}));
2149 EXPECT_FALSE(c1.IsKnown());
2150 }
2151
TEST_F(LiteralUtilTest,CreatePartiallyKnownTuple)2152 TEST_F(LiteralUtilTest, CreatePartiallyKnownTuple) {
2153 Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays(
2154 ShapeUtil::MakeShape(F32, {4, 4}));
2155 Literal c2 = LiteralUtil::CreateR0<int>(10);
2156 Literal c3 = LiteralUtil::MakeTuple({&c1, &c2});
2157 Literal c4 = LiteralUtil::CreateR0<int>(100);
2158 Literal c5 = LiteralUtil::MakeTuple({&c4, &c3});
2159 EXPECT_FALSE(c5.IsKnown());
2160 }
2161
TEST_F(LiteralUtilTest,CopyFromPartiallyKnownTuple)2162 TEST_F(LiteralUtilTest, CopyFromPartiallyKnownTuple) {
2163 Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays(
2164 ShapeUtil::MakeShape(F32, {4, 4}));
2165 Literal c2 = LiteralUtil::CreateR0<int>(10);
2166 Literal c3 = LiteralUtil::MakeTuple({&c1, &c2});
2167 Literal c4 = LiteralUtil::CreateR0<int>(100);
2168 Literal c5 = LiteralUtil::MakeTuple({&c4, &c3});
2169 Literal c6 = Literal::CreateFromShape(c5.shape());
2170 TF_ASSERT_OK(
2171 c6.CopyFrom(c5, /*dest_shape_index=*/{1}, /*src_shape_index=*/{1}));
2172 EXPECT_FALSE(c6.IsKnown());
2173 }
2174
TEST_F(LiteralUtilTest,CopyFromPartiallyKnownTupleUnknownTupleElement)2175 TEST_F(LiteralUtilTest, CopyFromPartiallyKnownTupleUnknownTupleElement) {
2176 Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays(
2177 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4, 4}),
2178 ShapeUtil::MakeShape(F32, {4, 4})}));
2179 Literal c2 = LiteralUtil::CreateR0<int>(10);
2180 Literal c3 = LiteralUtil::MakeTuple({&c1, &c2});
2181 Literal c4 = LiteralUtil::CreateR0<int>(100);
2182 Literal c5 = LiteralUtil::MakeTuple({&c4, &c3});
2183 Literal c6 = Literal::CreateFromShape(c5.shape());
2184 Literal c1_copy = Literal::CreateFromShape(c1.shape());
2185 Literal c2_copy = Literal::CreateFromShape(c2.shape());
2186 TF_ASSERT_OK(
2187 c6.CopyFrom(c5, /*dest_shape_index=*/{1}, /*src_shape_index=*/{1}));
2188 TF_ASSERT_OK(c1_copy.CopyFrom(c6, /*dest_shape_index=*/{},
2189 /*src_shape_index=*/{1, 0}));
2190 TF_ASSERT_OK(c2_copy.CopyFrom(c6, /*dest_shape_index=*/{},
2191 /*src_shape_index=*/{1, 1}));
2192 EXPECT_FALSE(c6.IsKnown());
2193 EXPECT_FALSE(c1_copy.IsKnown());
2194 EXPECT_TRUE(c2_copy.IsKnown());
2195 }
2196
2197 } // namespace
2198 } // namespace xla
2199