1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
17
18 #include <complex>
19 #include <limits>
20 #include <string>
21
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/stream_executor/device_memory.h"
26
27 namespace xla {
28 namespace gpu {
29 namespace {
30
31 class BufferComparatorTest : public testing::Test {
32 protected:
BufferComparatorTest()33 BufferComparatorTest()
34 : platform_(
35 se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie()),
36 stream_exec_(platform_->ExecutorForDevice(0).ValueOrDie()) {}
37
38 // Take floats only for convenience. Still uses ElementType internally.
39 template <typename ElementType>
CompareEqualBuffers(const std::vector<ElementType> & lhs,const std::vector<ElementType> & rhs)40 bool CompareEqualBuffers(const std::vector<ElementType>& lhs,
41 const std::vector<ElementType>& rhs) {
42 se::Stream stream(stream_exec_);
43 stream.Init();
44
45 se::ScopedDeviceMemory<ElementType> lhs_buffer =
46 stream_exec_->AllocateOwnedArray<ElementType>(lhs.size());
47 se::ScopedDeviceMemory<ElementType> rhs_buffer =
48 stream_exec_->AllocateOwnedArray<ElementType>(rhs.size());
49
50 stream.ThenMemcpy(lhs_buffer.ptr(), lhs.data(), lhs_buffer->size());
51 stream.ThenMemcpy(rhs_buffer.ptr(), rhs.data(), rhs_buffer->size());
52 TF_CHECK_OK(stream.BlockHostUntilDone());
53
54 BufferComparator comparator(
55 ShapeUtil::MakeShape(
56 primitive_util::NativeToPrimitiveType<ElementType>(),
57 {static_cast<int64_t>(lhs_buffer->ElementCount())}),
58 HloModuleConfig());
59 return comparator.CompareEqual(&stream, *lhs_buffer, *rhs_buffer).value();
60 }
61
62 // Take floats only for convenience. Still uses ElementType internally.
63 template <typename ElementType>
CompareEqualFloatBuffers(const std::vector<float> & lhs_float,const std::vector<float> & rhs_float)64 bool CompareEqualFloatBuffers(const std::vector<float>& lhs_float,
65 const std::vector<float>& rhs_float) {
66 std::vector<ElementType> lhs(lhs_float.begin(), lhs_float.end());
67 std::vector<ElementType> rhs(rhs_float.begin(), rhs_float.end());
68 return CompareEqualBuffers(lhs, rhs);
69 }
70
71 template <typename ElementType>
CompareEqualComplex(const std::vector<std::complex<ElementType>> & lhs,const std::vector<std::complex<ElementType>> & rhs)72 bool CompareEqualComplex(const std::vector<std::complex<ElementType>>& lhs,
73 const std::vector<std::complex<ElementType>>& rhs) {
74 return CompareEqualBuffers<std::complex<ElementType>>(lhs, rhs);
75 }
76
77 se::Platform* platform_;
78 se::StreamExecutor* stream_exec_;
79 };
80
TEST_F(BufferComparatorTest,TestComplex)81 TEST_F(BufferComparatorTest, TestComplex) {
82 EXPECT_FALSE(
83 CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}}, {{0.1, 0.2}, {6, 7}}));
84 EXPECT_TRUE(CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}},
85 {{0.1, 0.2}, {2.2, 3.3}}));
86 EXPECT_TRUE(
87 CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}}, {{0.1, 0.2}, {2, 3}}));
88
89 EXPECT_FALSE(
90 CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}}, {{0.1, 0.2}, {6, 3}}));
91
92 EXPECT_FALSE(
93 CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}}, {{0.1, 0.2}, {6, 7}}));
94
95 EXPECT_FALSE(
96 CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}}, {{0.1, 6}, {2, 3}}));
97 EXPECT_TRUE(CompareEqualComplex<double>({{0.1, 0.2}, {2, 3}},
98 {{0.1, 0.2}, {2.2, 3.3}}));
99 EXPECT_FALSE(
100 CompareEqualComplex<double>({{0.1, 0.2}, {2, 3}}, {{0.1, 0.2}, {2, 7}}));
101 }
102
TEST_F(BufferComparatorTest,TestNaNs)103 TEST_F(BufferComparatorTest, TestNaNs) {
104 EXPECT_TRUE(
105 CompareEqualFloatBuffers<Eigen::half>({std::nanf("")}, {std::nanf("")}));
106 // NaN values with different bit patterns should compare equal.
107 EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({std::nanf("")},
108 {std::nanf("1234")}));
109 EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({std::nanf("")}, {1.}));
110
111 EXPECT_TRUE(
112 CompareEqualFloatBuffers<float>({std::nanf("")}, {std::nanf("")}));
113 // NaN values with different bit patterns should compare equal.
114 EXPECT_TRUE(
115 CompareEqualFloatBuffers<float>({std::nanf("")}, {std::nanf("1234")}));
116 EXPECT_FALSE(CompareEqualFloatBuffers<float>({std::nanf("")}, {1.}));
117
118 EXPECT_TRUE(
119 CompareEqualFloatBuffers<double>({std::nanf("")}, {std::nanf("")}));
120 // NaN values with different bit patterns should compare equal.
121 EXPECT_TRUE(
122 CompareEqualFloatBuffers<double>({std::nanf("")}, {std::nanf("1234")}));
123 EXPECT_FALSE(CompareEqualFloatBuffers<double>({std::nanf("")}, {1.}));
124 }
125
TEST_F(BufferComparatorTest,TestInfs)126 TEST_F(BufferComparatorTest, TestInfs) {
127 const auto inf = std::numeric_limits<float>::infinity();
128 EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({inf}, {std::nanf("")}));
129 EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({inf}, {inf}));
130 EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({inf}, {65504}));
131 EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({-inf}, {-65504}));
132 EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({inf}, {-65504}));
133 EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({-inf}, {65504}));
134 EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({inf}, {20}));
135 EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({inf}, {-20}));
136 EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({-inf}, {20}));
137 EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({-inf}, {-20}));
138
139 EXPECT_FALSE(CompareEqualFloatBuffers<float>({inf}, {std::nanf("")}));
140 EXPECT_TRUE(CompareEqualFloatBuffers<float>({inf}, {inf}));
141 EXPECT_FALSE(CompareEqualFloatBuffers<float>({inf}, {65504}));
142 EXPECT_FALSE(CompareEqualFloatBuffers<float>({-inf}, {-65504}));
143 EXPECT_FALSE(CompareEqualFloatBuffers<float>({inf}, {-65504}));
144 EXPECT_FALSE(CompareEqualFloatBuffers<float>({-inf}, {65504}));
145 EXPECT_FALSE(CompareEqualFloatBuffers<float>({inf}, {20}));
146 EXPECT_FALSE(CompareEqualFloatBuffers<float>({inf}, {-20}));
147 EXPECT_FALSE(CompareEqualFloatBuffers<float>({-inf}, {20}));
148 EXPECT_FALSE(CompareEqualFloatBuffers<float>({-inf}, {-20}));
149
150 EXPECT_FALSE(CompareEqualFloatBuffers<double>({inf}, {std::nanf("")}));
151 EXPECT_TRUE(CompareEqualFloatBuffers<double>({inf}, {inf}));
152 EXPECT_FALSE(CompareEqualFloatBuffers<double>({inf}, {65504}));
153 EXPECT_FALSE(CompareEqualFloatBuffers<double>({-inf}, {-65504}));
154 EXPECT_FALSE(CompareEqualFloatBuffers<double>({inf}, {-65504}));
155 EXPECT_FALSE(CompareEqualFloatBuffers<double>({-inf}, {65504}));
156 EXPECT_FALSE(CompareEqualFloatBuffers<double>({inf}, {20}));
157 EXPECT_FALSE(CompareEqualFloatBuffers<double>({inf}, {-20}));
158 EXPECT_FALSE(CompareEqualFloatBuffers<double>({-inf}, {20}));
159 EXPECT_FALSE(CompareEqualFloatBuffers<double>({-inf}, {-20}));
160 }
161
TEST_F(BufferComparatorTest,TestNumbers)162 TEST_F(BufferComparatorTest, TestNumbers) {
163 EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({20}, {20.1}));
164 EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({0}, {1}));
165 EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({0.9}, {1}));
166 EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({9}, {10}));
167 EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({10}, {9}));
168
169 EXPECT_TRUE(CompareEqualFloatBuffers<float>({20}, {20.1}));
170 EXPECT_FALSE(CompareEqualFloatBuffers<float>({0}, {1}));
171 EXPECT_TRUE(CompareEqualFloatBuffers<float>({0.9}, {1}));
172 EXPECT_TRUE(CompareEqualFloatBuffers<float>({9}, {10}));
173 EXPECT_TRUE(CompareEqualFloatBuffers<float>({10}, {9}));
174
175 EXPECT_TRUE(CompareEqualFloatBuffers<double>({20}, {20.1}));
176 EXPECT_FALSE(CompareEqualFloatBuffers<double>({0}, {1}));
177 EXPECT_TRUE(CompareEqualFloatBuffers<double>({0.9}, {1}));
178 EXPECT_TRUE(CompareEqualFloatBuffers<double>({9}, {10}));
179 EXPECT_TRUE(CompareEqualFloatBuffers<double>({10}, {9}));
180
181 EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>({100}, {101}));
182 EXPECT_FALSE(CompareEqualFloatBuffers<int8_t>({0}, {10}));
183 EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>({9}, {10}));
184 EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>({90}, {100}));
185 EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>({100}, {90}));
186 EXPECT_FALSE(CompareEqualFloatBuffers<int8_t>({-128}, {127}));
187 }
188
TEST_F(BufferComparatorTest,TestMultiple)189 TEST_F(BufferComparatorTest, TestMultiple) {
190 {
191 EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>(
192 {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1}));
193 std::vector<float> lhs(200);
194 std::vector<float> rhs(200);
195 for (int i = 0; i < 200; i++) {
196 EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>(lhs, rhs))
197 << "should be the same at index " << i;
198 lhs[i] = 3;
199 rhs[i] = 5;
200 EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>(lhs, rhs))
201 << "should be the different at index " << i;
202 lhs[i] = 0;
203 rhs[i] = 0;
204 }
205 }
206
207 {
208 EXPECT_TRUE(CompareEqualFloatBuffers<float>(
209 {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1}));
210 std::vector<float> lhs(200);
211 std::vector<float> rhs(200);
212 for (int i = 0; i < 200; i++) {
213 EXPECT_TRUE(CompareEqualFloatBuffers<float>(lhs, rhs))
214 << "should be the same at index " << i;
215 lhs[i] = 3;
216 rhs[i] = 5;
217 EXPECT_FALSE(CompareEqualFloatBuffers<float>(lhs, rhs))
218 << "should be the different at index " << i;
219 lhs[i] = 0;
220 rhs[i] = 0;
221 }
222 }
223
224 {
225 EXPECT_TRUE(CompareEqualFloatBuffers<double>(
226 {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1}));
227 std::vector<float> lhs(200);
228 std::vector<float> rhs(200);
229 for (int i = 0; i < 200; i++) {
230 EXPECT_TRUE(CompareEqualFloatBuffers<double>(lhs, rhs))
231 << "should be the same at index " << i;
232 lhs[i] = 3;
233 rhs[i] = 5;
234 EXPECT_FALSE(CompareEqualFloatBuffers<double>(lhs, rhs))
235 << "should be the different at index " << i;
236 lhs[i] = 0;
237 rhs[i] = 0;
238 }
239 }
240
241 {
242 EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>({20, 30, 40, 50, 60},
243 {21, 31, 41, 51, 61}));
244 std::vector<float> lhs(200);
245 std::vector<float> rhs(200);
246 for (int i = 0; i < 200; i++) {
247 EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>(lhs, rhs))
248 << "should be the same at index " << i;
249 lhs[i] = 3;
250 rhs[i] = 5;
251 EXPECT_FALSE(CompareEqualFloatBuffers<int8_t>(lhs, rhs))
252 << "should be the different at index " << i;
253 lhs[i] = 0;
254 rhs[i] = 0;
255 }
256 }
257 }
258
259 } // namespace
260 } // namespace gpu
261 } // namespace xla
262