xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/half_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/test/test_assert.h>
5 #include <cmath>
6 #include <iostream>
7 #include <limits>
8 #include <sstream>
9 #include <type_traits>
10 
11 using namespace at;
12 
TEST(TestHalf,Arithmetic)13 TEST(TestHalf, Arithmetic) {
14   Half zero = 0;
15   Half one = 1;
16   ASSERT_EQ(zero + one, one);
17   ASSERT_EQ(zero + zero, zero);
18   ASSERT_EQ(zero * one, zero);
19   ASSERT_EQ(one * one, one);
20   ASSERT_EQ(one / one, one);
21   ASSERT_EQ(one - one, zero);
22   ASSERT_EQ(one - zero, one);
23   ASSERT_EQ(zero - one, -one);
24   ASSERT_EQ(one + one, Half(2));
25   ASSERT_EQ(one + one, 2);
26 }
27 
TEST(TestHalf,Comparisions)28 TEST(TestHalf, Comparisions) {
29   Half zero = 0;
30   Half one = 1;
31   ASSERT_LT(zero, one);
32   ASSERT_LT(zero, 1);
33   ASSERT_GT(1, zero);
34   ASSERT_GE(0, zero);
35   ASSERT_NE(0, one);
36   ASSERT_EQ(zero, 0);
37   ASSERT_EQ(zero, zero);
38   ASSERT_EQ(zero, -zero);
39 }
40 
TEST(TestHalf,Cast)41 TEST(TestHalf, Cast) {
42   Half value = 1.5f;
43   ASSERT_EQ((int)value, 1);
44   ASSERT_EQ((short)value, 1);
45   ASSERT_EQ((long long)value, 1LL);
46   ASSERT_EQ((float)value, 1.5f);
47   ASSERT_EQ((double)value, 1.5);
48   ASSERT_EQ((bool)value, true);
49   ASSERT_EQ((bool)Half(0.0f), false);
50 }
51 
TEST(TestHalf,Construction)52 TEST(TestHalf, Construction) {
53   ASSERT_EQ(Half((short)3), Half(3.0f));
54   ASSERT_EQ(Half((unsigned short)3), Half(3.0f));
55   ASSERT_EQ(Half(3), Half(3.0f));
56   ASSERT_EQ(Half(3U), Half(3.0f));
57   ASSERT_EQ(Half(3LL), Half(3.0f));
58   ASSERT_EQ(Half(3ULL), Half(3.0f));
59   ASSERT_EQ(Half(3.5), Half(3.5f));
60 }
61 
to_string(const Half & h)62 static std::string to_string(const Half& h) {
63   std::stringstream ss;
64   ss << h;
65   return ss.str();
66 }
67 
TEST(TestHalf,Half2String)68 TEST(TestHalf, Half2String) {
69   ASSERT_EQ(to_string(Half(3.5f)), "3.5");
70   ASSERT_EQ(to_string(Half(-100.0f)), "-100");
71 }
72 
TEST(TestHalf,HalfNumericLimits)73 TEST(TestHalf, HalfNumericLimits) {
74   using limits = std::numeric_limits<Half>;
75   ASSERT_EQ(limits::lowest(), -65504.0f);
76   ASSERT_EQ(limits::max(), 65504.0f);
77   ASSERT_GT(limits::min(), 0);
78   ASSERT_LT(limits::min(), 1);
79   ASSERT_GT(limits::denorm_min(), 0);
80   ASSERT_EQ(limits::denorm_min() / 2, 0);
81   ASSERT_EQ(limits::infinity(), std::numeric_limits<float>::infinity());
82   ASSERT_NE(limits::quiet_NaN(), limits::quiet_NaN());
83   ASSERT_NE(limits::signaling_NaN(), limits::signaling_NaN());
84 }
85 
86 // Check the declared type of members of numeric_limits<Half> matches
87 // the declared type of that member on numeric_limits<float>
88 
89 #define ASSERT_SAME_TYPE(name)                         \
90   static_assert(                                       \
91       std::is_same_v<                                  \
92           decltype(std::numeric_limits<Half>::name),   \
93           decltype(std::numeric_limits<float>::name)>, \
94       "decltype(" #name ") differs")
95 
96 ASSERT_SAME_TYPE(is_specialized);
97 ASSERT_SAME_TYPE(is_signed);
98 ASSERT_SAME_TYPE(is_integer);
99 ASSERT_SAME_TYPE(is_exact);
100 ASSERT_SAME_TYPE(has_infinity);
101 ASSERT_SAME_TYPE(has_quiet_NaN);
102 ASSERT_SAME_TYPE(has_signaling_NaN);
103 ASSERT_SAME_TYPE(has_denorm);
104 ASSERT_SAME_TYPE(has_denorm_loss);
105 ASSERT_SAME_TYPE(round_style);
106 ASSERT_SAME_TYPE(is_iec559);
107 ASSERT_SAME_TYPE(is_bounded);
108 ASSERT_SAME_TYPE(is_modulo);
109 ASSERT_SAME_TYPE(digits);
110 ASSERT_SAME_TYPE(digits10);
111 ASSERT_SAME_TYPE(max_digits10);
112 ASSERT_SAME_TYPE(radix);
113 ASSERT_SAME_TYPE(min_exponent);
114 ASSERT_SAME_TYPE(min_exponent10);
115 ASSERT_SAME_TYPE(max_exponent);
116 ASSERT_SAME_TYPE(max_exponent10);
117 ASSERT_SAME_TYPE(traps);
118 ASSERT_SAME_TYPE(tinyness_before);
119 
TEST(TestHalf,CommonMath)120 TEST(TestHalf, CommonMath) {
121 #ifndef NDEBUG
122   float threshold = 0.00001;
123 #endif
124   assert(std::abs(std::lgamma(Half(10.0)) - std::lgamma(10.0f)) <= threshold);
125   assert(std::abs(std::exp(Half(1.0)) - std::exp(1.0f)) <= threshold);
126   assert(std::abs(std::log(Half(1.0)) - std::log(1.0f)) <= threshold);
127   assert(std::abs(std::log10(Half(1000.0)) - std::log10(1000.0f)) <= threshold);
128   assert(std::abs(std::log1p(Half(0.0)) - std::log1p(0.0f)) <= threshold);
129   assert(std::abs(std::log2(Half(1000.0)) - std::log2(1000.0f)) <= threshold);
130   assert(std::abs(std::expm1(Half(1.0)) - std::expm1(1.0f)) <= threshold);
131   assert(std::abs(std::cos(Half(0.0)) - std::cos(0.0f)) <= threshold);
132   assert(std::abs(std::sin(Half(0.0)) - std::sin(0.0f)) <= threshold);
133   assert(std::abs(std::sqrt(Half(100.0)) - std::sqrt(100.0f)) <= threshold);
134   assert(std::abs(std::ceil(Half(2.4)) - std::ceil(2.4f)) <= threshold);
135   assert(std::abs(std::floor(Half(2.7)) - std::floor(2.7f)) <= threshold);
136   assert(std::abs(std::trunc(Half(2.7)) - std::trunc(2.7f)) <= threshold);
137   assert(std::abs(std::acos(Half(-1.0)) - std::acos(-1.0f)) <= threshold);
138   assert(std::abs(std::cosh(Half(1.0)) - std::cosh(1.0f)) <= threshold);
139   assert(std::abs(std::acosh(Half(1.0)) - std::acosh(1.0f)) <= threshold);
140   assert(std::abs(std::asin(Half(1.0)) - std::asin(1.0f)) <= threshold);
141   assert(std::abs(std::sinh(Half(1.0)) - std::sinh(1.0f)) <= threshold);
142   assert(std::abs(std::asinh(Half(1.0)) - std::asinh(1.0f)) <= threshold);
143   assert(std::abs(std::tan(Half(0.0)) - std::tan(0.0f)) <= threshold);
144   assert(std::abs(std::atan(Half(1.0)) - std::atan(1.0f)) <= threshold);
145   assert(std::abs(std::tanh(Half(1.0)) - std::tanh(1.0f)) <= threshold);
146   assert(std::abs(std::erf(Half(10.0)) - std::erf(10.0f)) <= threshold);
147   assert(std::abs(std::erfc(Half(10.0)) - std::erfc(10.0f)) <= threshold);
148   assert(std::abs(std::abs(Half(-3.0)) - std::abs(-3.0f)) <= threshold);
149   assert(std::abs(std::round(Half(2.3)) - std::round(2.3f)) <= threshold);
150   assert(
151       std::abs(std::pow(Half(2.0), Half(10.0)) - std::pow(2.0f, 10.0f)) <=
152       threshold);
153   assert(
154       std::abs(std::atan2(Half(7.0), Half(0.0)) - std::atan2(7.0f, 0.0f)) <=
155       threshold);
156 #ifdef __APPLE__
157   // @TODO: can macos do implicit conversion of Half?
158   assert(
159       std::abs(std::isnan(static_cast<float>(Half(0.0))) - std::isnan(0.0f)) <=
160       threshold);
161   assert(
162       std::abs(std::isinf(static_cast<float>(Half(0.0))) - std::isinf(0.0f)) <=
163       threshold);
164 #else
165   assert(std::abs(std::isnan(Half(0.0)) - std::isnan(0.0f)) <= threshold);
166   assert(std::abs(std::isinf(Half(0.0)) - std::isinf(0.0f)) <= threshold);
167 #endif
168 }
169 
TEST(TestHalf,ComplexHalf)170 TEST(TestHalf, ComplexHalf) {
171   Half real = 3.0f;
172   Half imag = -10.0f;
173   auto complex = c10::complex<Half>(real, imag);
174   ASSERT_EQ(complex.real(), real);
175   ASSERT_EQ(complex.imag(), imag);
176 }
177