1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
18
19 #include <initializer_list>
20 #include <memory>
21 #include <optional>
22 #include <random>
23 #include <string>
24
25 #include "absl/base/attributes.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/error_spec.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/test_helpers.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/platform/test.h"
39
40 namespace xla {
41
42 // Utility class for making expectations/assertions related to XLA literals.
43 class LiteralTestUtil {
44 public:
45 // Asserts that the given shapes have the same rank, dimension sizes, and
46 // primitive types.
47 [[nodiscard]] static ::testing::AssertionResult EqualShapes(
48 const Shape& expected, const Shape& actual);
49
50 // Asserts that the provided shapes are equal as defined in AssertEqualShapes
51 // and that they have the same layout.
52 [[nodiscard]] static ::testing::AssertionResult EqualShapesAndLayouts(
53 const Shape& expected, const Shape& actual);
54
55 [[nodiscard]] static ::testing::AssertionResult Equal(
56 const LiteralSlice& expected, const LiteralSlice& actual);
57
58 // Asserts the given literal are (bitwise) equal to given expected values.
59 template <typename NativeT>
60 static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
61
62 template <typename NativeT>
63 static void ExpectR1Equal(absl::Span<const NativeT> expected,
64 const LiteralSlice& actual);
65 template <typename NativeT>
66 static void ExpectR2Equal(
67 std::initializer_list<std::initializer_list<NativeT>> expected,
68 const LiteralSlice& actual);
69
70 template <typename NativeT>
71 static void ExpectR3Equal(
72 std::initializer_list<
73 std::initializer_list<std::initializer_list<NativeT>>>
74 expected,
75 const LiteralSlice& actual);
76
77 // Asserts the given literal are (bitwise) equal to given array.
78 template <typename NativeT>
79 static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
80 const LiteralSlice& actual);
81 template <typename NativeT>
82 static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
83 const LiteralSlice& actual);
84 template <typename NativeT>
85 static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
86 const LiteralSlice& actual);
87
88 // Decorates literal_comparison::Near() with an AssertionResult return type.
89 //
90 // See comment on literal_comparison::Near().
91 [[nodiscard]] static ::testing::AssertionResult Near(
92 const LiteralSlice& expected, const LiteralSlice& actual,
93 const ErrorSpec& error_spec,
94 std::optional<bool> detailed_message = std::nullopt);
95
96 // Asserts the given literal are within the given error bound of the given
97 // expected values. Only supported for floating point values.
98 template <typename NativeT>
99 static void ExpectR0Near(NativeT expected, const LiteralSlice& actual,
100 const ErrorSpec& error);
101
102 template <typename NativeT>
103 static void ExpectR1Near(absl::Span<const NativeT> expected,
104 const LiteralSlice& actual, const ErrorSpec& error);
105
106 template <typename NativeT>
107 static void ExpectR2Near(
108 std::initializer_list<std::initializer_list<NativeT>> expected,
109 const LiteralSlice& actual, const ErrorSpec& error);
110
111 template <typename NativeT>
112 static void ExpectR3Near(
113 std::initializer_list<
114 std::initializer_list<std::initializer_list<NativeT>>>
115 expected,
116 const LiteralSlice& actual, const ErrorSpec& error);
117
118 template <typename NativeT>
119 static void ExpectR4Near(
120 std::initializer_list<std::initializer_list<
121 std::initializer_list<std::initializer_list<NativeT>>>>
122 expected,
123 const LiteralSlice& actual, const ErrorSpec& error);
124
125 // Asserts the given literal are within the given error bound to the given
126 // array. Only supported for floating point values.
127 template <typename NativeT>
128 static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
129 const LiteralSlice& actual,
130 const ErrorSpec& error);
131
132 template <typename NativeT>
133 static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
134 const LiteralSlice& actual,
135 const ErrorSpec& error);
136
137 template <typename NativeT>
138 static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
139 const LiteralSlice& actual,
140 const ErrorSpec& error);
141
142 // If the error spec is given, returns whether the expected and the actual are
143 // within the error bound; otherwise, returns whether they are equal. Tuples
144 // will be compared recursively.
145 [[nodiscard]] static ::testing::AssertionResult NearOrEqual(
146 const LiteralSlice& expected, const LiteralSlice& actual,
147 const std::optional<ErrorSpec>& error);
148
149 private:
150 LiteralTestUtil(const LiteralTestUtil&) = delete;
151 LiteralTestUtil& operator=(const LiteralTestUtil&) = delete;
152 };
153
154 template <typename NativeT>
ExpectR0Equal(NativeT expected,const LiteralSlice & actual)155 /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
156 const LiteralSlice& actual) {
157 EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual));
158 }
159
160 template <typename NativeT>
ExpectR1Equal(absl::Span<const NativeT> expected,const LiteralSlice & actual)161 /* static */ void LiteralTestUtil::ExpectR1Equal(
162 absl::Span<const NativeT> expected, const LiteralSlice& actual) {
163 EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual));
164 }
165
166 template <typename NativeT>
ExpectR2Equal(std::initializer_list<std::initializer_list<NativeT>> expected,const LiteralSlice & actual)167 /* static */ void LiteralTestUtil::ExpectR2Equal(
168 std::initializer_list<std::initializer_list<NativeT>> expected,
169 const LiteralSlice& actual) {
170 EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual));
171 }
172
173 template <typename NativeT>
ExpectR3Equal(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> expected,const LiteralSlice & actual)174 /* static */ void LiteralTestUtil::ExpectR3Equal(
175 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
176 expected,
177 const LiteralSlice& actual) {
178 EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual));
179 }
180
181 template <typename NativeT>
ExpectR2EqualArray2D(const Array2D<NativeT> & expected,const LiteralSlice & actual)182 /* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
183 const Array2D<NativeT>& expected, const LiteralSlice& actual) {
184 EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual));
185 }
186
187 template <typename NativeT>
ExpectR3EqualArray3D(const Array3D<NativeT> & expected,const LiteralSlice & actual)188 /* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
189 const Array3D<NativeT>& expected, const LiteralSlice& actual) {
190 EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual));
191 }
192
193 template <typename NativeT>
ExpectR4EqualArray4D(const Array4D<NativeT> & expected,const LiteralSlice & actual)194 /* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
195 const Array4D<NativeT>& expected, const LiteralSlice& actual) {
196 EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual));
197 }
198
199 template <typename NativeT>
ExpectR0Near(NativeT expected,const LiteralSlice & actual,const ErrorSpec & error)200 /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
201 const LiteralSlice& actual,
202 const ErrorSpec& error) {
203 EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error));
204 }
205
206 template <typename NativeT>
ExpectR1Near(absl::Span<const NativeT> expected,const LiteralSlice & actual,const ErrorSpec & error)207 /* static */ void LiteralTestUtil::ExpectR1Near(
208 absl::Span<const NativeT> expected, const LiteralSlice& actual,
209 const ErrorSpec& error) {
210 EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error));
211 }
212
213 template <typename NativeT>
ExpectR2Near(std::initializer_list<std::initializer_list<NativeT>> expected,const LiteralSlice & actual,const ErrorSpec & error)214 /* static */ void LiteralTestUtil::ExpectR2Near(
215 std::initializer_list<std::initializer_list<NativeT>> expected,
216 const LiteralSlice& actual, const ErrorSpec& error) {
217 EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error));
218 }
219
220 template <typename NativeT>
ExpectR3Near(std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>> expected,const LiteralSlice & actual,const ErrorSpec & error)221 /* static */ void LiteralTestUtil::ExpectR3Near(
222 std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
223 expected,
224 const LiteralSlice& actual, const ErrorSpec& error) {
225 EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error));
226 }
227
228 template <typename NativeT>
ExpectR4Near(std::initializer_list<std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>> expected,const LiteralSlice & actual,const ErrorSpec & error)229 /* static */ void LiteralTestUtil::ExpectR4Near(
230 std::initializer_list<std::initializer_list<
231 std::initializer_list<std::initializer_list<NativeT>>>>
232 expected,
233 const LiteralSlice& actual, const ErrorSpec& error) {
234 EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error));
235 }
236
237 template <typename NativeT>
ExpectR2NearArray2D(const Array2D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)238 /* static */ void LiteralTestUtil::ExpectR2NearArray2D(
239 const Array2D<NativeT>& expected, const LiteralSlice& actual,
240 const ErrorSpec& error) {
241 EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error));
242 }
243
244 template <typename NativeT>
ExpectR3NearArray3D(const Array3D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)245 /* static */ void LiteralTestUtil::ExpectR3NearArray3D(
246 const Array3D<NativeT>& expected, const LiteralSlice& actual,
247 const ErrorSpec& error) {
248 EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error));
249 }
250
251 template <typename NativeT>
ExpectR4NearArray4D(const Array4D<NativeT> & expected,const LiteralSlice & actual,const ErrorSpec & error)252 /* static */ void LiteralTestUtil::ExpectR4NearArray4D(
253 const Array4D<NativeT>& expected, const LiteralSlice& actual,
254 const ErrorSpec& error) {
255 EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error));
256 }
257
258 } // namespace xla
259
260 #endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
261