xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
17 
18 #include <complex>
19 #include <limits>
20 #include <string>
21 
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/core/platform/test.h"
25 #include "tensorflow/stream_executor/device_memory.h"
26 
27 namespace xla {
28 namespace gpu {
29 namespace {
30 
31 class BufferComparatorTest : public testing::Test {
32  protected:
BufferComparatorTest()33   BufferComparatorTest()
34       : platform_(
35             se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie()),
36         stream_exec_(platform_->ExecutorForDevice(0).ValueOrDie()) {}
37 
38   // Take floats only for convenience. Still uses ElementType internally.
39   template <typename ElementType>
CompareEqualBuffers(const std::vector<ElementType> & lhs,const std::vector<ElementType> & rhs)40   bool CompareEqualBuffers(const std::vector<ElementType>& lhs,
41                            const std::vector<ElementType>& rhs) {
42     se::Stream stream(stream_exec_);
43     stream.Init();
44 
45     se::ScopedDeviceMemory<ElementType> lhs_buffer =
46         stream_exec_->AllocateOwnedArray<ElementType>(lhs.size());
47     se::ScopedDeviceMemory<ElementType> rhs_buffer =
48         stream_exec_->AllocateOwnedArray<ElementType>(rhs.size());
49 
50     stream.ThenMemcpy(lhs_buffer.ptr(), lhs.data(), lhs_buffer->size());
51     stream.ThenMemcpy(rhs_buffer.ptr(), rhs.data(), rhs_buffer->size());
52     TF_CHECK_OK(stream.BlockHostUntilDone());
53 
54     BufferComparator comparator(
55         ShapeUtil::MakeShape(
56             primitive_util::NativeToPrimitiveType<ElementType>(),
57             {static_cast<int64_t>(lhs_buffer->ElementCount())}),
58         HloModuleConfig());
59     return comparator.CompareEqual(&stream, *lhs_buffer, *rhs_buffer).value();
60   }
61 
62   // Take floats only for convenience. Still uses ElementType internally.
63   template <typename ElementType>
CompareEqualFloatBuffers(const std::vector<float> & lhs_float,const std::vector<float> & rhs_float)64   bool CompareEqualFloatBuffers(const std::vector<float>& lhs_float,
65                                 const std::vector<float>& rhs_float) {
66     std::vector<ElementType> lhs(lhs_float.begin(), lhs_float.end());
67     std::vector<ElementType> rhs(rhs_float.begin(), rhs_float.end());
68     return CompareEqualBuffers(lhs, rhs);
69   }
70 
71   template <typename ElementType>
CompareEqualComplex(const std::vector<std::complex<ElementType>> & lhs,const std::vector<std::complex<ElementType>> & rhs)72   bool CompareEqualComplex(const std::vector<std::complex<ElementType>>& lhs,
73                            const std::vector<std::complex<ElementType>>& rhs) {
74     return CompareEqualBuffers<std::complex<ElementType>>(lhs, rhs);
75   }
76 
77   se::Platform* platform_;
78   se::StreamExecutor* stream_exec_;
79 };
80 
TEST_F(BufferComparatorTest,TestComplex)81 TEST_F(BufferComparatorTest, TestComplex) {
82   EXPECT_FALSE(
83       CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}}, {{0.1, 0.2}, {6, 7}}));
84   EXPECT_TRUE(CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}},
85                                          {{0.1, 0.2}, {2.2, 3.3}}));
86   EXPECT_TRUE(
87       CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}}, {{0.1, 0.2}, {2, 3}}));
88 
89   EXPECT_FALSE(
90       CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}}, {{0.1, 0.2}, {6, 3}}));
91 
92   EXPECT_FALSE(
93       CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}}, {{0.1, 0.2}, {6, 7}}));
94 
95   EXPECT_FALSE(
96       CompareEqualComplex<float>({{0.1, 0.2}, {2, 3}}, {{0.1, 6}, {2, 3}}));
97   EXPECT_TRUE(CompareEqualComplex<double>({{0.1, 0.2}, {2, 3}},
98                                           {{0.1, 0.2}, {2.2, 3.3}}));
99   EXPECT_FALSE(
100       CompareEqualComplex<double>({{0.1, 0.2}, {2, 3}}, {{0.1, 0.2}, {2, 7}}));
101 }
102 
TEST_F(BufferComparatorTest,TestNaNs)103 TEST_F(BufferComparatorTest, TestNaNs) {
104   EXPECT_TRUE(
105       CompareEqualFloatBuffers<Eigen::half>({std::nanf("")}, {std::nanf("")}));
106   // NaN values with different bit patterns should compare equal.
107   EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({std::nanf("")},
108                                                     {std::nanf("1234")}));
109   EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({std::nanf("")}, {1.}));
110 
111   EXPECT_TRUE(
112       CompareEqualFloatBuffers<float>({std::nanf("")}, {std::nanf("")}));
113   // NaN values with different bit patterns should compare equal.
114   EXPECT_TRUE(
115       CompareEqualFloatBuffers<float>({std::nanf("")}, {std::nanf("1234")}));
116   EXPECT_FALSE(CompareEqualFloatBuffers<float>({std::nanf("")}, {1.}));
117 
118   EXPECT_TRUE(
119       CompareEqualFloatBuffers<double>({std::nanf("")}, {std::nanf("")}));
120   // NaN values with different bit patterns should compare equal.
121   EXPECT_TRUE(
122       CompareEqualFloatBuffers<double>({std::nanf("")}, {std::nanf("1234")}));
123   EXPECT_FALSE(CompareEqualFloatBuffers<double>({std::nanf("")}, {1.}));
124 }
125 
TEST_F(BufferComparatorTest,TestInfs)126 TEST_F(BufferComparatorTest, TestInfs) {
127   const auto inf = std::numeric_limits<float>::infinity();
128   EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({inf}, {std::nanf("")}));
129   EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({inf}, {inf}));
130   EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({inf}, {65504}));
131   EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({-inf}, {-65504}));
132   EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({inf}, {-65504}));
133   EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({-inf}, {65504}));
134   EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({inf}, {20}));
135   EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({inf}, {-20}));
136   EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({-inf}, {20}));
137   EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({-inf}, {-20}));
138 
139   EXPECT_FALSE(CompareEqualFloatBuffers<float>({inf}, {std::nanf("")}));
140   EXPECT_TRUE(CompareEqualFloatBuffers<float>({inf}, {inf}));
141   EXPECT_FALSE(CompareEqualFloatBuffers<float>({inf}, {65504}));
142   EXPECT_FALSE(CompareEqualFloatBuffers<float>({-inf}, {-65504}));
143   EXPECT_FALSE(CompareEqualFloatBuffers<float>({inf}, {-65504}));
144   EXPECT_FALSE(CompareEqualFloatBuffers<float>({-inf}, {65504}));
145   EXPECT_FALSE(CompareEqualFloatBuffers<float>({inf}, {20}));
146   EXPECT_FALSE(CompareEqualFloatBuffers<float>({inf}, {-20}));
147   EXPECT_FALSE(CompareEqualFloatBuffers<float>({-inf}, {20}));
148   EXPECT_FALSE(CompareEqualFloatBuffers<float>({-inf}, {-20}));
149 
150   EXPECT_FALSE(CompareEqualFloatBuffers<double>({inf}, {std::nanf("")}));
151   EXPECT_TRUE(CompareEqualFloatBuffers<double>({inf}, {inf}));
152   EXPECT_FALSE(CompareEqualFloatBuffers<double>({inf}, {65504}));
153   EXPECT_FALSE(CompareEqualFloatBuffers<double>({-inf}, {-65504}));
154   EXPECT_FALSE(CompareEqualFloatBuffers<double>({inf}, {-65504}));
155   EXPECT_FALSE(CompareEqualFloatBuffers<double>({-inf}, {65504}));
156   EXPECT_FALSE(CompareEqualFloatBuffers<double>({inf}, {20}));
157   EXPECT_FALSE(CompareEqualFloatBuffers<double>({inf}, {-20}));
158   EXPECT_FALSE(CompareEqualFloatBuffers<double>({-inf}, {20}));
159   EXPECT_FALSE(CompareEqualFloatBuffers<double>({-inf}, {-20}));
160 }
161 
TEST_F(BufferComparatorTest,TestNumbers)162 TEST_F(BufferComparatorTest, TestNumbers) {
163   EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({20}, {20.1}));
164   EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({0}, {1}));
165   EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({0.9}, {1}));
166   EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({9}, {10}));
167   EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({10}, {9}));
168 
169   EXPECT_TRUE(CompareEqualFloatBuffers<float>({20}, {20.1}));
170   EXPECT_FALSE(CompareEqualFloatBuffers<float>({0}, {1}));
171   EXPECT_TRUE(CompareEqualFloatBuffers<float>({0.9}, {1}));
172   EXPECT_TRUE(CompareEqualFloatBuffers<float>({9}, {10}));
173   EXPECT_TRUE(CompareEqualFloatBuffers<float>({10}, {9}));
174 
175   EXPECT_TRUE(CompareEqualFloatBuffers<double>({20}, {20.1}));
176   EXPECT_FALSE(CompareEqualFloatBuffers<double>({0}, {1}));
177   EXPECT_TRUE(CompareEqualFloatBuffers<double>({0.9}, {1}));
178   EXPECT_TRUE(CompareEqualFloatBuffers<double>({9}, {10}));
179   EXPECT_TRUE(CompareEqualFloatBuffers<double>({10}, {9}));
180 
181   EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>({100}, {101}));
182   EXPECT_FALSE(CompareEqualFloatBuffers<int8_t>({0}, {10}));
183   EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>({9}, {10}));
184   EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>({90}, {100}));
185   EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>({100}, {90}));
186   EXPECT_FALSE(CompareEqualFloatBuffers<int8_t>({-128}, {127}));
187 }
188 
TEST_F(BufferComparatorTest,TestMultiple)189 TEST_F(BufferComparatorTest, TestMultiple) {
190   {
191     EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>(
192         {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1}));
193     std::vector<float> lhs(200);
194     std::vector<float> rhs(200);
195     for (int i = 0; i < 200; i++) {
196       EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>(lhs, rhs))
197           << "should be the same at index " << i;
198       lhs[i] = 3;
199       rhs[i] = 5;
200       EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>(lhs, rhs))
201           << "should be the different at index " << i;
202       lhs[i] = 0;
203       rhs[i] = 0;
204     }
205   }
206 
207   {
208     EXPECT_TRUE(CompareEqualFloatBuffers<float>(
209         {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1}));
210     std::vector<float> lhs(200);
211     std::vector<float> rhs(200);
212     for (int i = 0; i < 200; i++) {
213       EXPECT_TRUE(CompareEqualFloatBuffers<float>(lhs, rhs))
214           << "should be the same at index " << i;
215       lhs[i] = 3;
216       rhs[i] = 5;
217       EXPECT_FALSE(CompareEqualFloatBuffers<float>(lhs, rhs))
218           << "should be the different at index " << i;
219       lhs[i] = 0;
220       rhs[i] = 0;
221     }
222   }
223 
224   {
225     EXPECT_TRUE(CompareEqualFloatBuffers<double>(
226         {20, 30, 40, 50, 60}, {20.1, 30.1, 40.1, 50.1, 60.1}));
227     std::vector<float> lhs(200);
228     std::vector<float> rhs(200);
229     for (int i = 0; i < 200; i++) {
230       EXPECT_TRUE(CompareEqualFloatBuffers<double>(lhs, rhs))
231           << "should be the same at index " << i;
232       lhs[i] = 3;
233       rhs[i] = 5;
234       EXPECT_FALSE(CompareEqualFloatBuffers<double>(lhs, rhs))
235           << "should be the different at index " << i;
236       lhs[i] = 0;
237       rhs[i] = 0;
238     }
239   }
240 
241   {
242     EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>({20, 30, 40, 50, 60},
243                                                  {21, 31, 41, 51, 61}));
244     std::vector<float> lhs(200);
245     std::vector<float> rhs(200);
246     for (int i = 0; i < 200; i++) {
247       EXPECT_TRUE(CompareEqualFloatBuffers<int8_t>(lhs, rhs))
248           << "should be the same at index " << i;
249       lhs[i] = 3;
250       rhs[i] = 5;
251       EXPECT_FALSE(CompareEqualFloatBuffers<int8_t>(lhs, rhs))
252           << "should be the different at index " << i;
253       lhs[i] = 0;
254       rhs[i] = 0;
255     }
256   }
257 }
258 
259 }  // namespace
260 }  // namespace gpu
261 }  // namespace xla
262