xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/literal_test_util_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests that our utility functions for dealing with literals are correctly
17 // implemented.
18 
19 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
20 
21 #include <vector>
22 
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/core/platform/env.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/path.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 namespace xla {
31 namespace {
32 
TEST(LiteralTestUtilTest,ComparesEqualTuplesEqual)33 TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
34   Literal literal = LiteralUtil::MakeTupleFromSlices({
35       LiteralUtil::CreateR0<int32_t>(42),
36       LiteralUtil::CreateR0<int32_t>(64),
37   });
38   EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
39 }
40 
TEST(LiteralTestUtilTest,ComparesEqualComplex64TuplesEqual)41 TEST(LiteralTestUtilTest, ComparesEqualComplex64TuplesEqual) {
42   Literal literal = LiteralUtil::MakeTupleFromSlices({
43       LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
44       LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
45   });
46   EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
47 }
48 
TEST(LiteralTestUtilTest,ComparesEqualComplex128TuplesEqual)49 TEST(LiteralTestUtilTest, ComparesEqualComplex128TuplesEqual) {
50   Literal literal = LiteralUtil::MakeTupleFromSlices({
51       LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
52       LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
53   });
54   EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
55 }
56 
TEST(LiteralTestUtilTest,ComparesUnequalComplex64TuplesUnequal)57 TEST(LiteralTestUtilTest, ComparesUnequalComplex64TuplesUnequal) {
58   Literal literal0 = LiteralUtil::MakeTupleFromSlices({
59       LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
60       LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
61   });
62   Literal literal1 = LiteralUtil::MakeTupleFromSlices({
63       LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
64       LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
65   });
66   Literal literal2 = LiteralUtil::MakeTupleFromSlices({
67       LiteralUtil::CreateR0<complex64>({42.42, 64.0}),
68       LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
69   });
70   Literal literal3 = LiteralUtil::MakeTupleFromSlices({
71       LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
72       LiteralUtil::CreateR0<complex64>({64.0, 42.42}),
73   });
74   EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1));
75   EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2));
76   EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3));
77   EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3));
78 }
79 
TEST(LiteralTestUtilTest,ComparesUnequalComplex128TuplesUnequal)80 TEST(LiteralTestUtilTest, ComparesUnequalComplex128TuplesUnequal) {
81   Literal literal0 = LiteralUtil::MakeTupleFromSlices({
82       LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
83       LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
84   });
85   Literal literal1 = LiteralUtil::MakeTupleFromSlices({
86       LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
87       LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
88   });
89   Literal literal2 = LiteralUtil::MakeTupleFromSlices({
90       LiteralUtil::CreateR0<complex128>({42.42, 64.0}),
91       LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
92   });
93   Literal literal3 = LiteralUtil::MakeTupleFromSlices({
94       LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
95       LiteralUtil::CreateR0<complex128>({64.0, 42.42}),
96   });
97   EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1));
98   EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2));
99   EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3));
100   EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3));
101 }
102 
TEST(LiteralTestUtilTest,ComparesUnequalTuplesUnequal)103 TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
104   // Implementation note: we have to use a death test here, because you can't
105   // un-fail an assertion failure. The CHECK-failure is death, so we can make a
106   // death assertion.
107   auto unequal_things_are_equal = [] {
108     Literal lhs = LiteralUtil::MakeTupleFromSlices({
109         LiteralUtil::CreateR0<int32_t>(42),
110         LiteralUtil::CreateR0<int32_t>(64),
111     });
112     Literal rhs = LiteralUtil::MakeTupleFromSlices({
113         LiteralUtil::CreateR0<int32_t>(64),
114         LiteralUtil::CreateR0<int32_t>(42),
115     });
116     CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal";
117   };
118   ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
119 }
120 
TEST(LiteralTestUtilTest,ExpectNearFailurePlacesResultsInTemporaryDirectory)121 TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
122   auto dummy_lambda = [] {
123     auto two = LiteralUtil::CreateR0<float>(2);
124     auto four = LiteralUtil::CreateR0<float>(4);
125     ErrorSpec error(0.001);
126     CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four";
127   };
128 
129   tensorflow::Env* env = tensorflow::Env::Default();
130 
131   std::string outdir;
132   if (!tensorflow::io::GetTestUndeclaredOutputsDir(&outdir)) {
133     outdir = tensorflow::testing::TmpDir();
134   }
135   std::string pattern = tensorflow::io::JoinPath(outdir, "tempfile-*.pb");
136   std::vector<std::string> files;
137   TF_CHECK_OK(env->GetMatchingPaths(pattern, &files));
138   for (const auto& f : files) {
139     TF_CHECK_OK(env->DeleteFile(f)) << f;
140   }
141 
142   ASSERT_DEATH(dummy_lambda(), "two is not near four");
143 
144   // Now check we wrote temporary files to the temporary directory that we can
145   // read.
146   std::vector<std::string> results;
147   TF_CHECK_OK(env->GetMatchingPaths(pattern, &results));
148 
149   LOG(INFO) << "results: [" << absl::StrJoin(results, ", ") << "]";
150   EXPECT_EQ(3, results.size());
151   for (const std::string& result : results) {
152     LiteralProto literal_proto;
153     TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
154                                             &literal_proto));
155     Literal literal = Literal::CreateFromProto(literal_proto).value();
156     if (result.find("expected") != std::string::npos) {
157       EXPECT_EQ("f32[] 2", literal.ToString());
158     } else if (result.find("actual") != std::string::npos) {
159       EXPECT_EQ("f32[] 4", literal.ToString());
160     } else if (result.find("mismatches") != std::string::npos) {
161       EXPECT_EQ("pred[] true", literal.ToString());
162     } else {
163       FAIL() << "unknown file in temporary directory: " << result;
164     }
165   }
166 }
167 
TEST(LiteralTestUtilTest,NotEqualHasValuesInMessage)168 TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
169   auto expected = LiteralUtil::CreateR1<int32_t>({1, 2, 3});
170   auto actual = LiteralUtil::CreateR1<int32_t>({4, 5, 6});
171   ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual);
172   EXPECT_THAT(result.message(),
173               ::testing::HasSubstr("Expected literal:\ns32[3] {1, 2, 3}"));
174   EXPECT_THAT(result.message(),
175               ::testing::HasSubstr("Actual literal:\ns32[3] {4, 5, 6}"));
176 }
177 
TEST(LiteralTestUtilTest,NearComparatorR1)178 TEST(LiteralTestUtilTest, NearComparatorR1) {
179   auto a = LiteralUtil::CreateR1<float>(
180       {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
181   auto b = LiteralUtil::CreateR1<float>(
182       {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
183   EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
184 }
185 
TEST(LiteralTestUtilTest,NearComparatorR1Complex64)186 TEST(LiteralTestUtilTest, NearComparatorR1Complex64) {
187   auto a = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
188                                              {0.1, 1.1},
189                                              {0.2, 1.2},
190                                              {0.3, 1.3},
191                                              {0.4, 1.4},
192                                              {0.5, 1.5},
193                                              {0.6, 1.6},
194                                              {0.7, 1.7},
195                                              {0.8, 1.8}});
196   auto b = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
197                                              {0.1, 1.1},
198                                              {0.2, 1.2},
199                                              {0.3, 1.3},
200                                              {0.4, 1.4},
201                                              {0.5, 1.5},
202                                              {0.6, 1.6},
203                                              {0.7, 1.7},
204                                              {0.8, 1.8}});
205   auto c = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
206                                              {0.1, 1.1},
207                                              {0.2, 1.2},
208                                              {0.3, 1.3},
209                                              {0.4, 1.4},
210                                              {0.5, 1.5},
211                                              {0.6, 1.6},
212                                              {0.7, 1.7},
213                                              {0.9, 1.8}});
214   auto d = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
215                                              {0.1, 1.1},
216                                              {0.2, 1.2},
217                                              {0.3, 1.3},
218                                              {0.4, 1.4},
219                                              {0.5, 1.5},
220                                              {0.6, 1.6},
221                                              {0.7, 1.7},
222                                              {0.8, 1.9}});
223   EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
224   EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001}));
225   EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001}));
226   EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001}));
227 }
228 
TEST(LiteralTestUtilTest,NearComparatorR1Complex128)229 TEST(LiteralTestUtilTest, NearComparatorR1Complex128) {
230   auto a = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
231                                               {0.1, 1.1},
232                                               {0.2, 1.2},
233                                               {0.3, 1.3},
234                                               {0.4, 1.4},
235                                               {0.5, 1.5},
236                                               {0.6, 1.6},
237                                               {0.7, 1.7},
238                                               {0.8, 1.8}});
239   auto b = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
240                                               {0.1, 1.1},
241                                               {0.2, 1.2},
242                                               {0.3, 1.3},
243                                               {0.4, 1.4},
244                                               {0.5, 1.5},
245                                               {0.6, 1.6},
246                                               {0.7, 1.7},
247                                               {0.8, 1.8}});
248   auto c = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
249                                               {0.1, 1.1},
250                                               {0.2, 1.2},
251                                               {0.3, 1.3},
252                                               {0.4, 1.4},
253                                               {0.5, 1.5},
254                                               {0.6, 1.6},
255                                               {0.7, 1.7},
256                                               {0.9, 1.8}});
257   auto d = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
258                                               {0.1, 1.1},
259                                               {0.2, 1.2},
260                                               {0.3, 1.3},
261                                               {0.4, 1.4},
262                                               {0.5, 1.5},
263                                               {0.6, 1.6},
264                                               {0.7, 1.7},
265                                               {0.8, 1.9}});
266   EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
267   EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001}));
268   EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001}));
269   EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001}));
270 }
271 
TEST(LiteralTestUtilTest,NearComparatorR1Nan)272 TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
273   auto a = LiteralUtil::CreateR1<float>(
274       {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
275   auto b = LiteralUtil::CreateR1<float>(
276       {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
277   EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
278 }
279 
TEST(LiteralTestUtil,NearComparatorDifferentLengths)280 TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
281   auto a = LiteralUtil::CreateR1<float>(
282       {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
283   auto b =
284       LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
285   EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
286   EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001}));
287 }
288 
TEST(LiteralTestUtilTest,ExpectNearDoubleOutsideFloatValueRange)289 TEST(LiteralTestUtilTest, ExpectNearDoubleOutsideFloatValueRange) {
290   auto two_times_float_max =
291       LiteralUtil::CreateR0<double>(2.0 * std::numeric_limits<float>::max());
292   ErrorSpec error(0.001);
293   EXPECT_TRUE(
294       LiteralTestUtil::Near(two_times_float_max, two_times_float_max, error));
295 }
296 
297 }  // namespace
298 }  // namespace xla
299