1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests that our utility functions for dealing with literals are correctly
17 // implemented.
18
19 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
20
21 #include <vector>
22
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/core/platform/env.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/path.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace xla {
31 namespace {
32
TEST(LiteralTestUtilTest,ComparesEqualTuplesEqual)33 TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
34 Literal literal = LiteralUtil::MakeTupleFromSlices({
35 LiteralUtil::CreateR0<int32_t>(42),
36 LiteralUtil::CreateR0<int32_t>(64),
37 });
38 EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
39 }
40
TEST(LiteralTestUtilTest,ComparesEqualComplex64TuplesEqual)41 TEST(LiteralTestUtilTest, ComparesEqualComplex64TuplesEqual) {
42 Literal literal = LiteralUtil::MakeTupleFromSlices({
43 LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
44 LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
45 });
46 EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
47 }
48
TEST(LiteralTestUtilTest,ComparesEqualComplex128TuplesEqual)49 TEST(LiteralTestUtilTest, ComparesEqualComplex128TuplesEqual) {
50 Literal literal = LiteralUtil::MakeTupleFromSlices({
51 LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
52 LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
53 });
54 EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
55 }
56
TEST(LiteralTestUtilTest,ComparesUnequalComplex64TuplesUnequal)57 TEST(LiteralTestUtilTest, ComparesUnequalComplex64TuplesUnequal) {
58 Literal literal0 = LiteralUtil::MakeTupleFromSlices({
59 LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
60 LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
61 });
62 Literal literal1 = LiteralUtil::MakeTupleFromSlices({
63 LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
64 LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
65 });
66 Literal literal2 = LiteralUtil::MakeTupleFromSlices({
67 LiteralUtil::CreateR0<complex64>({42.42, 64.0}),
68 LiteralUtil::CreateR0<complex64>({64.0, 42.0}),
69 });
70 Literal literal3 = LiteralUtil::MakeTupleFromSlices({
71 LiteralUtil::CreateR0<complex64>({42.0, 64.0}),
72 LiteralUtil::CreateR0<complex64>({64.0, 42.42}),
73 });
74 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1));
75 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2));
76 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3));
77 EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3));
78 }
79
TEST(LiteralTestUtilTest,ComparesUnequalComplex128TuplesUnequal)80 TEST(LiteralTestUtilTest, ComparesUnequalComplex128TuplesUnequal) {
81 Literal literal0 = LiteralUtil::MakeTupleFromSlices({
82 LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
83 LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
84 });
85 Literal literal1 = LiteralUtil::MakeTupleFromSlices({
86 LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
87 LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
88 });
89 Literal literal2 = LiteralUtil::MakeTupleFromSlices({
90 LiteralUtil::CreateR0<complex128>({42.42, 64.0}),
91 LiteralUtil::CreateR0<complex128>({64.0, 42.0}),
92 });
93 Literal literal3 = LiteralUtil::MakeTupleFromSlices({
94 LiteralUtil::CreateR0<complex128>({42.0, 64.0}),
95 LiteralUtil::CreateR0<complex128>({64.0, 42.42}),
96 });
97 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal1));
98 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal2));
99 EXPECT_FALSE(LiteralTestUtil::Equal(literal0, literal3));
100 EXPECT_FALSE(LiteralTestUtil::Equal(literal2, literal3));
101 }
102
TEST(LiteralTestUtilTest,ComparesUnequalTuplesUnequal)103 TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
104 // Implementation note: we have to use a death test here, because you can't
105 // un-fail an assertion failure. The CHECK-failure is death, so we can make a
106 // death assertion.
107 auto unequal_things_are_equal = [] {
108 Literal lhs = LiteralUtil::MakeTupleFromSlices({
109 LiteralUtil::CreateR0<int32_t>(42),
110 LiteralUtil::CreateR0<int32_t>(64),
111 });
112 Literal rhs = LiteralUtil::MakeTupleFromSlices({
113 LiteralUtil::CreateR0<int32_t>(64),
114 LiteralUtil::CreateR0<int32_t>(42),
115 });
116 CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal";
117 };
118 ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
119 }
120
TEST(LiteralTestUtilTest,ExpectNearFailurePlacesResultsInTemporaryDirectory)121 TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
122 auto dummy_lambda = [] {
123 auto two = LiteralUtil::CreateR0<float>(2);
124 auto four = LiteralUtil::CreateR0<float>(4);
125 ErrorSpec error(0.001);
126 CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four";
127 };
128
129 tensorflow::Env* env = tensorflow::Env::Default();
130
131 std::string outdir;
132 if (!tensorflow::io::GetTestUndeclaredOutputsDir(&outdir)) {
133 outdir = tensorflow::testing::TmpDir();
134 }
135 std::string pattern = tensorflow::io::JoinPath(outdir, "tempfile-*.pb");
136 std::vector<std::string> files;
137 TF_CHECK_OK(env->GetMatchingPaths(pattern, &files));
138 for (const auto& f : files) {
139 TF_CHECK_OK(env->DeleteFile(f)) << f;
140 }
141
142 ASSERT_DEATH(dummy_lambda(), "two is not near four");
143
144 // Now check we wrote temporary files to the temporary directory that we can
145 // read.
146 std::vector<std::string> results;
147 TF_CHECK_OK(env->GetMatchingPaths(pattern, &results));
148
149 LOG(INFO) << "results: [" << absl::StrJoin(results, ", ") << "]";
150 EXPECT_EQ(3, results.size());
151 for (const std::string& result : results) {
152 LiteralProto literal_proto;
153 TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
154 &literal_proto));
155 Literal literal = Literal::CreateFromProto(literal_proto).value();
156 if (result.find("expected") != std::string::npos) {
157 EXPECT_EQ("f32[] 2", literal.ToString());
158 } else if (result.find("actual") != std::string::npos) {
159 EXPECT_EQ("f32[] 4", literal.ToString());
160 } else if (result.find("mismatches") != std::string::npos) {
161 EXPECT_EQ("pred[] true", literal.ToString());
162 } else {
163 FAIL() << "unknown file in temporary directory: " << result;
164 }
165 }
166 }
167
TEST(LiteralTestUtilTest,NotEqualHasValuesInMessage)168 TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
169 auto expected = LiteralUtil::CreateR1<int32_t>({1, 2, 3});
170 auto actual = LiteralUtil::CreateR1<int32_t>({4, 5, 6});
171 ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual);
172 EXPECT_THAT(result.message(),
173 ::testing::HasSubstr("Expected literal:\ns32[3] {1, 2, 3}"));
174 EXPECT_THAT(result.message(),
175 ::testing::HasSubstr("Actual literal:\ns32[3] {4, 5, 6}"));
176 }
177
TEST(LiteralTestUtilTest,NearComparatorR1)178 TEST(LiteralTestUtilTest, NearComparatorR1) {
179 auto a = LiteralUtil::CreateR1<float>(
180 {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
181 auto b = LiteralUtil::CreateR1<float>(
182 {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
183 EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
184 }
185
TEST(LiteralTestUtilTest,NearComparatorR1Complex64)186 TEST(LiteralTestUtilTest, NearComparatorR1Complex64) {
187 auto a = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
188 {0.1, 1.1},
189 {0.2, 1.2},
190 {0.3, 1.3},
191 {0.4, 1.4},
192 {0.5, 1.5},
193 {0.6, 1.6},
194 {0.7, 1.7},
195 {0.8, 1.8}});
196 auto b = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
197 {0.1, 1.1},
198 {0.2, 1.2},
199 {0.3, 1.3},
200 {0.4, 1.4},
201 {0.5, 1.5},
202 {0.6, 1.6},
203 {0.7, 1.7},
204 {0.8, 1.8}});
205 auto c = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
206 {0.1, 1.1},
207 {0.2, 1.2},
208 {0.3, 1.3},
209 {0.4, 1.4},
210 {0.5, 1.5},
211 {0.6, 1.6},
212 {0.7, 1.7},
213 {0.9, 1.8}});
214 auto d = LiteralUtil::CreateR1<complex64>({{0.0, 1.0},
215 {0.1, 1.1},
216 {0.2, 1.2},
217 {0.3, 1.3},
218 {0.4, 1.4},
219 {0.5, 1.5},
220 {0.6, 1.6},
221 {0.7, 1.7},
222 {0.8, 1.9}});
223 EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
224 EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001}));
225 EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001}));
226 EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001}));
227 }
228
TEST(LiteralTestUtilTest,NearComparatorR1Complex128)229 TEST(LiteralTestUtilTest, NearComparatorR1Complex128) {
230 auto a = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
231 {0.1, 1.1},
232 {0.2, 1.2},
233 {0.3, 1.3},
234 {0.4, 1.4},
235 {0.5, 1.5},
236 {0.6, 1.6},
237 {0.7, 1.7},
238 {0.8, 1.8}});
239 auto b = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
240 {0.1, 1.1},
241 {0.2, 1.2},
242 {0.3, 1.3},
243 {0.4, 1.4},
244 {0.5, 1.5},
245 {0.6, 1.6},
246 {0.7, 1.7},
247 {0.8, 1.8}});
248 auto c = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
249 {0.1, 1.1},
250 {0.2, 1.2},
251 {0.3, 1.3},
252 {0.4, 1.4},
253 {0.5, 1.5},
254 {0.6, 1.6},
255 {0.7, 1.7},
256 {0.9, 1.8}});
257 auto d = LiteralUtil::CreateR1<complex128>({{0.0, 1.0},
258 {0.1, 1.1},
259 {0.2, 1.2},
260 {0.3, 1.3},
261 {0.4, 1.4},
262 {0.5, 1.5},
263 {0.6, 1.6},
264 {0.7, 1.7},
265 {0.8, 1.9}});
266 EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
267 EXPECT_FALSE(LiteralTestUtil::Near(a, c, ErrorSpec{0.0001}));
268 EXPECT_FALSE(LiteralTestUtil::Near(a, d, ErrorSpec{0.0001}));
269 EXPECT_FALSE(LiteralTestUtil::Near(c, d, ErrorSpec{0.0001}));
270 }
271
TEST(LiteralTestUtilTest,NearComparatorR1Nan)272 TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
273 auto a = LiteralUtil::CreateR1<float>(
274 {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
275 auto b = LiteralUtil::CreateR1<float>(
276 {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
277 EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
278 }
279
TEST(LiteralTestUtil,NearComparatorDifferentLengths)280 TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
281 auto a = LiteralUtil::CreateR1<float>(
282 {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
283 auto b =
284 LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
285 EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
286 EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001}));
287 }
288
TEST(LiteralTestUtilTest,ExpectNearDoubleOutsideFloatValueRange)289 TEST(LiteralTestUtilTest, ExpectNearDoubleOutsideFloatValueRange) {
290 auto two_times_float_max =
291 LiteralUtil::CreateR0<double>(2.0 * std::numeric_limits<float>::max());
292 ErrorSpec error(0.001);
293 EXPECT_TRUE(
294 LiteralTestUtil::Near(two_times_float_max, two_times_float_max, error));
295 }
296
297 } // namespace
298 } // namespace xla
299