xref: /aosp_15_r20/external/tensorflow/tensorflow/core/platform/denormal_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Testing configuration of denormal state.
16 #include "tensorflow/core/platform/denormal.h"
17 
18 #include <cstring>
19 #include <limits>
20 
21 #include "tensorflow/core/platform/test.h"
22 
23 namespace tensorflow {
24 namespace port {
25 
TEST(DenormalStateTest,ConstructorAndAccessorsWork)26 TEST(DenormalStateTest, ConstructorAndAccessorsWork) {
27   const bool flush_to_zero[] = {true, true, false, false};
28   const bool denormals_are_zero[] = {true, false, true, false};
29   for (int i = 0; i < 4; ++i) {
30     const DenormalState state =
31         DenormalState(flush_to_zero[i], denormals_are_zero[i]);
32     EXPECT_EQ(state.flush_to_zero(), flush_to_zero[i]);
33     EXPECT_EQ(state.denormals_are_zero(), denormals_are_zero[i]);
34   }
35 }
36 
37 // Convert a 32-bit float to its binary representation.
bits(float x)38 uint32_t bits(float x) {
39   uint32_t out;
40   memcpy(&out, &x, sizeof(float));
41   return out;
42 }
43 
CheckDenormalHandling(const DenormalState & state)44 void CheckDenormalHandling(const DenormalState& state) {
45   // Notes:
46   //  - In the following tests we need to compare binary representations because
47   //    floating-point comparisons can trigger denormal flushing on SSE/ARM.
48   //  - We also require the input value to be marked `volatile` to prevent the
49   //    compiler from optimizing away any floating-point operations that might
50   //    otherwise be expected to flush denormals.
51 
52   // The following is zero iff denormal outputs are flushed to zero.
53   volatile float denormal_output = std::numeric_limits<float>::min();
54   denormal_output *= 0.25f;
55   if (state.flush_to_zero()) {
56     EXPECT_EQ(bits(denormal_output), 0x0);
57   } else {
58     EXPECT_NE(bits(denormal_output), 0x0);
59   }
60 
61   // The following is zero iff denormal inputs are flushed to zero.
62   volatile float normal_output = std::numeric_limits<float>::denorm_min();
63   normal_output *= std::numeric_limits<float>::max();
64   if (state.denormals_are_zero()) {
65     EXPECT_EQ(bits(normal_output), 0x0);
66   } else {
67     EXPECT_NE(bits(normal_output), 0x0);
68   }
69 }
70 
TEST(DenormalTest,GetAndSetStateWorkWithCorrectFlushing)71 TEST(DenormalTest, GetAndSetStateWorkWithCorrectFlushing) {
72   const DenormalState states[] = {
73       DenormalState(/*flush_to_zero=*/true, /*denormals_are_zero=*/true),
74       DenormalState(/*flush_to_zero=*/true, /*denormals_are_zero=*/false),
75       DenormalState(/*flush_to_zero=*/false, /*denormals_are_zero=*/true),
76       DenormalState(/*flush_to_zero=*/false, /*denormals_are_zero=*/false)};
77 
78   for (const DenormalState& state : states) {
79     if (SetDenormalState(state)) {
80       EXPECT_EQ(GetDenormalState(), state);
81       CheckDenormalHandling(state);
82     }
83   }
84 }
85 
TEST(ScopedRestoreFlushDenormalStateTest,RestoresState)86 TEST(ScopedRestoreFlushDenormalStateTest, RestoresState) {
87   const DenormalState flush_denormals(/*flush_to_zero=*/true,
88                                       /*denormals_are_zero=*/true);
89   const DenormalState dont_flush_denormals(/*flush_to_zero=*/false,
90                                            /*denormals_are_zero=*/false);
91 
92   // Only test if the platform supports setting the denormal state.
93   const bool can_set_denormal_state = SetDenormalState(flush_denormals) &&
94                                       SetDenormalState(dont_flush_denormals);
95   if (can_set_denormal_state) {
96     // Flush -> Don't Flush -> Flush.
97     SetDenormalState(flush_denormals);
98     {
99       ScopedRestoreFlushDenormalState restore_state;
100       SetDenormalState(dont_flush_denormals);
101       EXPECT_EQ(GetDenormalState(), dont_flush_denormals);
102     }
103     EXPECT_EQ(GetDenormalState(), flush_denormals);
104 
105     // Don't Flush -> Flush -> Don't Flush.
106     SetDenormalState(dont_flush_denormals);
107     {
108       ScopedRestoreFlushDenormalState restore_state;
109       SetDenormalState(flush_denormals);
110       EXPECT_EQ(GetDenormalState(), flush_denormals);
111     }
112     EXPECT_EQ(GetDenormalState(), dont_flush_denormals);
113   }
114 }
115 
TEST(ScopedFlushDenormalTest,SetsFlushingAndRestoresState)116 TEST(ScopedFlushDenormalTest, SetsFlushingAndRestoresState) {
117   const DenormalState flush_denormals(/*flush_to_zero=*/true,
118                                       /*denormals_are_zero=*/true);
119   const DenormalState dont_flush_denormals(/*flush_to_zero=*/false,
120                                            /*denormals_are_zero=*/false);
121 
122   // Only test if the platform supports setting the denormal state.
123   const bool can_set_denormal_state = SetDenormalState(flush_denormals) &&
124                                       SetDenormalState(dont_flush_denormals);
125   if (can_set_denormal_state) {
126     SetDenormalState(dont_flush_denormals);
127     {
128       ScopedFlushDenormal scoped_flush_denormal;
129       EXPECT_EQ(GetDenormalState(), flush_denormals);
130     }
131     EXPECT_EQ(GetDenormalState(), dont_flush_denormals);
132   }
133 }
134 
TEST(ScopedDontFlushDenormalTest,SetsNoFlushingAndRestoresState)135 TEST(ScopedDontFlushDenormalTest, SetsNoFlushingAndRestoresState) {
136   const DenormalState flush_denormals(/*flush_to_zero=*/true,
137                                       /*denormals_are_zero=*/true);
138   const DenormalState dont_flush_denormals(/*flush_to_zero=*/false,
139                                            /*denormals_are_zero=*/false);
140 
141   // Only test if the platform supports setting the denormal state.
142   const bool can_set_denormal_state = SetDenormalState(flush_denormals) &&
143                                       SetDenormalState(dont_flush_denormals);
144   if (can_set_denormal_state) {
145     SetDenormalState(flush_denormals);
146     {
147       ScopedDontFlushDenormal scoped_dont_flush_denormal;
148       EXPECT_EQ(GetDenormalState(), dont_flush_denormals);
149     }
150     EXPECT_EQ(GetDenormalState(), flush_denormals);
151   }
152 }
153 
154 }  // namespace port
155 }  // namespace tensorflow
156