xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/literal_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal.h"
17 
18 #include <limits>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/xla/array3d.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 
34 namespace xla {
35 namespace {
36 
37 using ::testing::ElementsAre;
38 using ::testing::HasSubstr;
39 
40 class LiteralUtilTest : public ::testing::Test {
41  protected:
LiteralUtilTest()42   LiteralUtilTest() {
43     Array4D<float> arr4d({
44         // clang-format off
45       {  // i0=0
46           {  // i1=0
47               {1, 2, 3},  // i2=0
48               {4, 5, 6},  // i2=1
49               {7, 8, 9},  // i2=2
50           },
51           {  // i1=1
52               {11, 12, 13},
53               {14, 15, 16},
54               {17, 18, 19},
55           },
56       },
57       {  // i0=1
58           {  // i1=0
59               {101, 102, 103},
60               {104, 105, 106},
61               {107, 108, 109},
62           },
63           {  // i1=1
64               {201, 202, 203},  // i2=0
65               {204, 205, 206},  // i2=1
66               {207, 208, 209},  // i2=2
67           },
68       },
69         // clang-format on
70     });
71 
72     layout_r2_dim0major_ = LayoutUtil::MakeLayout({1, 0});
73     layout_r2_dim0minor_ = LayoutUtil::MakeLayout({0, 1});
74     layout_r3_dim0major_ = LayoutUtil::MakeLayout({2, 1, 0});
75     layout_r3_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2});
76     layout_r4_dim0major_ = LayoutUtil::MakeLayout({3, 2, 1, 0});
77     layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3});
78 
79     literal_r4_2x2x3x3_dim0major_ =
80         LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
81                                                           layout_r4_dim0major_);
82     literal_r4_2x2x3x3_dim0minor_ =
83         LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
84                                                           layout_r4_dim0minor_);
85   }
86 
87   Layout layout_r2_dim0major_;
88   Layout layout_r2_dim0minor_;
89   Layout layout_r3_dim0major_;
90   Layout layout_r3_dim0minor_;
91   Layout layout_r4_dim0major_;
92   Layout layout_r4_dim0minor_;
93   Literal literal_r4_2x2x3x3_dim0major_;
94   Literal literal_r4_2x2x3x3_dim0minor_;
95 };
96 
TEST_F(LiteralUtilTest,LiteralScalarToString)97 TEST_F(LiteralUtilTest, LiteralScalarToString) {
98   auto true_lit = LiteralUtil::CreateR0<bool>(true);
99   EXPECT_EQ("pred[] true", true_lit.ToString());
100 
101   auto false_lit = LiteralUtil::CreateR0<bool>(false);
102   EXPECT_EQ("pred[] false", false_lit.ToString());
103 
104   auto u32_lit = LiteralUtil::CreateR0<uint32_t>(42);
105   EXPECT_EQ("u32[] 42", u32_lit.ToString());
106 
107   auto s32_lit = LiteralUtil::CreateR0<int32_t>(-999);
108   EXPECT_EQ("s32[] -999", s32_lit.ToString());
109 
110   auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
111   EXPECT_EQ("f32[] 3.14", f32_lit.ToString());
112 
113   auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
114   EXPECT_EQ("f16[] 0.5", f16_lit.ToString());
115 
116   auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
117   EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString());
118 
119   auto c128_lit = LiteralUtil::CreateR0<complex128>({3.14, 2.78});
120   EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString());
121 
122   auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
123   EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString());
124 
125   // 3.14 will be rounded to 3.140625 in bfloat16 format.
126   auto bf16_lit_truncated =
127       LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
128   ASSERT_EQ("bf16[] 3.141", bf16_lit_truncated.ToString());
129 
130   auto bf16_lit_truncated2 =
131       LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
132   EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString());
133 }
134 
TEST_F(LiteralUtilTest,LiteralVectorToString)135 TEST_F(LiteralUtilTest, LiteralVectorToString) {
136   auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
137   EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString());
138 }
139 
TEST_F(LiteralUtilTest,R2ToString)140 TEST_F(LiteralUtilTest, R2ToString) {
141   const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
142   const std::string expected = R"(s32[3,2] {
143   { 1, 2 },
144   { 3, 4 },
145   { 5, 6 }
146 })";
147   EXPECT_EQ(expected, literal.ToString());
148 }
149 
TEST_F(LiteralUtilTest,R2DynamicToString)150 TEST_F(LiteralUtilTest, R2DynamicToString) {
151   auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
152   literal.SetDynamicSize(0, {}, 2);
153   const std::string expected = R"(s32[<=3,2](2,2) {
154   { 1, 2 },
155   { 3, 4 }
156 })";
157   EXPECT_EQ(expected, literal.ToString());
158 
159   // A Less trivial case where the memory layout is not consecutive.
160   auto literal2 = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}});
161   literal2.SetDynamicSize(1, {}, 2);
162   const std::string expected2 = R"(s32[2,<=3](2,2) {
163   { 1, 2 },
164   { 4, 5 }
165 })";
166   EXPECT_EQ(expected2, literal2.ToString());
167 }
168 
TEST_F(LiteralUtilTest,R3ToString)169 TEST_F(LiteralUtilTest, R3ToString) {
170   const auto literal =
171       LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
172   const std::string expected = R"(s32[3,2,1] {
173 {
174   {1},
175   {2}
176 },
177 {
178   {3},
179   {4}
180 },
181 {
182   {5},
183   {6}
184 }
185 })";
186   EXPECT_EQ(expected, literal.ToString());
187 }
188 
TEST_F(LiteralUtilTest,R6ToString)189 TEST_F(LiteralUtilTest, R6ToString) {
190   const auto literal =
191       LiteralUtil::CreateFromDimensions(S32, {2, 2, 1, 1, 1, 2});
192   const std::string expected = R"(s32[2,2,1,1,1,2] {
193 { /*i0=0*/
194 { /*i1=0*/
195 { /*i2=0*/
196 { /*i3=0*/
197   { 0, 0 }
198 }
199 }
200 },
201 { /*i1=1*/
202 { /*i2=0*/
203 { /*i3=0*/
204   { 0, 0 }
205 }
206 }
207 }
208 },
209 { /*i0=1*/
210 { /*i1=0*/
211 { /*i2=0*/
212 { /*i3=0*/
213   { 0, 0 }
214 }
215 }
216 },
217 { /*i1=1*/
218 { /*i2=0*/
219 { /*i3=0*/
220   { 0, 0 }
221 }
222 }
223 }
224 }
225 })";
226   EXPECT_EQ(expected, literal.ToString());
227 }
228 
TEST_F(LiteralUtilTest,TupleToString)229 TEST_F(LiteralUtilTest, TupleToString) {
230   auto scalar = LiteralUtil::CreateR0<float>(1.0);
231   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
232   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
233   const std::string expected = R"((
234 f32[] 1,
235 f32[2,2] {
236   { 1, 2 },
237   { 3, 4 }
238 }
239 ))";
240   EXPECT_EQ(expected, tuple.ToString());
241 }
242 
TEST_F(LiteralUtilTest,CreateR3FromArray3d)243 TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
244   // clang-format off
245   Array3D<float> array_3d({
246     {{1.0f, 2.0f},
247      {3.0f, 4.0f},
248      {5.0f, 6.0f}},
249     {{7.0f, 8.0f},
250      {9.0f, 10.0f},
251      {11.0f, 12.0f}},
252   });
253   // clang-format on
254 
255   auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
256   EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
257   std::string result = literal.ToString();
258   const std::string expected = R"(f32[2,3,2] {
259 {
260   { 1, 2 },
261   { 3, 4 },
262   { 5, 6 }
263 },
264 {
265   { 7, 8 },
266   { 9, 10 },
267   { 11, 12 }
268 }
269 })";
270   EXPECT_EQ(expected, result);
271 }
272 
TEST_F(LiteralUtilTest,LiteralR4F32ProjectedStringifies)273 TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
274   // clang-format off
275   auto literal = LiteralUtil::CreateR4Projected<float>({
276     {1, 2},
277     {1001, 1002},
278     {2001, 2002},
279   }, /*projection_p=*/1, /*projection_z=*/2);
280   // clang-format on
281   EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
282   std::string result = literal.ToString();
283   const std::string expected = R"(f32[1,2,3,2] {
284 { /*i0=0*/
285 { /*i1=0*/
286   { 1, 2 },
287   { 1001, 1002 },
288   { 2001, 2002 }
289 },
290 { /*i1=1*/
291   { 1, 2 },
292   { 1001, 1002 },
293   { 2001, 2002 }
294 }
295 }
296 })";
297   EXPECT_EQ(expected, result);
298 }
299 
TEST_F(LiteralUtilTest,LiteralR4F32Stringifies)300 TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
301   EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
302               ElementsAre(2, 2, 3, 3));
303   std::string result = literal_r4_2x2x3x3_dim0major_.ToString();
304   const std::string expected = R"(f32[2,2,3,3] {
305 { /*i0=0*/
306 { /*i1=0*/
307   { 1, 2, 3 },
308   { 4, 5, 6 },
309   { 7, 8, 9 }
310 },
311 { /*i1=1*/
312   { 11, 12, 13 },
313   { 14, 15, 16 },
314   { 17, 18, 19 }
315 }
316 },
317 { /*i0=1*/
318 { /*i1=0*/
319   { 101, 102, 103 },
320   { 104, 105, 106 },
321   { 107, 108, 109 }
322 },
323 { /*i1=1*/
324   { 201, 202, 203 },
325   { 204, 205, 206 },
326   { 207, 208, 209 }
327 }
328 }
329 })";
330   EXPECT_EQ(expected, result);
331 }
332 
TEST_F(LiteralUtilTest,EachCellR2F32)333 TEST_F(LiteralUtilTest, EachCellR2F32) {
334   // clang-format off
335   auto literal = LiteralUtil::CreateR2<float>({
336     {3.1f, 4.2f},
337     {9.3f, 12.4f},
338   });
339   // clang-format on
340   std::vector<std::tuple<int64_t, int64_t, std::string>> seen;
341   literal.EachCellAsString(
342       [&seen](absl::Span<const int64_t> indices, const std::string& value) {
343         seen.emplace_back(indices[0], indices[1], value);
344       });
345 
346   using Elem = std::tuple<int64_t, int64_t, std::string>;
347   std::vector<Elem> expected = {Elem(0, 0, "3.1"), Elem(0, 1, "4.2"),
348                                 Elem(1, 0, "9.3"), Elem(1, 1, "12.4")};
349   EXPECT_EQ(expected, seen);
350 }
351 
TEST_F(LiteralUtilTest,ScalarEquality)352 TEST_F(LiteralUtilTest, ScalarEquality) {
353   // Test equality with scalars.
354   auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
355   auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
356 
357   EXPECT_EQ(f32_42, f32_42);
358   EXPECT_EQ(f32_42, f32_42_clone);
359 
360   auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
361   EXPECT_NE(f32_42, f32_123);
362 
363   auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
364   EXPECT_NE(f32_42, f64_42);
365 }
366 
TEST_F(LiteralUtilTest,NonScalarEquality)367 TEST_F(LiteralUtilTest, NonScalarEquality) {
368   // Test equality with nonscalars.
369   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
370   auto matrix_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
371   auto matrix_different =
372       LiteralUtil::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
373   auto vector_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
374   auto scalar = LiteralUtil::CreateR0<float>(1.0);
375   Literal nil(ShapeUtil::MakeNil());
376 
377   EXPECT_EQ(matrix, matrix);
378   EXPECT_EQ(matrix, matrix_clone);
379   EXPECT_NE(matrix, matrix_different);
380   EXPECT_NE(matrix, vector_literal);
381   EXPECT_NE(matrix, scalar);
382   EXPECT_NE(matrix, nil);
383   EXPECT_EQ(nil, nil);
384 }
385 
TEST_F(LiteralUtilTest,TokenEquality)386 TEST_F(LiteralUtilTest, TokenEquality) {
387   auto token0 = LiteralUtil::CreateToken();
388   auto token1 = LiteralUtil::CreateToken();
389   auto scalar = LiteralUtil::CreateR0<float>(1.0);
390 
391   EXPECT_EQ(token0, token1);
392   EXPECT_NE(token0, scalar);
393 
394   EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
395             LiteralUtil::MakeTuple({&token0}));
396   EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
397             LiteralUtil::MakeTuple({&token1, &scalar}));
398   EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
399             LiteralUtil::MakeTuple({&scalar, &token1}));
400 }
401 
TEST_F(LiteralUtilTest,DifferentLayoutEquality)402 TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
403   // Test equality with literals which have different layouts.
404   Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
405   colmajor.Set<float>({0, 0}, 1.0);
406   colmajor.Set<float>({0, 1}, 2.0);
407   colmajor.Set<float>({1, 0}, 3.0);
408   colmajor.Set<float>({1, 1}, 4.0);
409 
410   Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
411   rowmajor.Set<float>({0, 0}, 1.0);
412   rowmajor.Set<float>({0, 1}, 2.0);
413   rowmajor.Set<float>({1, 0}, 3.0);
414   rowmajor.Set<float>({1, 1}, 4.0);
415 
416   EXPECT_EQ(rowmajor, colmajor);
417 }
418 
TEST_F(LiteralUtilTest,TupleEquality)419 TEST_F(LiteralUtilTest, TupleEquality) {
420   // Test equality with tuples.
421   auto scalar = LiteralUtil::CreateR0<float>(1.0);
422   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
423   auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
424 
425   // Tuple with the same elements. One element is shared with the original
426   // tuple, the other is a clone of the element in the original tuple.
427   auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
428   auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
429   EXPECT_EQ(tuple1, tuple2);
430 
431   // Tuple with elements reversed.
432   auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
433   EXPECT_NE(tuple1, reversed_tuple);
434 
435   // Tuple with different value.
436   auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
437   auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
438   EXPECT_NE(tuple1, different_tuple);
439 }
440 
TEST_F(LiteralUtilTest,DynamicShapeEquality)441 TEST_F(LiteralUtilTest, DynamicShapeEquality) {
442   // Test equality with tuples.
443   auto r1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
444   r1.SetDynamicSize(0, {}, 1);
445   auto r2 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
446   r2.SetDynamicSize(0, {}, 1);
447   auto tuple1 = LiteralUtil::MakeTuple({&r1, &r2});
448 
449   // Tuple with the same elements. One element is shared with the original
450   // tuple, the other is a clone of the element in the original tuple.
451   auto r1_clone = LiteralUtil::CreateR1<float>({1.0, 3.0});
452   r1_clone.SetDynamicSize(0, {}, 1);
453   auto tuple2 = LiteralUtil::MakeTuple({&r1_clone, &r2});
454   EXPECT_EQ(tuple1, tuple2);
455 
456   // Tuple with different dynamic sizes.
457   auto r2_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
458   r2_clone.SetDynamicSize(0, {}, 2);
459   auto tuple_3 = LiteralUtil::MakeTuple({&r1_clone, &r2_clone});
460   EXPECT_NE(tuple1, tuple_3);
461 }
462 
TEST_F(LiteralUtilTest,C64Equality)463 TEST_F(LiteralUtilTest, C64Equality) {
464   // Test equality with tuples.
465   auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
466 
467   // Tuple with the same elements. One element is shared with the original
468   // tuple, the other is a clone of the element in the original tuple.
469   auto vector_clone =
470       LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
471   EXPECT_EQ(vector, vector_clone);
472 
473   auto vector_reversed =
474       LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
475   EXPECT_NE(vector, vector_reversed);
476 }
477 
TEST_F(LiteralUtilTest,C128Equality)478 TEST_F(LiteralUtilTest, C128Equality) {
479   // Test equality with tuples.
480   auto vector = LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
481 
482   // Tuple with the same elements. One element is shared with the original
483   // tuple, the other is a clone of the element in the original tuple.
484   auto vector_clone =
485       LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
486   EXPECT_EQ(vector, vector_clone);
487 
488   auto vector_reversed =
489       LiteralUtil::CreateR1<complex128>({{3.0, 4.0}, {1.0, 2.0}});
490   EXPECT_NE(vector, vector_reversed);
491 }
492 
TEST_F(LiteralUtilTest,IsAllTuple)493 TEST_F(LiteralUtilTest, IsAllTuple) {
494   auto element1 = LiteralUtil::CreateR0<float>(0.0);
495   auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
496   auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
497 
498   // Tuples should always return false for IsAll.
499   EXPECT_FALSE(tuple.IsAll(0));
500   EXPECT_FALSE(tuple.IsAll(1));
501 }
502 
503 // Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest,CreateFromShapeTuple)504 TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
505   auto scalar = LiteralUtil::CreateR0<float>(0.0);
506   auto matrix = LiteralUtil::CreateR2<int32_t>({{0, 0}, {0, 0}});
507   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
508 
509   auto x = Literal::CreateFromShape(tuple.shape());
510   EXPECT_EQ(tuple, x);
511 }
512 
TEST_F(LiteralUtilTest,IsAll)513 TEST_F(LiteralUtilTest, IsAll) {
514   EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
515   EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
516   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
517   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
518   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
519   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
520   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
521 
522   // We shouldn't reinterpret int8_min as an unsigned type and then decide that
523   // it is equal to 255.
524   auto int8_min = std::numeric_limits<int8_t>::min();
525   EXPECT_FALSE(LiteralUtil::CreateR0<uint8_t>(255).IsAll(int8_min));
526 
527   EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
528   EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
529 
530   EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
531   EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
532 
533   EXPECT_TRUE(LiteralUtil::CreateR2<uint64_t>({{8, 8}, {8, 8}}).IsAll(8));
534   EXPECT_FALSE(LiteralUtil::CreateR2<uint64_t>({{8, 8}, {8, 9}}).IsAll(8));
535   EXPECT_FALSE(LiteralUtil::CreateR2<uint64_t>({{9, 8}, {8, 8}}).IsAll(8));
536 
537   half h8(8.0f);
538   half h9(9.0f);
539   EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
540   EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
541   EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
542 
543   bfloat16 b8(8.0f);
544   bfloat16 b9(9.0f);
545 
546   EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
547   EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
548   EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
549 
550   // 9.001 will be truncated to 9.0
551   bfloat16 b91(9.001f);
552   bfloat16 b90(9.00f);
553   EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
554 
555   complex64 c8_9 = {8, 9};
556   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
557 
558   auto uint64_max = std::numeric_limits<uint64_t>::max();
559   EXPECT_FALSE(LiteralUtil::CreateR2<uint64_t>(
560                    {{uint64_max, uint64_max}, {uint64_max, uint64_max}})
561                    .IsAll(-1));
562 }
563 
TEST_F(LiteralUtilTest,IsAllFloat)564 TEST_F(LiteralUtilTest, IsAllFloat) {
565   // IsAllFloat always returns false when the literal is not floating-point.
566   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
567   EXPECT_FALSE(LiteralUtil::CreateR0<int8_t>(0).IsAllFloat(0));
568   EXPECT_FALSE(LiteralUtil::CreateR0<uint8_t>(0).IsAllFloat(0));
569   EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
570 
571   EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
572   EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
573   EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
574   EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
575   EXPECT_FALSE(
576       LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
577   EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
578                   .IsAllFloat(.5));
579 
580   EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
581   EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
582   EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
583   EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
584   EXPECT_FALSE(
585       LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
586 }
587 
TEST_F(LiteralUtilTest,IsAllComplex)588 TEST_F(LiteralUtilTest, IsAllComplex) {
589   // IsAllComplex always returns false when the literal is not complex.
590   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
591   EXPECT_FALSE(LiteralUtil::CreateR0<int8_t>(0).IsAllComplex(0));
592   EXPECT_FALSE(LiteralUtil::CreateR0<uint8_t>(0).IsAllComplex(0));
593   EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
594   EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
595   EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
596 
597   complex64 c8_9 = {8, 9};
598   complex64 c7_9 = {7, 9};
599   EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
600                   .IsAllComplex({8.0f, 9.0f}));
601   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
602                    .IsAllComplex({8.0f, 9.0f}));
603   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
604                    .IsAllComplex({8.0f, 9.0f}));
605 }
606 
TEST_F(LiteralUtilTest,IsAllFirst)607 TEST_F(LiteralUtilTest, IsAllFirst) {
608   // IsAllComplex always returns false when the literal is not complex.
609   EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
610   EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
611   EXPECT_FALSE(LiteralUtil::CreateR1<int8_t>({1, 1, 2}).IsAllFirst());
612   EXPECT_TRUE(LiteralUtil::CreateR1<int8_t>({5, 5, 5, 5}).IsAllFirst());
613   EXPECT_FALSE(LiteralUtil::CreateR1<uint8_t>({1, 1, 2}).IsAllFirst());
614   EXPECT_TRUE(LiteralUtil::CreateR1<int32_t>({5, 5, 5, 5}).IsAllFirst());
615   EXPECT_FALSE(LiteralUtil::CreateR1<int32_t>({1, 1, 2}).IsAllFirst());
616   EXPECT_TRUE(LiteralUtil::CreateR1<uint32_t>({5, 5, 5, 5}).IsAllFirst());
617   EXPECT_FALSE(LiteralUtil::CreateR1<uint32_t>({1, 1, 2}).IsAllFirst());
618 
619   complex64 c8_9 = {8, 9};
620   complex64 c7_9 = {7, 9};
621   EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
622   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
623 }
624 
TEST_F(LiteralUtilTest,IsZero)625 TEST_F(LiteralUtilTest, IsZero) {
626   auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
627   auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
628   EXPECT_TRUE(scalar_zero.IsZero({}));
629   EXPECT_FALSE(scalar_one.IsZero({}));
630 
631   auto array = LiteralUtil::CreateR2<uint32_t>({{1, 2, 0, 3}, {1, 0, 1, 2}});
632   EXPECT_FALSE(array.IsZero({0, 1}));
633   EXPECT_TRUE(array.IsZero({0, 2}));
634   EXPECT_TRUE(array.IsZero({1, 1}));
635   EXPECT_FALSE(array.IsZero({1, 2}));
636 
637   auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
638   auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
639   EXPECT_TRUE(complex_zero.IsZero({}));
640   EXPECT_FALSE(complex_nonzero.IsZero({}));
641 }
642 
643 template <typename T>
644 class LiteralUtilTestTemplated : public ::testing::Test {};
645 
646 using TestedTypes = ::testing::Types<float, int32_t, uint32_t, complex64>;
647 TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes);
648 
TYPED_TEST(LiteralUtilTestTemplated,Relayout2x2)649 TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
650   // Make a non-integer for floating point types.
651   TypeParam half = TypeParam(1) / TypeParam(2);
652   auto data = LiteralUtil::CreateR2<TypeParam>({{half, 2}, {3, 4}});
653   const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
654   const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
655 
656   auto data01 = data.Relayout(layout01);
657   EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
658   EXPECT_EQ(data, data01);
659 
660   auto data10 = data.Relayout(layout10);
661   EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
662   EXPECT_EQ(data, data10);
663 }
664 
TEST_F(LiteralUtilTest,ReshapeR0)665 TEST_F(LiteralUtilTest, ReshapeR0) {
666   auto original = LiteralUtil::CreateR0<float>(1.7f);
667   auto reshape = original.Reshape(/*dimensions=*/{}).value();
668   EXPECT_EQ(original, reshape);
669 }
670 
TEST_F(LiteralUtilTest,ReshapeR4)671 TEST_F(LiteralUtilTest, ReshapeR4) {
672   // clang-format off
673   // F32[1x3x2x4]
674   auto original = LiteralUtil::CreateR4WithLayout<float>({{
675      {{10, 11, 12, 13}, {14, 15, 16, 17}},
676      {{18, 19, 20, 21}, {22, 23, 24, 25}},
677      {{26, 27, 28, 29}, {30, 31, 32, 33}},
678   }}, layout_r4_dim0major_);
679   // F32[1x3x4x2]
680   auto expected = LiteralUtil::CreateR3WithLayout<float>({
681     {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
682     {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
683     {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
684   }, layout_r3_dim0major_);
685   // clang-format on
686   auto reshape = original.Reshape({3, 4, 2}).value();
687 
688   EXPECT_EQ(expected, reshape);
689 }
690 
TEST_F(LiteralUtilTest,ReshapeR4Dim0Minor)691 TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
692   // clang-format off
693   // F32[1x3x2x4]
694   auto original = LiteralUtil::CreateR4WithLayout<float>({{
695      {{10, 11, 12, 13}, {14, 15, 16, 17}},
696      {{18, 19, 20, 21}, {22, 23, 24, 25}},
697      {{26, 27, 28, 29}, {30, 31, 32, 33}},
698   }}, layout_r4_dim0minor_);
699   // F32[1x3x4x2]
700   auto expected = LiteralUtil::CreateR3WithLayout<float>({
701     {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
702     {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
703     {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
704   }, layout_r3_dim0major_);
705   // clang-format on
706   auto reshape = original.Reshape({3, 4, 2}).value();
707 
708   EXPECT_EQ(expected, reshape);
709 }
710 
TEST_F(LiteralUtilTest,TransposeR0)711 TEST_F(LiteralUtilTest, TransposeR0) {
712   auto original = LiteralUtil::CreateR0<float>(1.7f);
713   auto reshape = original.Transpose(/*permutation=*/{});
714   EXPECT_EQ(original, reshape);
715 }
716 
TEST_F(LiteralUtilTest,TransposeR4)717 TEST_F(LiteralUtilTest, TransposeR4) {
718   // clang-format off
719   // F32[1x3x2x4]
720   auto original = LiteralUtil::CreateR4<float>({{
721      {{10, 11, 12, 13}, {14, 15, 16, 17}},
722      {{18, 19, 20, 21}, {22, 23, 24, 25}},
723      {{26, 27, 28, 29}, {30, 31, 32, 33}},
724   }});
725   // clang-format on
726   auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
727 
728   reshape.EachCell<float>([&](absl::Span<const int64_t> indices, float value) {
729     EXPECT_EQ(value, original.Get<float>(
730                          {indices[2], indices[3], indices[0], indices[1]}));
731   });
732 }
733 
TEST_F(LiteralUtilTest,TransposeDynamicR2)734 TEST_F(LiteralUtilTest, TransposeDynamicR2) {
735   // F32[2, <=3] (2, 1)
736   auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
737   original.SetDynamicSize(1, 1);
738   // F32[<=3, 2] (1, 2)
739   auto reshape = original.Transpose(/*permutation=*/{1, 0});
740 
741   reshape.EachCell<float>([&](absl::Span<const int64_t> indices, float value) {
742     EXPECT_EQ(value, original.Get<float>({indices[1], indices[0]}));
743   });
744 }
745 
TEST_F(LiteralUtilTest,ToStaticR2)746 TEST_F(LiteralUtilTest, ToStaticR2) {
747   // F32[2, <=3] (2, 1)
748   auto original = LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}});
749   original.SetDynamicSize(1, 1);
750   // F32[2, 1]
751   auto static_literal = original.ToStatic();
752   EXPECT_EQ(static_literal.shape(), ShapeUtil::MakeShape(F32, {2, 1}));
753   EXPECT_TRUE(static_literal.shape().is_static());
754 
755   static_literal.EachCell<float>(
756       [&](absl::Span<const int64_t> indices, float value) {
757         EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
758       });
759 }
760 
TEST_F(LiteralUtilTest,ToBoundedDynamicR2)761 TEST_F(LiteralUtilTest, ToBoundedDynamicR2) {
762   // F32[2, 1]
763   auto original = LiteralUtil::CreateR2<float>({{1}, {4}});
764   // F32[2, <=3] (2, 1)
765   auto dynamic_shape = ShapeUtil::MakeShape(F32, {2, 3}, {false, true});
766   auto dynamic_literal = original.ToBoundedDynamic(dynamic_shape);
767   EXPECT_EQ(dynamic_literal.shape(), dynamic_shape);
768 
769   dynamic_literal.EachCell<float>(
770       [&](absl::Span<const int64_t> indices, float value) {
771         EXPECT_EQ(value, original.Get<float>({indices[0], indices[1]}));
772       });
773 }
774 
TEST_F(LiteralUtilTest,TestR4RelayoutEquivalence)775 TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
776   // Tests that using Relayout on an array is equivalent to creating it in the
777   // target layout in the first place.
778   auto dim0minor_relaid_to_dim0major =
779       literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
780   EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
781 
782   auto dim0major_relaid_to_dim0minor =
783       literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
784   EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
785 }
786 
TEST_F(LiteralUtilTest,TestR2LinearLayout)787 TEST_F(LiteralUtilTest, TestR2LinearLayout) {
788   // Test expected memory layout of R2 dim0-minor (column-major) literal.
789   auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32_t>(
790       {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
791   EXPECT_EQ(mat_dim0minor.element_count(), 6);
792   EXPECT_THAT(mat_dim0minor.data<int32_t>(), ElementsAre(1, 4, 2, 5, 3, 6));
793 
794   // Test expected memory layout when using Relayout to row major.
795   auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
796   EXPECT_THAT(relaid_mat_to_dim0major.data<int32_t>(),
797               ElementsAre(1, 2, 3, 4, 5, 6));
798 
799   // Test expected memory layout of R2 created with dim0-major (row-major).
800   auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32_t>(
801       {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
802   EXPECT_EQ(mat_dim0major.element_count(), 6);
803   EXPECT_THAT(mat_dim0major.data<int32_t>(), ElementsAre(1, 2, 3, 4, 5, 6));
804 
805   // Test expected memory layout when using Relayout to column major.
806   auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
807   EXPECT_THAT(relaid_mat_to_dim0minor.data<int32_t>(),
808               ElementsAre(1, 4, 2, 5, 3, 6));
809 }
810 
TEST_F(LiteralUtilTest,TestR3LinearLayout)811 TEST_F(LiteralUtilTest, TestR3LinearLayout) {
812   // Test expected memory layout of R3 dim0-minor (column-major) literal.
813   Array3D<int> arr3d(
814       // clang-format off
815         {
816           {
817             {1, 2, 3},
818             {4, 5, 6},
819           },
820           {
821             {7, 8, 9},
822             {10, 11, 12},
823           },
824       });  // clang-format on
825   auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
826       arr3d, layout_r3_dim0minor_);
827 
828   EXPECT_EQ(lit_dim0minor.element_count(), 12);
829   std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
830   EXPECT_THAT(lit_dim0minor.data<int32_t>(),
831               testing::ElementsAreArray(expected_dim0minor));
832 
833   // Test expected memory layout when using Relayout to row major.
834   auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
835   std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
836   EXPECT_THAT(relaid_lit_to_dim0major.data<int32_t>(),
837               testing::ElementsAreArray(expected_dim0major));
838 
839   // Test expected memory layout of R3 created with dim0-major (row-major).
840   auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
841       arr3d, layout_r3_dim0major_);
842   EXPECT_EQ(lit_dim0major.element_count(), 12);
843   EXPECT_THAT(lit_dim0major.data<int32_t>(),
844               testing::ElementsAreArray(expected_dim0major));
845 
846   // Test expected memory layout when using Relayout to column major.
847   auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
848   EXPECT_THAT(relaid_lit_to_dim0minor.data<int32_t>(),
849               testing::ElementsAreArray(expected_dim0minor));
850 }
851 
TEST_F(LiteralUtilTest,SliceR0S32)852 TEST_F(LiteralUtilTest, SliceR0S32) {
853   auto input = LiteralUtil::CreateR0<int32_t>(1);
854   auto result = input.Slice({}, {});
855   EXPECT_EQ(input, result);
856 }
857 
TEST_F(LiteralUtilTest,SliceR1F32)858 TEST_F(LiteralUtilTest, SliceR1F32) {
859   auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
860   auto result = input.Slice({3}, {4});
861   auto expected = LiteralUtil::CreateR1<float>({4.0});
862   EXPECT_EQ(expected, result);
863 }
864 
TEST_F(LiteralUtilTest,SliceR2U32)865 TEST_F(LiteralUtilTest, SliceR2U32) {
866   auto input_3x4 = LiteralUtil::CreateR2<uint32_t>(
867       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
868   auto result = input_3x4.Slice({0, 2}, {2, 4});
869   auto expected = LiteralUtil::CreateR2<uint32_t>({{3, 4}, {7, 8}});
870   EXPECT_EQ(expected, result);
871 }
872 
TEST_F(LiteralUtilTest,SliceR3U32Full)873 TEST_F(LiteralUtilTest, SliceR3U32Full) {
874   auto input_2x3x2 = LiteralUtil::CreateR3<uint32_t>(
875       {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
876   auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
877   EXPECT_EQ(input_2x3x2, result);
878 }
879 
TEST_F(LiteralUtilTest,SliceR2Dynamic)880 TEST_F(LiteralUtilTest, SliceR2Dynamic) {
881   auto input_3x4 = LiteralUtil::CreateR2<uint32_t>(
882       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
883   input_3x4.SetDynamicSize(1, 3);
884   // slice second dim from dynamic size 3 to dynamic size 1.
885   auto result = input_3x4.Slice({0, 1}, {2, 2});
886   auto expected = LiteralUtil::CreateR2<uint32_t>({{2}, {6}});
887   EXPECT_EQ(expected, result);
888   EXPECT_EQ(result.GetDynamicSize(1), 1);
889 }
890 
TEST_F(LiteralUtilTest,SliceR2DynamicInBound)891 TEST_F(LiteralUtilTest, SliceR2DynamicInBound) {
892   auto input_3x4 = LiteralUtil::CreateR2<uint32_t>(
893       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
894   input_3x4.SetDynamicSize(1, 1);
895   auto result = input_3x4.Slice({0, 0}, {2, 2});
896   auto expected = LiteralUtil::CreateR2<uint32_t>({{1}, {5}});
897   EXPECT_EQ(expected, result);
898   EXPECT_EQ(result.GetDynamicSize(1), 1);
899 }
900 
TEST_F(LiteralUtilTest,SliceR2DynamicOutOfBound)901 TEST_F(LiteralUtilTest, SliceR2DynamicOutOfBound) {
902   auto input_3x4 = LiteralUtil::CreateR2<uint32_t>(
903       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
904   input_3x4.SetDynamicSize(1, 1);
905   auto result = input_3x4.Slice({0, 1}, {2, 3});
906   auto expected = LiteralUtil::CreateR2<uint32_t>({{}, {}});
907   EXPECT_EQ(expected, result);
908   // Out of bound access clamps into 0 sized dimension.
909   EXPECT_EQ(result.GetDynamicSize(1), 0);
910 }
911 
TEST_F(LiteralUtilTest,PopulateR1S64)912 TEST_F(LiteralUtilTest, PopulateR1S64) {
913   Literal output(ShapeUtil::MakeShape(S64, {1}));
914   output.PopulateR1<int64_t>({77});
915   auto expected = LiteralUtil::CreateR1<int64_t>({77});
916   EXPECT_EQ(output, expected);
917 }
918 
TEST_F(LiteralUtilTest,PopulateR1U64)919 TEST_F(LiteralUtilTest, PopulateR1U64) {
920   Literal output(ShapeUtil::MakeShape(U64, {2}));
921   output.PopulateR1<uint64_t>({{77, 88}});
922   auto expected = LiteralUtil::CreateR1<uint64_t>({{77, 88}});
923   EXPECT_EQ(output, expected);
924 }
925 
TEST_F(LiteralUtilTest,PopulateR1C64)926 TEST_F(LiteralUtilTest, PopulateR1C64) {
927   Literal output(ShapeUtil::MakeShape(C64, {1}));
928   output.PopulateR1<complex64>({{77, 88}});
929   auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
930   EXPECT_EQ(output, expected);
931 }
932 
TEST_F(LiteralUtilTest,PopulateR1C128)933 TEST_F(LiteralUtilTest, PopulateR1C128) {
934   Literal output(ShapeUtil::MakeShape(C128, {1}));
935   output.PopulateR1<complex128>({{77, 88}});
936   auto expected = LiteralUtil::CreateR1<complex128>({{77, 88}});
937   EXPECT_EQ(output, expected);
938 }
939 
TEST_F(LiteralUtilTest,PopulateR2C64)940 TEST_F(LiteralUtilTest, PopulateR2C64) {
941   Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
942   output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
943   auto expected =
944       LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
945   EXPECT_EQ(output, expected);
946 }
947 
TEST_F(LiteralUtilTest,PopulateWithValueR0BF16)948 TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
949   Literal output(ShapeUtil::MakeShape(BF16, {}));
950   bfloat16 h(0.25f);
951   output.PopulateWithValue<bfloat16>(h);
952   auto expected = LiteralUtil::CreateR0<bfloat16>(h);
953   EXPECT_EQ(output, expected);
954 }
955 
TEST_F(LiteralUtilTest,PopulateWithValueR1BF16)956 TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
957   Literal output(ShapeUtil::MakeShape(BF16, {3}));
958   bfloat16 h(0.5f);
959   output.PopulateWithValue<bfloat16>(h);
960   auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
961   EXPECT_EQ(output, expected);
962 }
963 
TEST_F(LiteralUtilTest,PopulateWithValueR2BF16)964 TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
965   Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
966   bfloat16 h(2.0f);
967   output.PopulateWithValue<bfloat16>(h);
968   auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
969   EXPECT_EQ(output, expected);
970 }
971 
TEST_F(LiteralUtilTest,PopulateWithValueR0F32)972 TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
973   Literal output(ShapeUtil::MakeShape(F32, {}));
974   output.PopulateWithValue<float>(2.5f);
975   auto expected = LiteralUtil::CreateR0<float>(2.5f);
976   EXPECT_EQ(output, expected);
977 }
978 
TEST_F(LiteralUtilTest,PopulateWithValueR1S64)979 TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
980   Literal output(ShapeUtil::MakeShape(S64, {3}));
981   output.PopulateWithValue<int64_t>(-7);
982   auto expected = LiteralUtil::CreateR1<int64_t>({-7, -7, -7});
983   EXPECT_EQ(output, expected);
984 }
985 
TEST_F(LiteralUtilTest,PopulateWithValueR2U64)986 TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
987   Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
988   output.PopulateWithValue<uint64_t>(42);
989   auto expected = LiteralUtil::CreateR2<uint64_t>({{42, 42}, {42, 42}});
990   EXPECT_EQ(output, expected);
991 }
992 
TEST_F(LiteralUtilTest,PopulateWithValueR2C64)993 TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
994   Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
995   output.PopulateWithValue<complex64>({4, 2});
996   auto expected =
997       LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
998   EXPECT_EQ(output, expected);
999 }
1000 
TEST_F(LiteralUtilTest,PopulateWithValueR2C128)1001 TEST_F(LiteralUtilTest, PopulateWithValueR2C128) {
1002   Literal output(ShapeUtil::MakeShape(C128, {2, 2}));
1003   output.PopulateWithValue<complex128>({4, 2});
1004   auto expected =
1005       LiteralUtil::CreateR2<complex128>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
1006   EXPECT_EQ(output, expected);
1007 }
1008 
TEST_F(LiteralUtilTest,PopulateWithValueR0F16)1009 TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
1010   Literal output(ShapeUtil::MakeShape(F16, {}));
1011   half h(0.25f);
1012   output.PopulateWithValue<half>(h);
1013   auto expected = LiteralUtil::CreateR0<half>(h);
1014   EXPECT_EQ(output, expected);
1015 }
1016 
TEST_F(LiteralUtilTest,PopulateWithValueR1F16)1017 TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
1018   Literal output(ShapeUtil::MakeShape(F16, {3}));
1019   half h(0.5f);
1020   output.PopulateWithValue<half>(h);
1021   auto expected = LiteralUtil::CreateR1<half>({h, h, h});
1022   EXPECT_EQ(output, expected);
1023 }
1024 
TEST_F(LiteralUtilTest,PopulateWithValueR2F16)1025 TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
1026   Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
1027   half h(2.0f);
1028   output.PopulateWithValue<half>(h);
1029   auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
1030   EXPECT_EQ(output, expected);
1031 }
1032 
TEST_F(LiteralUtilTest,ReplicateR2U32)1033 TEST_F(LiteralUtilTest, ReplicateR2U32) {
1034   auto input = LiteralUtil::CreateR2<uint32_t>(
1035       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
1036   auto output = input.Replicate<uint32_t>(3);
1037   auto expected = LiteralUtil::CreateR3<uint32_t>(
1038       {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
1039        {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
1040        {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
1041   EXPECT_EQ(output, expected);
1042 }
1043 
TEST_F(LiteralUtilTest,CopySliceFrom)1044 TEST_F(LiteralUtilTest, CopySliceFrom) {
1045   const int64_t dimensions[] = {17, 15, 34, 21};
1046   const int64_t layouts[][4] = {
1047       {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}};
1048   for (const auto& layout : layouts) {
1049     Shape shape = ShapeUtil::MakeShapeWithLayout(
1050         primitive_util::NativeToPrimitiveType<uint32_t>(), dimensions, layout);
1051 
1052     auto source = Literal::CreateFromShape(shape);
1053     const int64_t zero_base[] = {0, 0, 0, 0};
1054     const int64_t step[] = {1, 1, 1, 1};
1055     uint32_t seqnr = 0;
1056     auto init_proc = [&](absl::Span<const int64_t> indexes) {
1057       source.Set(indexes, ++seqnr);
1058       return true;
1059     };
1060     ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
1061                             init_proc);
1062 
1063     auto blank = Literal::CreateFromShape(shape);
1064     const int64_t src_base[] = {3, 1, 5, 7};
1065     const int64_t dest_base[] = {6, 4, 12, 2};
1066     const int64_t copy_size[] = {7, 8, 11, 9};
1067     TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
1068 
1069     std::vector<int64_t> source_indexes(TF_ARRAYSIZE(dimensions), 0);
1070     std::vector<int64_t> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
1071     bool matched = true;
1072     auto check_proc = [&](absl::Span<const int64_t> indexes) {
1073       std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
1074       std::transform(source_indexes.begin(), source_indexes.end(), src_base,
1075                      source_indexes.begin(), std::plus<int64_t>());
1076       std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
1077       std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
1078                      blank_indexes.begin(), std::plus<int64_t>());
1079       auto bval = blank.Get<uint32_t>(blank_indexes);
1080       matched = (bval != 0 && bval == source.Get<uint32_t>(source_indexes));
1081       return matched;
1082     };
1083 
1084     ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
1085                             check_proc);
1086     EXPECT_TRUE(matched);
1087   }
1088 }
1089 
TEST_F(LiteralUtilTest,CopyFromScalars)1090 TEST_F(LiteralUtilTest, CopyFromScalars) {
1091   auto zero = LiteralUtil::CreateR0<uint32_t>(0);
1092   auto nine = LiteralUtil::CreateR0<uint32_t>(9);
1093   TF_EXPECT_OK(zero.CopyFrom(nine));
1094   EXPECT_EQ(zero, nine);
1095 
1096   auto vect = LiteralUtil::CreateR1<uint32_t>({3, 4, 9, 12, 5, 17, 21});
1097   TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
1098   EXPECT_EQ(zero.Get<uint32_t>({}), 17);
1099   TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
1100   EXPECT_EQ(vect.Get<uint32_t>({4}), 17);
1101 }
1102 
TEST_F(LiteralUtilTest,CopyFromAndToZeroElement)1103 TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
1104   const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0});
1105   const auto const_nine = LiteralUtil::CreateR1<float>({9});
1106   const auto const_empty = Literal::CreateFromShape(empty_r1_shape);
1107 
1108   {
1109     // Source contains dimension with zero elements.
1110     const auto empty = Literal::CreateFromShape(empty_r1_shape);
1111     auto nine = LiteralUtil::CreateR1<float>({9});
1112 
1113     TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
1114     EXPECT_EQ(nine, const_nine);
1115   }
1116 
1117   {
1118     // Copy 0 element to destination with zero elements.
1119     auto empty = Literal::CreateFromShape(empty_r1_shape);
1120     auto nine = LiteralUtil::CreateR1<float>({9});
1121 
1122     TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
1123     EXPECT_EQ(empty, const_empty);
1124   }
1125 }
1126 
TEST_F(LiteralUtilTest,CopyFromNilShape)1127 TEST_F(LiteralUtilTest, CopyFromNilShape) {
1128   Literal nil_literal0(ShapeUtil::MakeNil());
1129   Literal nil_literal1(ShapeUtil::MakeNil());
1130   // This doesn't actually do any copying, but it should succeed.
1131   TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1));
1132 }
1133 
TEST_F(LiteralUtilTest,CopyFromArrays)1134 TEST_F(LiteralUtilTest, CopyFromArrays) {
1135   auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
1136   auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
1137   EXPECT_NE(scalar_42, scalar_123);
1138   TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
1139                                   /*src_shape_index=*/{}));
1140   EXPECT_EQ(scalar_42, scalar_123);
1141   EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
1142 
1143   auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1144   auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
1145   EXPECT_NE(matrix_1234, matrix_5678);
1146   EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
1147   TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
1148                                     /*src_shape_index=*/{}));
1149   EXPECT_EQ(matrix_1234, matrix_5678);
1150   EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
1151 }
1152 
TEST_F(LiteralUtilTest,CopyFromTuples)1153 TEST_F(LiteralUtilTest, CopyFromTuples) {
1154   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1155   Literal nil_literal(ShapeUtil::MakeNil());
1156   Literal inner_elements[] = {LiteralUtil::CreateR0<int32_t>(42),
1157                               LiteralUtil::CreateR1<double>({23.0, 44.0})};
1158   Literal inner_tuple = LiteralUtil::MakeTuple(
1159       {&inner_elements[0], &inner_elements[1], &nil_literal});
1160   Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
1161   // Create a tuple the same shape as the inner tuple of nested_tuple but with
1162   // different values..
1163   Literal int32_minus5 = LiteralUtil::CreateR0<int32_t>(-5);
1164   Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
1165   Literal tuple =
1166       LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
1167 
1168   EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
1169   EXPECT_EQ(nested_tuple.Get<int32_t>({}, {1, 0}), 42);
1170   EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
1171   EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
1172 
1173   // Overwrite the inner tuple element of nested_tuple with the contents of
1174   // 'tuple'.
1175   TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
1176                                      /*src_shape_index=*/{}));
1177 
1178   // The matrix element should be unchanged.
1179   EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
1180 
1181   // The tuple element should have been copied from 'tuple'.
1182   EXPECT_EQ(nested_tuple.Get<int32_t>({}, {1, 0}), -5);
1183   EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
1184   EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
1185 }
TEST_F(LiteralUtilTest,CopyBetweenSameTuple)1186 TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
1187   Literal elements[] = {LiteralUtil::CreateR0<int32_t>(-2),
1188                         LiteralUtil::CreateR0<int32_t>(4)};
1189   Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
1190 
1191   EXPECT_EQ(tuple.Get<int32_t>({}, {0}), -2);
1192   EXPECT_EQ(tuple.Get<int32_t>({}, {1}), 4);
1193 
1194   // Copy from one element to the other.
1195   TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
1196                               /*src_shape_index=*/{0}));
1197 
1198   EXPECT_EQ(tuple.Get<int32_t>({}, {0}), -2);
1199   EXPECT_EQ(tuple.Get<int32_t>({}, {1}), -2);
1200 }
1201 
TEST_F(LiteralUtilTest,CopyFromDifferentShapes)1202 TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
1203   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1204   auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
1205   Status status = matrix.CopyFrom(vector);
1206   ASSERT_FALSE(status.ok());
1207   EXPECT_THAT(status.error_message(),
1208               HasSubstr("Destination subshape incompatible"));
1209 }
1210 
TEST_F(LiteralUtilTest,F16)1211 TEST_F(LiteralUtilTest, F16) {
1212   // Verify that the internal data views are consistent and that they
1213   // are in little endian format
1214   // TODO - modify if we make the data format machine endianness dependent
1215   Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
1216   const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
1217   EXPECT_EQ(d1[0], 0);
1218   EXPECT_EQ(d1[1], 0);
1219   EXPECT_EQ(d1[2], 0);
1220   EXPECT_EQ(d1[3], 0);
1221   EXPECT_EQ(d1[4], 0);
1222   EXPECT_EQ(d1[5], 0);
1223   EXPECT_EQ(d1[6], 0);
1224   EXPECT_EQ(d1[7], 0);
1225 
1226   half h1(1.0f);
1227   half h2(2.0f);
1228   auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
1229   const uint16_t* d2 =
1230       reinterpret_cast<const uint16_t*>(m2.data<half>().data());
1231   EXPECT_EQ(d2[0], 0x3C00);
1232   EXPECT_EQ(d2[1], 0x4000);
1233   EXPECT_EQ(d2[2], 0x4000);
1234   EXPECT_EQ(d2[3], 0x3C00);
1235 }
1236 
TEST_F(LiteralUtilTest,Populate)1237 TEST_F(LiteralUtilTest, Populate) {
1238   struct PopulateData {
1239     std::vector<int64_t> dimensions;
1240     std::vector<int64_t> layout;
1241   } populate_data[] = {
1242       {{}, {}},
1243       {{0}, {0}},
1244       {{16}, {0}},
1245       {{2, 0}, {1, 0}},
1246       {{4, 16}, {1, 0}},
1247       {{21, 12}, {0, 1}},
1248       {{6, 11, 17}, {2, 0, 1}},
1249       {{6, 11, 5, 17}, {3, 2, 0, 1}},
1250   };
1251   for (const auto& data : populate_data) {
1252     Shape shape = ShapeUtil::MakeShapeWithLayout(
1253         primitive_util::NativeToPrimitiveType<uint32_t>(), data.dimensions,
1254         data.layout);
1255     Literal literal(shape);
1256     auto generator = [&](absl::Span<const int64_t> indexes) -> uint32_t {
1257       // Offsets from linear index just to avoid R0 literals to be initialized
1258       // with zero.
1259       return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
1260                                                            indexes) +
1261              17;
1262     };
1263     TF_EXPECT_OK(literal.Populate<uint32_t>(generator));
1264 
1265     std::vector<int64_t> zero_base(data.dimensions.size(), 0);
1266     std::vector<int64_t> step(data.dimensions.size(), 1);
1267     bool matched = true;
1268     auto check_function = [&](absl::Span<const int64_t> indexes) {
1269       auto value = literal.Get<uint32_t>(indexes);
1270       matched = matched && (value == generator(indexes));
1271       return matched;
1272     };
1273     ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
1274                             check_function);
1275     EXPECT_TRUE(matched);
1276   }
1277 }
1278 
TEST_F(LiteralUtilTest,PopulateParallel)1279 TEST_F(LiteralUtilTest, PopulateParallel) {
1280   struct PopulateData {
1281     std::vector<int64_t> dimensions;
1282     std::vector<int64_t> layout;
1283   } populate_data[] = {
1284       {{}, {}},
1285       {{0}, {0}},
1286       {{16}, {0}},
1287       {{2, 0}, {1, 0}},
1288       {{4, 16}, {1, 0}},
1289       {{21, 12}, {0, 1}},
1290       {{6, 11, 17}, {2, 0, 1}},
1291       {{6, 11, 5, 17}, {3, 2, 0, 1}},
1292   };
1293   for (const auto& data : populate_data) {
1294     Shape shape = ShapeUtil::MakeShapeWithLayout(
1295         primitive_util::NativeToPrimitiveType<uint32_t>(), data.dimensions,
1296         data.layout);
1297     Literal literal(shape);
1298     auto generator = [&](absl::Span<const int64_t> indexes,
1299                          int /*thread_id*/) -> uint32_t {
1300       // Offsets from linear index just to avoid R0 literals to be initialized
1301       // with zero.
1302       return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
1303                                                            indexes) +
1304              17;
1305     };
1306     TF_EXPECT_OK(literal.PopulateParallel<uint32_t>(generator));
1307 
1308     std::vector<int64_t> zero_base(data.dimensions.size(), 0);
1309     std::vector<int64_t> step(data.dimensions.size(), 1);
1310     bool matched = true;
1311     auto check_function = [&](absl::Span<const int64_t> indexes) {
1312       auto value = literal.Get<uint32_t>(indexes);
1313       matched = matched && (value == generator(indexes, /*thread_id=*/-1));
1314       return matched;
1315     };
1316     ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
1317                             check_function);
1318     EXPECT_TRUE(matched);
1319   }
1320 }
1321 
TEST_F(LiteralUtilTest,ConvertR4)1322 TEST_F(LiteralUtilTest, ConvertR4) {
1323   // clang-format off
1324   auto original = LiteralUtil::CreateR4WithLayout<int8_t>({{
1325      {{10, 11, 12, 13}, {14, 15, 16, 17}},
1326      {{18, 19, 20, 21}, {22, 23, 24, 25}},
1327      {{26, 27, 28, 29}, {30, 31, 32, 33}},
1328   }}, layout_r4_dim0major_);
1329   auto expected = LiteralUtil::CreateR4WithLayout<uint32_t>({{
1330      {{10, 11, 12, 13}, {14, 15, 16, 17}},
1331      {{18, 19, 20, 21}, {22, 23, 24, 25}},
1332      {{26, 27, 28, 29}, {30, 31, 32, 33}},
1333   }}, layout_r4_dim0major_);
1334   // clang-format on
1335   TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
1336 
1337   EXPECT_EQ(expected, converted);
1338 }
1339 
TEST_F(LiteralUtilTest,ConvertIfTypesMatch)1340 TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
1341   // clang-format off
1342   auto s8 = LiteralUtil::CreateR4WithLayout<int8_t>({{
1343     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1344     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1345     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1346   }}, layout_r4_dim0major_);
1347   auto s16 = LiteralUtil::CreateR4WithLayout<int16_t>({{
1348     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1349     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1350     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1351   }}, layout_r4_dim0major_);
1352   auto s32 = LiteralUtil::CreateR4WithLayout<int32_t>({{
1353     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1354     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1355     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1356   }}, layout_r4_dim0major_);
1357   auto u16 = LiteralUtil::CreateR4WithLayout<uint16_t>({{
1358     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1359     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1360     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1361   }}, layout_r4_dim0major_);
1362   auto u32 = LiteralUtil::CreateR4WithLayout<uint32_t>({{
1363     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1364     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1365     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1366   }}, layout_r4_dim0major_);
1367   auto s64 = LiteralUtil::CreateR4WithLayout<int64_t>({{
1368     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1369     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1370     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1371   }}, layout_r4_dim0major_);
1372   auto u64 = LiteralUtil::CreateR4WithLayout<uint64_t>({{
1373     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1374     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1375     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1376   }}, layout_r4_dim0major_);
1377   auto pred = LiteralUtil::CreateR4WithLayout<bool>({{
1378     {{true, false, true, false}, {false, true, false, true}},
1379     {{false, true, false, true}, {true, false, true, false}},
1380     {{true, false, true, false}, {false, true, false, true}},
1381   }}, layout_r4_dim0major_);
1382   auto int32_pred = LiteralUtil::CreateR4WithLayout<int32_t>({{
1383     {{1, 0, 1, 0}, {0, 1, 0, 1}},
1384     {{0, 1, 0, 1}, {1, 0, 1, 0}},
1385     {{1, 0, 1, 0}, {0, 1, 0, 1}},
1386   }}, layout_r4_dim0major_);
1387   auto f16 = LiteralUtil::CreateR4WithLayout<half>({{
1388     {{half(10.0), half(0.0), half(12.0), half(0.0)},
1389      {half(0.0), half(15.0), half(0.0), half(17.0)}},
1390     {{half(0.0), half(19.0), half(0.0), half(21.0)},
1391      {half(22.0), half(0.0), half(24.0), half(0.0)}},
1392     {{half(26.0), half(0.0), half(28.0), half(0.0)},
1393      {half(0.0), half(31.0), half(0.0), half(33.0)}},
1394   }}, layout_r4_dim0major_);
1395   auto bf16 = LiteralUtil::CreateR4WithLayout<bfloat16>({{
1396     {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
1397      {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
1398     {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
1399      {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
1400     {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
1401      {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
1402   }}, layout_r4_dim0major_);
1403   auto f32 = LiteralUtil::CreateR4WithLayout<float>({{
1404     {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
1405     {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
1406     {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
1407   }}, layout_r4_dim0major_);
1408   auto f64 = LiteralUtil::CreateR4WithLayout<double>({{
1409     {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
1410     {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
1411     {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
1412   }}, layout_r4_dim0major_);
1413   auto c64 = LiteralUtil::CreateR4WithLayout<complex64>({{
1414     {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
1415     {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
1416     {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
1417   }}, layout_r4_dim0major_);
1418   auto c128 = LiteralUtil::CreateR4WithLayout<complex128>({{
1419     {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
1420     {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
1421     {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
1422   }}, layout_r4_dim0major_);  // clang-format on
1423   Literal conv;
1424 
1425   conv = s8.Convert(U16).value();
1426   EXPECT_EQ(conv, u16);
1427 
1428   conv = s8.Convert(S16).value();
1429   EXPECT_EQ(conv, s16);
1430 
1431   conv = s8.Convert(U32).value();
1432   EXPECT_EQ(conv, u32);
1433 
1434   conv = s8.Convert(S32).value();
1435   EXPECT_EQ(conv, s32);
1436 
1437   conv = s8.Convert(U64).value();
1438   EXPECT_EQ(conv, u64);
1439 
1440   conv = s8.Convert(S64).value();
1441   EXPECT_EQ(conv, s64);
1442 
1443   conv = s8.Convert(PRED).value();
1444   EXPECT_EQ(conv, pred);
1445 
1446   conv = bf16.Convert(S32).value();
1447   EXPECT_EQ(conv, s32);
1448 
1449   conv = bf16.Convert(F32).value();
1450   EXPECT_EQ(conv, f32);
1451 
1452   conv = pred.Convert(S32).value();
1453   EXPECT_EQ(conv, int32_pred);
1454 
1455   conv = f32.Convert(S32).value();
1456   EXPECT_EQ(conv, s32);
1457 
1458   conv = f64.Convert(S32).value();
1459   EXPECT_EQ(conv, s32);
1460 
1461   conv = s32.Convert(F32).value();
1462   EXPECT_EQ(conv, f32);
1463 
1464   conv = f32.Convert(F16).value();
1465   EXPECT_EQ(conv, f16);
1466 
1467   conv = f64.Convert(F16).value();
1468   EXPECT_EQ(conv, f16);
1469 
1470   conv = s32.Convert(F16).value();
1471   EXPECT_EQ(conv, f16);
1472 
1473   conv = u32.Convert(F16).value();
1474   EXPECT_EQ(conv, f16);
1475 
1476   conv = s32.Convert(C64).value();
1477   EXPECT_EQ(conv, c64);
1478 
1479   conv = f16.Convert(C64).value();
1480   EXPECT_EQ(conv, c64);
1481 
1482   conv = s32.Convert(S16).value();
1483   EXPECT_EQ(conv, s16);
1484 
1485   conv = s32.Convert(U16).value();
1486   EXPECT_EQ(conv, u16);
1487 
1488   conv = s32.Convert(C128).value();
1489   EXPECT_EQ(conv, c128);
1490 
1491   conv = f16.Convert(C128).value();
1492   EXPECT_EQ(conv, c128);
1493 
1494   EXPECT_EQ(s32.Convert(TUPLE).status().code(),
1495             tensorflow::error::UNIMPLEMENTED);
1496   EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
1497   EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
1498   EXPECT_EQ(c128.Convert(F32).status().code(),
1499             tensorflow::error::UNIMPLEMENTED);
1500   EXPECT_EQ(c128.Convert(S32).status().code(),
1501             tensorflow::error::UNIMPLEMENTED);
1502 }
1503 
TEST_F(LiteralUtilTest,BitcastConvert)1504 TEST_F(LiteralUtilTest, BitcastConvert) {
1505   Literal original = LiteralUtil::CreateR1<uint32_t>(
1506       {absl::bit_cast<uint32_t>(2.5f), absl::bit_cast<uint32_t>(-42.25f),
1507        absl::bit_cast<uint32_t>(100.f), 0xbeef});
1508   Literal expected = LiteralUtil::CreateR1<float>(
1509       {2.5f, -42.25f, 100.0f, absl::bit_cast<float>(0xbeef)});
1510   TF_ASSERT_OK_AND_ASSIGN(Literal converted,
1511                           original.BitcastConvert(ShapeUtil::ChangeElementType(
1512                               original.shape(), F32)));
1513 }
1514 
TEST_F(LiteralUtilTest,BitcastConvertBetweenInvalidTypes)1515 TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
1516   Literal literal = LiteralUtil::CreateR0<uint32_t>(1234);
1517   Status status =
1518       literal.BitcastConvert(ShapeUtil::ChangeElementType(literal.shape(), F64))
1519           .status();
1520   EXPECT_NE(OkStatus(), status);
1521   EXPECT_TRUE(absl::StrContains(status.error_message(),
1522                                 "to a shape of different size"));
1523 }
1524 
1525 // Sets the layout of the given ShapeProto to the default.
SetDefaultLayoutOnProto(ShapeProto * shape_proto)1526 void SetDefaultLayoutOnProto(ShapeProto* shape_proto) {
1527   CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type()));
1528   auto* minor_to_major =
1529       shape_proto->mutable_layout()->mutable_minor_to_major();
1530   minor_to_major->Resize(shape_proto->dimensions_size(), 0);
1531   const int64_t size = minor_to_major->size();
1532   for (int64_t i = 0; i < size; ++i) {
1533     minor_to_major->Set(i, size - 1 - i);
1534   }
1535 }
1536 
TEST_F(LiteralUtilTest,CopyFromProto_Bool)1537 TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
1538   LiteralProto p;
1539   p.mutable_shape()->set_element_type(PRED);
1540   for (int len = 0; len < 25; ++len) {
1541     p.mutable_shape()->clear_dimensions();
1542     p.mutable_shape()->add_dimensions(len);
1543     SetDefaultLayoutOnProto(p.mutable_shape());
1544     p.clear_preds();
1545     for (int i = 0; i < len; ++i) {
1546       p.add_preds((i % 2) == (len % 2));
1547     }
1548 
1549     TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1550     ASSERT_EQ(len, literal.data<bool>().size());
1551     int i = 0;
1552     for (bool value : literal.data<bool>()) {
1553       EXPECT_EQ((i % 2) == (len % 2), value);
1554       ++i;
1555     }
1556   }
1557 }
1558 
1559 // Note that f16 is currently stored in a byte array in little endian byte order
TEST_F(LiteralUtilTest,ToProto_f16)1560 TEST_F(LiteralUtilTest, ToProto_f16) {
1561   half h1(1.0f);
1562   half h2(2.0f);
1563 
1564   auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
1565   EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
1566   EXPECT_EQ(4, m.data<half>().size());
1567 
1568   LiteralProto p = m.ToProto();
1569   EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape())));
1570   EXPECT_EQ(8, p.f16s().size());
1571   const char* d = p.f16s().data();
1572   EXPECT_EQ(d[0], 0);
1573   EXPECT_EQ(d[1], 0x3C);
1574   EXPECT_EQ(d[2], 0);
1575   EXPECT_EQ(d[3], 0x40);
1576   EXPECT_EQ(d[4], 0);
1577   EXPECT_EQ(d[5], 0x40);
1578   EXPECT_EQ(d[6], 0);
1579   EXPECT_EQ(d[7], 0x3C);
1580 }
1581 
1582 // Note that f16 is currently stored in a byte array in little endian byte order
TEST_F(LiteralUtilTest,CopyFromProto_f16)1583 TEST_F(LiteralUtilTest, CopyFromProto_f16) {
1584   half h1(1.0f);
1585   half h2(2.0f);
1586 
1587   const char half_vals[8] = {0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C};
1588   LiteralProto p;
1589   p.mutable_shape()->set_element_type(F16);
1590   p.mutable_shape()->clear_dimensions();
1591   p.mutable_shape()->add_dimensions(4);
1592   SetDefaultLayoutOnProto(p.mutable_shape());
1593   p.clear_f16s();
1594   p.set_f16s(half_vals, 8);
1595   TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1596   auto r = literal.data<half>();
1597   ASSERT_EQ(4, r.size());
1598   EXPECT_EQ(h1, r[0]);
1599   EXPECT_EQ(h2, r[1]);
1600   EXPECT_EQ(h2, r[2]);
1601   EXPECT_EQ(h1, r[3]);
1602 }
1603 
TEST_F(LiteralUtilTest,CopyFromProto_u16)1604 TEST_F(LiteralUtilTest, CopyFromProto_u16) {
1605   uint16_t u1(0xabcd);
1606   uint16_t u2(0x1234);
1607 
1608   const unsigned char uint16_vals[8] = {0xcd, 0xab, 0x34, 0x12,
1609                                         0x34, 0x12, 0xcd, 0xab};
1610   LiteralProto p;
1611   p.mutable_shape()->set_element_type(U16);
1612   p.mutable_shape()->clear_dimensions();
1613   p.mutable_shape()->add_dimensions(4);
1614   SetDefaultLayoutOnProto(p.mutable_shape());
1615   p.clear_u16s();
1616   p.set_u16s(uint16_vals, 8);
1617   TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1618   auto r = literal.data<uint16_t>();
1619   ASSERT_EQ(4, r.size());
1620   EXPECT_EQ(u1, r[0]);
1621   EXPECT_EQ(u2, r[1]);
1622   EXPECT_EQ(u2, r[2]);
1623   EXPECT_EQ(u1, r[3]);
1624 }
1625 
TEST_F(LiteralUtilTest,LiteralDynamicSliceTest)1626 TEST_F(LiteralUtilTest, LiteralDynamicSliceTest) {
1627   auto scalar = LiteralUtil::CreateR0<float>(1.0);
1628   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1629   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1630   auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1631   Literal nil(ShapeUtil::MakeNil());
1632 
1633   EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
1634   EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
1635   EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
1636   EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
1637   EXPECT_EQ(LiteralSlice(nil, {}), nil);
1638 
1639   EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
1640   EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
1641 
1642   EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
1643   EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
1644   EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
1645   EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
1646 }
1647 
TEST_F(LiteralUtilTest,MutatingLiteralSlice)1648 TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
1649   auto scalar = LiteralUtil::CreateR0<float>(1.0);
1650   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1651   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1652   auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1653   // Verify that changing the underlying data beneath the view changes the
1654   // data of the view itself.
1655   const auto nested_tuple_view = LiteralSlice(nested_tuple);
1656   EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
1657             1.0f);
1658   EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
1659                                          /*shape_index=*/{0, 0}),
1660             1.0f);
1661   nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
1662   EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
1663             555.0f);
1664   EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
1665                                          /*shape_index=*/{0, 0}),
1666             555.0f);
1667 }
1668 
TEST_F(LiteralUtilTest,LiteralSliceOfALiteralSlice)1669 TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
1670   auto scalar = LiteralUtil::CreateR0<float>(1.0);
1671   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1672   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1673   auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1674 
1675   const auto nested_tuple_view = LiteralSlice(nested_tuple);
1676   const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
1677   const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
1678   EXPECT_EQ(matrix_view,
1679             LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
1680 }
1681 
TEST_F(LiteralUtilTest,BorrowingLiteralFromOneBufferPtr)1682 TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
1683   std::vector<int64_t> int64_values = {1, 2, 3};
1684   const Shape literal_shape = ShapeUtil::MakeShape(S64, {3});
1685 
1686   BorrowingLiteral literal(reinterpret_cast<const char*>(int64_values.data()),
1687                            literal_shape);
1688 
1689   EXPECT_EQ(literal.Get<int64_t>({0}), 1);
1690   EXPECT_EQ(literal.Get<int64_t>({1}), 2);
1691   EXPECT_EQ(literal.Get<int64_t>({2}), 3);
1692 }
1693 
TEST_F(LiteralUtilTest,BorrowingLiteralFromMultipleBufferPtrs)1694 TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
1695   std::vector<int64_t> one_two_three = {1, 2, 3};
1696   const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3});
1697 
1698   std::vector<int64_t> hundred = {100};
1699   const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1});
1700 
1701   std::vector<const char*> src_buf_ptrs;
1702   src_buf_ptrs.emplace_back(
1703       reinterpret_cast<const char*>(one_two_three.data()));
1704   src_buf_ptrs.emplace_back(reinterpret_cast<const char*>(hundred.data()));
1705   auto literal_tuple = BorrowingLiteral(
1706       src_buf_ptrs,
1707       ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape}));
1708 
1709   EXPECT_EQ(
1710       literal_tuple.Get<int64_t>(/*multi_index=*/{0}, /*shape_index=*/{0}), 1);
1711   EXPECT_EQ(
1712       literal_tuple.Get<int64_t>(/*multi_index=*/{0}, /*shape_index=*/{1}),
1713       100);
1714 
1715   EXPECT_EQ(
1716       literal_tuple.Get<int64_t>(/*multi_index=*/{1}, /*shape_index=*/{0}), 2);
1717 
1718   EXPECT_EQ(
1719       literal_tuple.Get<int64_t>(/*multi_index=*/{2}, /*shape_index=*/{0}), 3);
1720 }
1721 
TEST_F(LiteralUtilTest,LiteralMove)1722 TEST_F(LiteralUtilTest, LiteralMove) {
1723   Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1724   Literal literal(std::move(matrix));
1725 
1726   EXPECT_TRUE(
1727       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
1728   EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
1729   EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
1730   EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
1731   EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
1732 }
1733 
TEST_F(LiteralUtilTest,DecomposeTuple)1734 TEST_F(LiteralUtilTest, DecomposeTuple) {
1735   Literal nil_literal(ShapeUtil::MakeNil());
1736   Literal inner_elements[] = {
1737       LiteralUtil::CreateR0<int32_t>(42),
1738       LiteralUtil::CreateR1<double>({23.0, 44.0}),
1739   };
1740   Literal tuple_elements[] = {
1741       LiteralUtil::CreateR2<int32_t>({{1, 2}, {3, 4}}),
1742       LiteralUtil::MakeTuple(
1743           {&inner_elements[0], &inner_elements[1], &nil_literal}),
1744   };
1745   Literal nested_tuple = LiteralUtil::MakeTuple(
1746       {&tuple_elements[0], &tuple_elements[1], &nil_literal});
1747 
1748   EXPECT_FALSE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
1749   std::vector<Literal> elements = nested_tuple.DecomposeTuple();
1750   EXPECT_TRUE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
1751 
1752   ASSERT_EQ(elements.size(), 3);
1753 
1754   EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(),
1755                                     ShapeUtil::MakeShape(S32, {2, 2})));
1756   EXPECT_EQ(elements[0].Get<int32_t>({0, 0}), 1);
1757   EXPECT_EQ(elements[0].Get<int32_t>({0, 1}), 2);
1758   EXPECT_EQ(elements[0].Get<int32_t>({1, 0}), 3);
1759   EXPECT_EQ(elements[0].Get<int32_t>({1, 1}), 4);
1760 
1761   EXPECT_TRUE(ShapeUtil::Compatible(
1762       elements[1].shape(),
1763       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}),
1764                                  ShapeUtil::MakeShape(F64, {2}),
1765                                  ShapeUtil::MakeNil()})));
1766   EXPECT_EQ(elements[1].Get<int32_t>({}, /*shape_index=*/{0}), 42);
1767   EXPECT_EQ(elements[1].Get<double>({0}, /*shape_index=*/{1}), 23.0);
1768   EXPECT_EQ(elements[1].Get<double>({1}, /*shape_index=*/{1}), 44.0);
1769 
1770   EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil()));
1771 }
1772 
TEST_F(LiteralUtilTest,DecomposeEmptyTuple)1773 TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
1774   Literal nil_literal(ShapeUtil::MakeNil());
1775   std::vector<Literal> elements = nil_literal.DecomposeTuple();
1776   EXPECT_EQ(elements.size(), 0);
1777 }
1778 
TEST_F(LiteralUtilTest,MoveIntoTuple)1779 TEST_F(LiteralUtilTest, MoveIntoTuple) {
1780   std::vector<Literal> elements;
1781   elements.push_back(LiteralUtil::CreateR0<float>(1.0));
1782   elements.push_back(LiteralUtil::CreateR1<int32_t>({4, 8}));
1783   std::vector<Literal> inner_elements;
1784   inner_elements.push_back(LiteralUtil::CreateR0<int32_t>(42));
1785   inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
1786   elements.push_back(
1787       LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
1788 
1789   Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
1790   ASSERT_TRUE(literal.shape().IsTuple());
1791   ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
1792 
1793   EXPECT_EQ(literal.Get<float>({}, /*shape_index=*/{0}), 1.0);
1794   EXPECT_EQ(literal.Get<int32_t>({0}, /*shape_index=*/{1}), 4);
1795   EXPECT_EQ(literal.Get<int32_t>({1}, /*shape_index=*/{1}), 8);
1796   EXPECT_EQ(literal.Get<int32_t>({}, /*shape_index=*/{2, 0}), 42);
1797   EXPECT_EQ(literal.Get<double>({0}, /*shape_index=*/{2, 1}), 23.0);
1798   EXPECT_EQ(literal.Get<double>({1}, /*shape_index=*/{2, 1}), 44.0);
1799 
1800   for (const Literal& element : elements) {
1801     EXPECT_TRUE(ShapeUtil::IsEmptyTuple(element.shape()));
1802   }
1803 }
1804 
TEST_F(LiteralUtilTest,MoveIntoEmptyTuple)1805 TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) {
1806   Literal literal = Literal::MoveIntoTuple({});
1807   ASSERT_TRUE(literal.shape().IsTuple());
1808   EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
1809 }
1810 
TEST_F(LiteralUtilTest,LiteralMoveAssignment)1811 TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
1812   Literal literal;
1813   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
1814 
1815   Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1816   literal = std::move(matrix);
1817 
1818   EXPECT_TRUE(
1819       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
1820   EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
1821   EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
1822   EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
1823   EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
1824 }
1825 
TEST_F(LiteralUtilTest,LiteralSliceCopy)1826 TEST_F(LiteralUtilTest, LiteralSliceCopy) {
1827   Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1828   const auto matrix_view = LiteralSlice(matrix);
1829   LiteralSlice matrix_view_copy(matrix_view);
1830 
1831   EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
1832   EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
1833   EXPECT_EQ(matrix_view_copy.Get<float>({1, 0}), 3.0);
1834   EXPECT_EQ(matrix_view_copy.Get<float>({1, 1}), 4.0);
1835 }
1836 
TEST_F(LiteralUtilTest,GetSetTuple)1837 TEST_F(LiteralUtilTest, GetSetTuple) {
1838   Literal elements[] = {
1839       LiteralUtil::CreateR0<float>(42.0),
1840       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
1841   };
1842   auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
1843   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
1844   tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
1845   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
1846 
1847   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
1848   tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
1849   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
1850             -4.0);
1851 }
1852 
TEST_F(LiteralUtilTest,CreateFromShapeZeroInitialized)1853 TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
1854   // Literals constructed using CreateFromShape should be zero initialized.
1855   Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
1856   EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
1857   EXPECT_TRUE(scalar_f32.IsAll(0));
1858 
1859   Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
1860   EXPECT_EQ(vector_s32.Get<int32_t>({0}), 0);
1861   EXPECT_EQ(vector_s32.Get<int32_t>({1}), 0);
1862   EXPECT_EQ(vector_s32.Get<int32_t>({2}), 0);
1863   EXPECT_TRUE(vector_s32.IsAll(0));
1864 
1865   Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
1866       {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
1867        ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}),
1868        ShapeUtil::MakeShape(C128, {})}));
1869 
1870   EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
1871   EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
1872   EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
1873   EXPECT_EQ(tuple.Get<uint64_t>({0, 0}, {2}), 0);
1874   EXPECT_EQ(tuple.Get<uint64_t>({1, 0}, {2}), 0);
1875   EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
1876   EXPECT_EQ(tuple.Get<complex128>({}, {4}), complex128(0.0, 0.0));
1877 }
1878 
TEST_F(LiteralUtilTest,ProtoRoundTrip)1879 TEST_F(LiteralUtilTest, ProtoRoundTrip) {
1880   // Test serializing then deserializing a Literal through a proto.
1881   auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
1882   auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
1883   auto vector_int8 = LiteralUtil::CreateR1<int8_t>({-128, 0, 2, 4, 7, 56, 127});
1884   auto vector_uint8 = LiteralUtil::CreateR1<uint8_t>({128, 0, 2, 56, 127, 255});
1885   auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
1886   auto vector_c128 =
1887       LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
1888   auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
1889       {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
1890   auto vector_half =
1891       LiteralUtil::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
1892   auto matrix_pred =
1893       LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
1894   auto tuple = LiteralUtil::MakeTuple(
1895       {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
1896   Literal nil_literal(ShapeUtil::MakeNil());
1897   auto nested_tuple =
1898       LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
1899 
1900   auto to_from_proto = [](const Literal& literal) -> Literal {
1901     return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
1902   };
1903 
1904   EXPECT_EQ(one_f32, to_from_proto(one_f32));
1905   EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
1906   EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
1907   EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
1908   EXPECT_EQ(vector_c128, to_from_proto(vector_c128));
1909   EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
1910   EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
1911   EXPECT_EQ(tuple, to_from_proto(tuple));
1912   EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
1913   EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
1914 
1915   EXPECT_NE(one_f32, two_f32);
1916   EXPECT_NE(one_f32, to_from_proto(two_f32));
1917 }
1918 
TEST_F(LiteralUtilTest,InvalidProtoNoValues)1919 TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
1920   // Proto contains a shape, but no values.
1921   LiteralProto proto;
1922   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
1923   Status status = Literal::CreateFromProto(proto).status();
1924   ASSERT_FALSE(status.ok());
1925   EXPECT_THAT(status.error_message(),
1926               HasSubstr("Expected 3 elements in LiteralProto"));
1927 }
1928 
TEST_F(LiteralUtilTest,ValidProtoNoValues)1929 TEST_F(LiteralUtilTest, ValidProtoNoValues) {
1930   // Proto contains a shape, but no values.
1931   LiteralProto proto;
1932   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
1933   Status status =
1934       Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
1935           .status();
1936   EXPECT_TRUE(status.ok());
1937 }
1938 
TEST_F(LiteralUtilTest,ValidProtoWithClearedValues)1939 TEST_F(LiteralUtilTest, ValidProtoWithClearedValues) {
1940   auto literal = LiteralUtil::CreateR1<bool>({true, false, true});
1941   LiteralProto proto = literal.ToProto();
1942   EXPECT_EQ(proto.preds_size(), 3);
1943 
1944   // Clear values.
1945   proto.clear_preds();
1946   EXPECT_EQ(proto.preds_size(), 0);
1947   Status status =
1948       Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
1949           .status();
1950   EXPECT_TRUE(status.ok());
1951 }
1952 
TEST_F(LiteralUtilTest,InvalidProtoNoShape)1953 TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
1954   // Proto contains values, but no shape.
1955   LiteralProto proto;
1956   proto.add_preds(false);
1957   proto.add_preds(true);
1958   proto.add_preds(false);
1959   Status status = Literal::CreateFromProto(proto).status();
1960   ASSERT_FALSE(status.ok());
1961   EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
1962 }
1963 
TEST_F(LiteralUtilTest,InvalidProtoWrongContainer)1964 TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
1965   // Proto contains values in wrong container.
1966   LiteralProto proto;
1967   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
1968   proto.add_preds(false);
1969   proto.add_preds(true);
1970   proto.add_preds(false);
1971   Status status = Literal::CreateFromProto(proto).status();
1972   ASSERT_FALSE(status.ok());
1973   EXPECT_THAT(status.error_message(),
1974               HasSubstr("Expected 3 elements in LiteralProto"));
1975 }
1976 
TEST_F(LiteralUtilTest,InvalidProtoTooFewValues)1977 TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
1978   // Proto contains too few values.
1979   LiteralProto proto;
1980   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto();
1981   proto.add_f32s(1.0);
1982   proto.add_f32s(2.0);
1983   proto.add_f32s(3.0);
1984   Status status = Literal::CreateFromProto(proto).status();
1985   ASSERT_FALSE(status.ok());
1986   EXPECT_THAT(status.error_message(),
1987               HasSubstr("Expected 84 elements in LiteralProto"));
1988 }
1989 
TEST_F(LiteralUtilTest,InvalidProtoTooManyValues)1990 TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
1991   // Proto contains too many values.
1992   LiteralProto proto;
1993   *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto();
1994   proto.add_s32s(42);
1995   proto.add_s32s(-10);
1996   proto.add_s32s(100);
1997   Status status = Literal::CreateFromProto(proto).status();
1998   ASSERT_FALSE(status.ok());
1999   EXPECT_THAT(status.error_message(),
2000               HasSubstr("Expected 2 elements in LiteralProto"));
2001 }
2002 
TEST_F(LiteralUtilTest,InvalidProtoMissingLayout)2003 TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
2004   // Proto shape missing layout.
2005   LiteralProto proto;
2006   *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto();
2007   proto.mutable_shape()->clear_layout();
2008   proto.add_preds(true);
2009   proto.add_preds(false);
2010   proto.add_preds(true);
2011   proto.add_preds(false);
2012   Status status = Literal::CreateFromProto(proto).status();
2013   ASSERT_FALSE(status.ok());
2014   EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
2015 }
2016 
TEST_F(LiteralUtilTest,InvalidProtoTooFewTupleElements)2017 TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
2018   // Proto has the too few tuple elements.
2019   LiteralProto proto;
2020   *proto.mutable_shape() =
2021       ShapeUtil::MakeTupleShape(
2022           {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
2023           .ToProto();
2024   LiteralProto* element0 = proto.add_tuple_literals();
2025   *element0->mutable_shape() =
2026       ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
2027   element0->add_preds(false);
2028   element0->add_preds(true);
2029 
2030   Status status = Literal::CreateFromProto(proto).status();
2031   ASSERT_FALSE(status.ok());
2032   EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
2033 }
2034 
TEST_F(LiteralUtilTest,InvalidProtoTooManyTupleElements)2035 TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
2036   // Proto has the too many tuple elements.
2037   LiteralProto proto;
2038   *proto.mutable_shape() =
2039       ShapeUtil::MakeTupleShape(
2040           {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
2041           .ToProto();
2042   LiteralProto* element0 = proto.add_tuple_literals();
2043   *element0->mutable_shape() =
2044       ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
2045   element0->add_preds(false);
2046   element0->add_preds(true);
2047   LiteralProto* element1 = proto.add_tuple_literals();
2048   *element1->mutable_shape() =
2049       ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto();
2050   element1->add_f32s(42.0);
2051   LiteralProto* element2 = proto.add_tuple_literals();
2052   *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto();
2053   element2->add_f32s(123.0);
2054 
2055   Status status = Literal::CreateFromProto(proto).status();
2056   ASSERT_FALSE(status.ok());
2057   EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
2058 }
2059 
TEST_F(LiteralUtilTest,BroadcastVectorToMatrix0)2060 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
2061   Literal literal = LiteralUtil::CreateR1<int64_t>({1, 2});
2062   TF_ASSERT_OK_AND_ASSIGN(
2063       Literal broadcasted_literal,
2064       literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
2065                         /*dimensions=*/{0}));
2066   EXPECT_EQ(broadcasted_literal,
2067             LiteralUtil::CreateR2<int64_t>({{1, 1}, {2, 2}}));
2068 }
2069 
TEST_F(LiteralUtilTest,BroadcastVectorToMatrix1)2070 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
2071   Literal literal = LiteralUtil::CreateR1<int64_t>({1, 2});
2072   TF_ASSERT_OK_AND_ASSIGN(
2073       Literal broadcasted_literal,
2074       literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
2075                         /*dimensions=*/{1}));
2076   EXPECT_EQ(broadcasted_literal,
2077             LiteralUtil::CreateR2<int64_t>({{1, 2}, {1, 2}}));
2078 }
2079 
TEST_F(LiteralUtilTest,BroadcastScalarToMatrix)2080 TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
2081   Literal literal = LiteralUtil::CreateR0<int32_t>(9);
2082   TF_ASSERT_OK_AND_ASSIGN(
2083       Literal broadcasted_literal,
2084       literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
2085                         /*dimensions=*/{}));
2086   EXPECT_EQ(broadcasted_literal,
2087             LiteralUtil::CreateR2<int32_t>({{9, 9}, {9, 9}}));
2088 }
2089 
TEST_F(LiteralUtilTest,DynamicBroadcast)2090 TEST_F(LiteralUtilTest, DynamicBroadcast) {
2091   Literal literal = LiteralUtil::CreateR1<int64_t>({1, 2});
2092   literal.SetDynamicSize(0, 1);
2093   TF_ASSERT_OK_AND_ASSIGN(
2094       Literal broadcasted_literal,
2095       literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
2096                         /*dimensions=*/{1}));
2097   EXPECT_EQ(broadcasted_literal, LiteralUtil::CreateR2<int64_t>({{1}, {1}}));
2098   EXPECT_EQ(broadcasted_literal.GetDynamicSize(1), 1);
2099 }
2100 
TEST_F(LiteralUtilTest,GetAsComplex128)2101 TEST_F(LiteralUtilTest, GetAsComplex128) {
2102   complex128 value = {1, 0};
2103   Literal c1 = LiteralUtil::CreateR0<complex128>(value);
2104   EXPECT_EQ(*c1.GetAsComplex128({}), value);
2105   Literal c2 = LiteralUtil::CreateR0<double>(1);
2106   EXPECT_EQ(*c2.GetAsComplex128({}), value);
2107   complex64 float_value = {1, 0};
2108   Literal c4 = LiteralUtil::CreateR0<complex64>(float_value);
2109   EXPECT_EQ(*c4.GetAsComplex128({}), value);
2110   complex128 other_value = {1, 2};
2111   Literal c5 = LiteralUtil::CreateR0<complex128>(other_value);
2112   EXPECT_EQ(*c5.GetAsComplex128({}), other_value);
2113   Literal c6 = LiteralUtil::CreateR0<int64_t>(1);
2114   EXPECT_FALSE(c6.GetAsComplex128({}).has_value());
2115 }
2116 
TEST_F(LiteralUtilTest,SliceOnBool)2117 TEST_F(LiteralUtilTest, SliceOnBool) {
2118   Literal c1 = LiteralUtil::CreateR1<bool>({true, true, false});
2119   EXPECT_EQ(c1, c1.Slice({0}, {3}));
2120 }
2121 
TEST_F(LiteralUtilTest,IsEqualAt)2122 TEST_F(LiteralUtilTest, IsEqualAt) {
2123   double val_double = 10.0;
2124   int val_integral = 10;
2125   Literal c1 = LiteralUtil::CreateR0<int>(10);
2126   EXPECT_TRUE(c1.IsEqualAt({}, val_double));
2127   EXPECT_TRUE(c1.IsEqualAt({}, val_integral));
2128   Literal c2 = LiteralUtil::CreateR0<double>(10);
2129   EXPECT_TRUE(c2.IsEqualAt({}, val_double));
2130   EXPECT_TRUE(c2.IsEqualAt({}, val_integral));
2131   complex128 val_complex = {10, 0};
2132   EXPECT_TRUE(c2.IsEqualAt({}, val_complex));
2133   EXPECT_TRUE(c1.IsEqualAt({}, val_complex));
2134   Literal c3 = LiteralUtil::CreateR0<complex128>(val_complex);
2135   EXPECT_TRUE(c3.IsEqualAt({}, val_double));
2136   EXPECT_TRUE(c3.IsEqualAt({}, val_integral));
2137   EXPECT_TRUE(c3.IsEqualAt({}, val_complex));
2138   EXPECT_FALSE(c3.IsEqualAt({}, std::numeric_limits<double>::infinity()));
2139   complex128 val_true_complex = {10, 3};
2140   complex64 val_smaller_complex = {10, 3};
2141   Literal c4 = LiteralUtil::CreateR0<complex128>(val_true_complex);
2142   EXPECT_TRUE(c4.IsEqualAt({}, val_true_complex));
2143   EXPECT_TRUE(c4.IsEqualAt({}, val_smaller_complex));
2144 }
2145 
TEST_F(LiteralUtilTest,CreateFromShapeWithUnknownLeafArrays)2146 TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) {
2147   Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays(
2148       ShapeUtil::MakeShape(F32, {4, 4}));
2149   EXPECT_FALSE(c1.IsKnown());
2150 }
2151 
TEST_F(LiteralUtilTest,CreatePartiallyKnownTuple)2152 TEST_F(LiteralUtilTest, CreatePartiallyKnownTuple) {
2153   Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays(
2154       ShapeUtil::MakeShape(F32, {4, 4}));
2155   Literal c2 = LiteralUtil::CreateR0<int>(10);
2156   Literal c3 = LiteralUtil::MakeTuple({&c1, &c2});
2157   Literal c4 = LiteralUtil::CreateR0<int>(100);
2158   Literal c5 = LiteralUtil::MakeTuple({&c4, &c3});
2159   EXPECT_FALSE(c5.IsKnown());
2160 }
2161 
TEST_F(LiteralUtilTest,CopyFromPartiallyKnownTuple)2162 TEST_F(LiteralUtilTest, CopyFromPartiallyKnownTuple) {
2163   Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays(
2164       ShapeUtil::MakeShape(F32, {4, 4}));
2165   Literal c2 = LiteralUtil::CreateR0<int>(10);
2166   Literal c3 = LiteralUtil::MakeTuple({&c1, &c2});
2167   Literal c4 = LiteralUtil::CreateR0<int>(100);
2168   Literal c5 = LiteralUtil::MakeTuple({&c4, &c3});
2169   Literal c6 = Literal::CreateFromShape(c5.shape());
2170   TF_ASSERT_OK(
2171       c6.CopyFrom(c5, /*dest_shape_index=*/{1}, /*src_shape_index=*/{1}));
2172   EXPECT_FALSE(c6.IsKnown());
2173 }
2174 
TEST_F(LiteralUtilTest,CopyFromPartiallyKnownTupleUnknownTupleElement)2175 TEST_F(LiteralUtilTest, CopyFromPartiallyKnownTupleUnknownTupleElement) {
2176   Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays(
2177       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4, 4}),
2178                                  ShapeUtil::MakeShape(F32, {4, 4})}));
2179   Literal c2 = LiteralUtil::CreateR0<int>(10);
2180   Literal c3 = LiteralUtil::MakeTuple({&c1, &c2});
2181   Literal c4 = LiteralUtil::CreateR0<int>(100);
2182   Literal c5 = LiteralUtil::MakeTuple({&c4, &c3});
2183   Literal c6 = Literal::CreateFromShape(c5.shape());
2184   Literal c1_copy = Literal::CreateFromShape(c1.shape());
2185   Literal c2_copy = Literal::CreateFromShape(c2.shape());
2186   TF_ASSERT_OK(
2187       c6.CopyFrom(c5, /*dest_shape_index=*/{1}, /*src_shape_index=*/{1}));
2188   TF_ASSERT_OK(c1_copy.CopyFrom(c6, /*dest_shape_index=*/{},
2189                                 /*src_shape_index=*/{1, 0}));
2190   TF_ASSERT_OK(c2_copy.CopyFrom(c6, /*dest_shape_index=*/{},
2191                                 /*src_shape_index=*/{1, 1}));
2192   EXPECT_FALSE(c6.IsKnown());
2193   EXPECT_FALSE(c1_copy.IsKnown());
2194   EXPECT_TRUE(c2_copy.IsKnown());
2195 }
2196 
2197 }  // namespace
2198 }  // namespace xla
2199