1*bf2c3715SXin Li // This file is part of Eigen, a lightweight C++ template library
2*bf2c3715SXin Li // for linear algebra.
3*bf2c3715SXin Li //
4*bf2c3715SXin Li // This Source Code Form is subject to the terms of the Mozilla
5*bf2c3715SXin Li // Public License v. 2.0. If a copy of the MPL was not distributed
6*bf2c3715SXin Li // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
7*bf2c3715SXin Li
8*bf2c3715SXin Li #include <sstream>
9*bf2c3715SXin Li #include <memory>
10*bf2c3715SXin Li #include <math.h>
11*bf2c3715SXin Li
12*bf2c3715SXin Li #include "main.h"
13*bf2c3715SXin Li
14*bf2c3715SXin Li #include <Eigen/src/Core/arch/Default/BFloat16.h>
15*bf2c3715SXin Li
16*bf2c3715SXin Li #define VERIFY_BFLOAT16_BITS_EQUAL(h, bits) \
17*bf2c3715SXin Li VERIFY_IS_EQUAL((numext::bit_cast<numext::uint16_t>(h)), (static_cast<numext::uint16_t>(bits)))
18*bf2c3715SXin Li
19*bf2c3715SXin Li // Make sure it's possible to forward declare Eigen::bfloat16
20*bf2c3715SXin Li namespace Eigen {
21*bf2c3715SXin Li struct bfloat16;
22*bf2c3715SXin Li }
23*bf2c3715SXin Li
24*bf2c3715SXin Li using Eigen::bfloat16;
25*bf2c3715SXin Li
BinaryToFloat(uint32_t sign,uint32_t exponent,uint32_t high_mantissa,uint32_t low_mantissa)26*bf2c3715SXin Li float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
27*bf2c3715SXin Li uint32_t low_mantissa) {
28*bf2c3715SXin Li float dest;
29*bf2c3715SXin Li uint32_t src = (sign << 31) + (exponent << 23) + (high_mantissa << 16) + low_mantissa;
30*bf2c3715SXin Li memcpy(static_cast<void*>(&dest),
31*bf2c3715SXin Li static_cast<const void*>(&src), sizeof(dest));
32*bf2c3715SXin Li return dest;
33*bf2c3715SXin Li }
34*bf2c3715SXin Li
35*bf2c3715SXin Li template<typename T>
test_roundtrip()36*bf2c3715SXin Li void test_roundtrip() {
37*bf2c3715SXin Li // Representable T round trip via bfloat16
38*bf2c3715SXin Li VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(-std::numeric_limits<T>::infinity()))), -std::numeric_limits<T>::infinity());
39*bf2c3715SXin Li VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(std::numeric_limits<T>::infinity()))), std::numeric_limits<T>::infinity());
40*bf2c3715SXin Li VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(-1.0)))), T(-1.0));
41*bf2c3715SXin Li VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(-0.5)))), T(-0.5));
42*bf2c3715SXin Li VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(-0.0)))), T(-0.0));
43*bf2c3715SXin Li VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(1.0)))), T(1.0));
44*bf2c3715SXin Li VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(0.5)))), T(0.5));
45*bf2c3715SXin Li VERIFY_IS_EQUAL((internal::cast<bfloat16,T>(internal::cast<T,bfloat16>(T(0.0)))), T(0.0));
46*bf2c3715SXin Li }
47*bf2c3715SXin Li
test_conversion()48*bf2c3715SXin Li void test_conversion()
49*bf2c3715SXin Li {
50*bf2c3715SXin Li using Eigen::bfloat16_impl::__bfloat16_raw;
51*bf2c3715SXin Li
52*bf2c3715SXin Li // Round-trip casts
53*bf2c3715SXin Li VERIFY_IS_EQUAL(
54*bf2c3715SXin Li numext::bit_cast<bfloat16>(numext::bit_cast<numext::uint16_t>(bfloat16(1.0f))),
55*bf2c3715SXin Li bfloat16(1.0f));
56*bf2c3715SXin Li VERIFY_IS_EQUAL(
57*bf2c3715SXin Li numext::bit_cast<bfloat16>(numext::bit_cast<numext::uint16_t>(bfloat16(0.5f))),
58*bf2c3715SXin Li bfloat16(0.5f));
59*bf2c3715SXin Li VERIFY_IS_EQUAL(
60*bf2c3715SXin Li numext::bit_cast<bfloat16>(numext::bit_cast<numext::uint16_t>(bfloat16(-0.33333f))),
61*bf2c3715SXin Li bfloat16(-0.33333f));
62*bf2c3715SXin Li VERIFY_IS_EQUAL(
63*bf2c3715SXin Li numext::bit_cast<bfloat16>(numext::bit_cast<numext::uint16_t>(bfloat16(0.0f))),
64*bf2c3715SXin Li bfloat16(0.0f));
65*bf2c3715SXin Li
66*bf2c3715SXin Li // Conversion from float.
67*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(1.0f), 0x3f80);
68*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.5f), 0x3f00);
69*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.33333f), 0x3eab);
70*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(3.38e38f), 0x7f7e);
71*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(3.40e38f), 0x7f80); // Becomes infinity.
72*bf2c3715SXin Li
73*bf2c3715SXin Li // Verify round-to-nearest-even behavior.
74*bf2c3715SXin Li float val1 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c00)));
75*bf2c3715SXin Li float val2 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c01)));
76*bf2c3715SXin Li float val3 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c02)));
77*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.5f * (val1 + val2)), 0x3c00);
78*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.5f * (val2 + val3)), 0x3c02);
79*bf2c3715SXin Li
80*bf2c3715SXin Li // Conversion from int.
81*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(-1), 0xbf80);
82*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0), 0x0000);
83*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(1), 0x3f80);
84*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(2), 0x4000);
85*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(3), 0x4040);
86*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(12), 0x4140);
87*bf2c3715SXin Li
88*bf2c3715SXin Li // Conversion from bool.
89*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(false), 0x0000);
90*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(true), 0x3f80);
91*bf2c3715SXin Li
92*bf2c3715SXin Li // Conversion to bool
93*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(3)), true);
94*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(0.33333f)), true);
95*bf2c3715SXin Li VERIFY_IS_EQUAL(bfloat16(-0.0), false);
96*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(0.0)), false);
97*bf2c3715SXin Li
98*bf2c3715SXin Li // Explicit conversion to float.
99*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x0000))), 0.0f);
100*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x3f80))), 1.0f);
101*bf2c3715SXin Li
102*bf2c3715SXin Li // Implicit conversion to float
103*bf2c3715SXin Li VERIFY_IS_EQUAL(bfloat16(__bfloat16_raw(0x0000)), 0.0f);
104*bf2c3715SXin Li VERIFY_IS_EQUAL(bfloat16(__bfloat16_raw(0x3f80)), 1.0f);
105*bf2c3715SXin Li
106*bf2c3715SXin Li // Zero representations
107*bf2c3715SXin Li VERIFY_IS_EQUAL(bfloat16(0.0f), bfloat16(0.0f));
108*bf2c3715SXin Li VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(0.0f));
109*bf2c3715SXin Li VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(-0.0f));
110*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.0f), 0x0000);
111*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(-0.0f), 0x8000);
112*bf2c3715SXin Li
113*bf2c3715SXin Li // Default is zero
114*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
115*bf2c3715SXin Li
116*bf2c3715SXin Li // Representable floats round trip via bfloat16
117*bf2c3715SXin Li test_roundtrip<float>();
118*bf2c3715SXin Li test_roundtrip<double>();
119*bf2c3715SXin Li test_roundtrip<std::complex<float> >();
120*bf2c3715SXin Li test_roundtrip<std::complex<double> >();
121*bf2c3715SXin Li
122*bf2c3715SXin Li // Conversion
123*bf2c3715SXin Li Array<float,1,100> a;
124*bf2c3715SXin Li for (int i = 0; i < 100; i++) a(i) = i + 1.25;
125*bf2c3715SXin Li Array<bfloat16,1,100> b = a.cast<bfloat16>();
126*bf2c3715SXin Li Array<float,1,100> c = b.cast<float>();
127*bf2c3715SXin Li for (int i = 0; i < 100; ++i) {
128*bf2c3715SXin Li VERIFY_LE(numext::abs(c(i) - a(i)), a(i) / 128);
129*bf2c3715SXin Li }
130*bf2c3715SXin Li
131*bf2c3715SXin Li // Epsilon
132*bf2c3715SXin Li VERIFY_LE(1.0f, static_cast<float>((std::numeric_limits<bfloat16>::epsilon)() + bfloat16(1.0f)));
133*bf2c3715SXin Li VERIFY_IS_EQUAL(1.0f, static_cast<float>((std::numeric_limits<bfloat16>::epsilon)() / bfloat16(2.0f) + bfloat16(1.0f)));
134*bf2c3715SXin Li
135*bf2c3715SXin Li // Negate
136*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(3.0f)), -3.0f);
137*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(-4.5f)), 4.5f);
138*bf2c3715SXin Li
139*bf2c3715SXin Li
140*bf2c3715SXin Li #if !EIGEN_COMP_MSVC
141*bf2c3715SXin Li // Visual Studio errors out on divisions by 0
142*bf2c3715SXin Li VERIFY((numext::isnan)(static_cast<float>(bfloat16(0.0 / 0.0))));
143*bf2c3715SXin Li VERIFY((numext::isinf)(static_cast<float>(bfloat16(1.0 / 0.0))));
144*bf2c3715SXin Li VERIFY((numext::isinf)(static_cast<float>(bfloat16(-1.0 / 0.0))));
145*bf2c3715SXin Li
146*bf2c3715SXin Li // Visual Studio errors out on divisions by 0
147*bf2c3715SXin Li VERIFY((numext::isnan)(bfloat16(0.0 / 0.0)));
148*bf2c3715SXin Li VERIFY((numext::isinf)(bfloat16(1.0 / 0.0)));
149*bf2c3715SXin Li VERIFY((numext::isinf)(bfloat16(-1.0 / 0.0)));
150*bf2c3715SXin Li #endif
151*bf2c3715SXin Li
152*bf2c3715SXin Li // NaNs and infinities.
153*bf2c3715SXin Li VERIFY(!(numext::isinf)(static_cast<float>(bfloat16(3.38e38f)))); // Largest finite number.
154*bf2c3715SXin Li VERIFY(!(numext::isnan)(static_cast<float>(bfloat16(0.0f))));
155*bf2c3715SXin Li VERIFY((numext::isinf)(static_cast<float>(bfloat16(__bfloat16_raw(0xff80)))));
156*bf2c3715SXin Li VERIFY((numext::isnan)(static_cast<float>(bfloat16(__bfloat16_raw(0xffc0)))));
157*bf2c3715SXin Li VERIFY((numext::isinf)(static_cast<float>(bfloat16(__bfloat16_raw(0x7f80)))));
158*bf2c3715SXin Li VERIFY((numext::isnan)(static_cast<float>(bfloat16(__bfloat16_raw(0x7fc0)))));
159*bf2c3715SXin Li
160*bf2c3715SXin Li // Exactly same checks as above, just directly on the bfloat16 representation.
161*bf2c3715SXin Li VERIFY(!(numext::isinf)(bfloat16(__bfloat16_raw(0x7bff))));
162*bf2c3715SXin Li VERIFY(!(numext::isnan)(bfloat16(__bfloat16_raw(0x0000))));
163*bf2c3715SXin Li VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0xff80))));
164*bf2c3715SXin Li VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0xffc0))));
165*bf2c3715SXin Li VERIFY((numext::isinf)(bfloat16(__bfloat16_raw(0x7f80))));
166*bf2c3715SXin Li VERIFY((numext::isnan)(bfloat16(__bfloat16_raw(0x7fc0))));
167*bf2c3715SXin Li
168*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x0, 0xff, 0x40, 0x0)), 0x7fc0);
169*bf2c3715SXin Li VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x1, 0xff, 0x40, 0x0)), 0xffc0);
170*bf2c3715SXin Li }
171*bf2c3715SXin Li
test_numtraits()172*bf2c3715SXin Li void test_numtraits()
173*bf2c3715SXin Li {
174*bf2c3715SXin Li std::cout << "epsilon = " << NumTraits<bfloat16>::epsilon() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(NumTraits<bfloat16>::epsilon()) << ")" << std::endl;
175*bf2c3715SXin Li std::cout << "highest = " << NumTraits<bfloat16>::highest() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(NumTraits<bfloat16>::highest()) << ")" << std::endl;
176*bf2c3715SXin Li std::cout << "lowest = " << NumTraits<bfloat16>::lowest() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(NumTraits<bfloat16>::lowest()) << ")" << std::endl;
177*bf2c3715SXin Li std::cout << "min = " << (std::numeric_limits<bfloat16>::min)() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>((std::numeric_limits<bfloat16>::min)()) << ")" << std::endl;
178*bf2c3715SXin Li std::cout << "denorm min = " << (std::numeric_limits<bfloat16>::denorm_min)() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>((std::numeric_limits<bfloat16>::denorm_min)()) << ")" << std::endl;
179*bf2c3715SXin Li std::cout << "infinity = " << NumTraits<bfloat16>::infinity() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(NumTraits<bfloat16>::infinity()) << ")" << std::endl;
180*bf2c3715SXin Li std::cout << "quiet nan = " << NumTraits<bfloat16>::quiet_NaN() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(NumTraits<bfloat16>::quiet_NaN()) << ")" << std::endl;
181*bf2c3715SXin Li std::cout << "signaling nan = " << std::numeric_limits<bfloat16>::signaling_NaN() << " (0x" << std::hex << numext::bit_cast<numext::uint16_t>(std::numeric_limits<bfloat16>::signaling_NaN()) << ")" << std::endl;
182*bf2c3715SXin Li
183*bf2c3715SXin Li VERIFY(NumTraits<bfloat16>::IsSigned);
184*bf2c3715SXin Li
185*bf2c3715SXin Li VERIFY_IS_EQUAL(
186*bf2c3715SXin Li numext::bit_cast<numext::uint16_t>(std::numeric_limits<bfloat16>::infinity()),
187*bf2c3715SXin Li numext::bit_cast<numext::uint16_t>(bfloat16(std::numeric_limits<float>::infinity())) );
188*bf2c3715SXin Li // There is no guarantee that casting a 32-bit NaN to bfloat16 has a precise
189*bf2c3715SXin Li // bit pattern. We test that it is in fact a NaN, then test the signaling
190*bf2c3715SXin Li // bit (msb of significand is 1 for quiet, 0 for signaling).
191*bf2c3715SXin Li const numext::uint16_t BFLOAT16_QUIET_BIT = 0x0040;
192*bf2c3715SXin Li VERIFY(
193*bf2c3715SXin Li (numext::isnan)(std::numeric_limits<bfloat16>::quiet_NaN())
194*bf2c3715SXin Li && (numext::isnan)(bfloat16(std::numeric_limits<float>::quiet_NaN()))
195*bf2c3715SXin Li && ((numext::bit_cast<numext::uint16_t>(std::numeric_limits<bfloat16>::quiet_NaN()) & BFLOAT16_QUIET_BIT) > 0)
196*bf2c3715SXin Li && ((numext::bit_cast<numext::uint16_t>(bfloat16(std::numeric_limits<float>::quiet_NaN())) & BFLOAT16_QUIET_BIT) > 0) );
197*bf2c3715SXin Li // After a cast to bfloat16, a signaling NaN may become non-signaling. Thus,
198*bf2c3715SXin Li // we check that both are NaN, and that only the `numeric_limits` version is
199*bf2c3715SXin Li // signaling.
200*bf2c3715SXin Li VERIFY(
201*bf2c3715SXin Li (numext::isnan)(std::numeric_limits<bfloat16>::signaling_NaN())
202*bf2c3715SXin Li && (numext::isnan)(bfloat16(std::numeric_limits<float>::signaling_NaN()))
203*bf2c3715SXin Li && ((numext::bit_cast<numext::uint16_t>(std::numeric_limits<bfloat16>::signaling_NaN()) & BFLOAT16_QUIET_BIT) == 0) );
204*bf2c3715SXin Li
205*bf2c3715SXin Li VERIFY( (std::numeric_limits<bfloat16>::min)() > bfloat16(0.f) );
206*bf2c3715SXin Li VERIFY( (std::numeric_limits<bfloat16>::denorm_min)() > bfloat16(0.f) );
207*bf2c3715SXin Li VERIFY_IS_EQUAL( (std::numeric_limits<bfloat16>::denorm_min)()/bfloat16(2), bfloat16(0.f) );
208*bf2c3715SXin Li }
209*bf2c3715SXin Li
test_arithmetic()210*bf2c3715SXin Li void test_arithmetic()
211*bf2c3715SXin Li {
212*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2) + bfloat16(2)), 4);
213*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2) + bfloat16(-2)), 0);
214*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(bfloat16(0.33333f) + bfloat16(0.66667f)), 1.0f);
215*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(bfloat16(2.0f) * bfloat16(-5.5f)), -11.0f);
216*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(bfloat16(1.0f) / bfloat16(3.0f)), 0.3339f);
217*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(4096.0f)), -4096.0f);
218*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(-bfloat16(-4096.0f)), 4096.0f);
219*bf2c3715SXin Li }
220*bf2c3715SXin Li
test_comparison()221*bf2c3715SXin Li void test_comparison()
222*bf2c3715SXin Li {
223*bf2c3715SXin Li VERIFY(bfloat16(1.0f) > bfloat16(0.5f));
224*bf2c3715SXin Li VERIFY(bfloat16(0.5f) < bfloat16(1.0f));
225*bf2c3715SXin Li VERIFY(!(bfloat16(1.0f) < bfloat16(0.5f)));
226*bf2c3715SXin Li VERIFY(!(bfloat16(0.5f) > bfloat16(1.0f)));
227*bf2c3715SXin Li
228*bf2c3715SXin Li VERIFY(!(bfloat16(4.0f) > bfloat16(4.0f)));
229*bf2c3715SXin Li VERIFY(!(bfloat16(4.0f) < bfloat16(4.0f)));
230*bf2c3715SXin Li
231*bf2c3715SXin Li VERIFY(!(bfloat16(0.0f) < bfloat16(-0.0f)));
232*bf2c3715SXin Li VERIFY(!(bfloat16(-0.0f) < bfloat16(0.0f)));
233*bf2c3715SXin Li VERIFY(!(bfloat16(0.0f) > bfloat16(-0.0f)));
234*bf2c3715SXin Li VERIFY(!(bfloat16(-0.0f) > bfloat16(0.0f)));
235*bf2c3715SXin Li
236*bf2c3715SXin Li VERIFY(bfloat16(0.2f) > bfloat16(-1.0f));
237*bf2c3715SXin Li VERIFY(bfloat16(-1.0f) < bfloat16(0.2f));
238*bf2c3715SXin Li VERIFY(bfloat16(-16.0f) < bfloat16(-15.0f));
239*bf2c3715SXin Li
240*bf2c3715SXin Li VERIFY(bfloat16(1.0f) == bfloat16(1.0f));
241*bf2c3715SXin Li VERIFY(bfloat16(1.0f) != bfloat16(2.0f));
242*bf2c3715SXin Li
243*bf2c3715SXin Li // Comparisons with NaNs and infinities.
244*bf2c3715SXin Li #if !EIGEN_COMP_MSVC
245*bf2c3715SXin Li // Visual Studio errors out on divisions by 0
246*bf2c3715SXin Li VERIFY(!(bfloat16(0.0 / 0.0) == bfloat16(0.0 / 0.0)));
247*bf2c3715SXin Li VERIFY(bfloat16(0.0 / 0.0) != bfloat16(0.0 / 0.0));
248*bf2c3715SXin Li
249*bf2c3715SXin Li VERIFY(!(bfloat16(1.0) == bfloat16(0.0 / 0.0)));
250*bf2c3715SXin Li VERIFY(!(bfloat16(1.0) < bfloat16(0.0 / 0.0)));
251*bf2c3715SXin Li VERIFY(!(bfloat16(1.0) > bfloat16(0.0 / 0.0)));
252*bf2c3715SXin Li VERIFY(bfloat16(1.0) != bfloat16(0.0 / 0.0));
253*bf2c3715SXin Li
254*bf2c3715SXin Li VERIFY(bfloat16(1.0) < bfloat16(1.0 / 0.0));
255*bf2c3715SXin Li VERIFY(bfloat16(1.0) > bfloat16(-1.0 / 0.0));
256*bf2c3715SXin Li #endif
257*bf2c3715SXin Li }
258*bf2c3715SXin Li
test_basic_functions()259*bf2c3715SXin Li void test_basic_functions()
260*bf2c3715SXin Li {
261*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(numext::abs(bfloat16(3.5f))), 3.5f);
262*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(abs(bfloat16(3.5f))), 3.5f);
263*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(numext::abs(bfloat16(-3.5f))), 3.5f);
264*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(abs(bfloat16(-3.5f))), 3.5f);
265*bf2c3715SXin Li
266*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(numext::floor(bfloat16(3.5f))), 3.0f);
267*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(floor(bfloat16(3.5f))), 3.0f);
268*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(numext::floor(bfloat16(-3.5f))), -4.0f);
269*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(floor(bfloat16(-3.5f))), -4.0f);
270*bf2c3715SXin Li
271*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(numext::ceil(bfloat16(3.5f))), 4.0f);
272*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(ceil(bfloat16(3.5f))), 4.0f);
273*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(numext::ceil(bfloat16(-3.5f))), -3.0f);
274*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(ceil(bfloat16(-3.5f))), -3.0f);
275*bf2c3715SXin Li
276*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(numext::sqrt(bfloat16(0.0f))), 0.0f);
277*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(sqrt(bfloat16(0.0f))), 0.0f);
278*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(numext::sqrt(bfloat16(4.0f))), 2.0f);
279*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(sqrt(bfloat16(4.0f))), 2.0f);
280*bf2c3715SXin Li
281*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(numext::pow(bfloat16(0.0f), bfloat16(1.0f))), 0.0f);
282*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(pow(bfloat16(0.0f), bfloat16(1.0f))), 0.0f);
283*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(numext::pow(bfloat16(2.0f), bfloat16(2.0f))), 4.0f);
284*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(pow(bfloat16(2.0f), bfloat16(2.0f))), 4.0f);
285*bf2c3715SXin Li
286*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(numext::exp(bfloat16(0.0f))), 1.0f);
287*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(exp(bfloat16(0.0f))), 1.0f);
288*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(numext::exp(bfloat16(EIGEN_PI))), 20.f + static_cast<float>(EIGEN_PI));
289*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(exp(bfloat16(EIGEN_PI))), 20.f + static_cast<float>(EIGEN_PI));
290*bf2c3715SXin Li
291*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(numext::expm1(bfloat16(0.0f))), 0.0f);
292*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(expm1(bfloat16(0.0f))), 0.0f);
293*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(numext::expm1(bfloat16(2.0f))), 6.375f);
294*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(expm1(bfloat16(2.0f))), 6.375f);
295*bf2c3715SXin Li
296*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(numext::log(bfloat16(1.0f))), 0.0f);
297*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(log(bfloat16(1.0f))), 0.0f);
298*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(numext::log(bfloat16(10.0f))), 2.296875f);
299*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(log(bfloat16(10.0f))), 2.296875f);
300*bf2c3715SXin Li
301*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(numext::log1p(bfloat16(0.0f))), 0.0f);
302*bf2c3715SXin Li VERIFY_IS_EQUAL(static_cast<float>(log1p(bfloat16(0.0f))), 0.0f);
303*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(numext::log1p(bfloat16(10.0f))), 2.390625f);
304*bf2c3715SXin Li VERIFY_IS_APPROX(static_cast<float>(log1p(bfloat16(10.0f))), 2.390625f);
305*bf2c3715SXin Li }
306*bf2c3715SXin Li
test_trigonometric_functions()307*bf2c3715SXin Li void test_trigonometric_functions()
308*bf2c3715SXin Li {
309*bf2c3715SXin Li VERIFY_IS_APPROX(numext::cos(bfloat16(0.0f)), bfloat16(cosf(0.0f)));
310*bf2c3715SXin Li VERIFY_IS_APPROX(cos(bfloat16(0.0f)), bfloat16(cosf(0.0f)));
311*bf2c3715SXin Li VERIFY_IS_APPROX(numext::cos(bfloat16(EIGEN_PI)), bfloat16(cosf(EIGEN_PI)));
312*bf2c3715SXin Li // VERIFY_IS_APPROX(numext::cos(bfloat16(EIGEN_PI/2)), bfloat16(cosf(EIGEN_PI/2)));
313*bf2c3715SXin Li // VERIFY_IS_APPROX(numext::cos(bfloat16(3*EIGEN_PI/2)), bfloat16(cosf(3*EIGEN_PI/2)));
314*bf2c3715SXin Li VERIFY_IS_APPROX(numext::cos(bfloat16(3.5f)), bfloat16(cosf(3.5f)));
315*bf2c3715SXin Li
316*bf2c3715SXin Li VERIFY_IS_APPROX(numext::sin(bfloat16(0.0f)), bfloat16(sinf(0.0f)));
317*bf2c3715SXin Li VERIFY_IS_APPROX(sin(bfloat16(0.0f)), bfloat16(sinf(0.0f)));
318*bf2c3715SXin Li // VERIFY_IS_APPROX(numext::sin(bfloat16(EIGEN_PI)), bfloat16(sinf(EIGEN_PI)));
319*bf2c3715SXin Li VERIFY_IS_APPROX(numext::sin(bfloat16(EIGEN_PI/2)), bfloat16(sinf(EIGEN_PI/2)));
320*bf2c3715SXin Li VERIFY_IS_APPROX(numext::sin(bfloat16(3*EIGEN_PI/2)), bfloat16(sinf(3*EIGEN_PI/2)));
321*bf2c3715SXin Li VERIFY_IS_APPROX(numext::sin(bfloat16(3.5f)), bfloat16(sinf(3.5f)));
322*bf2c3715SXin Li
323*bf2c3715SXin Li VERIFY_IS_APPROX(numext::tan(bfloat16(0.0f)), bfloat16(tanf(0.0f)));
324*bf2c3715SXin Li VERIFY_IS_APPROX(tan(bfloat16(0.0f)), bfloat16(tanf(0.0f)));
325*bf2c3715SXin Li // VERIFY_IS_APPROX(numext::tan(bfloat16(EIGEN_PI)), bfloat16(tanf(EIGEN_PI)));
326*bf2c3715SXin Li // VERIFY_IS_APPROX(numext::tan(bfloat16(EIGEN_PI/2)), bfloat16(tanf(EIGEN_PI/2)));
327*bf2c3715SXin Li // VERIFY_IS_APPROX(numext::tan(bfloat16(3*EIGEN_PI/2)), bfloat16(tanf(3*EIGEN_PI/2)));
328*bf2c3715SXin Li VERIFY_IS_APPROX(numext::tan(bfloat16(3.5f)), bfloat16(tanf(3.5f)));
329*bf2c3715SXin Li }
330*bf2c3715SXin Li
test_array()331*bf2c3715SXin Li void test_array()
332*bf2c3715SXin Li {
333*bf2c3715SXin Li typedef Array<bfloat16,1,Dynamic> ArrayXh;
334*bf2c3715SXin Li Index size = internal::random<Index>(1,10);
335*bf2c3715SXin Li Index i = internal::random<Index>(0,size-1);
336*bf2c3715SXin Li ArrayXh a1 = ArrayXh::Random(size), a2 = ArrayXh::Random(size);
337*bf2c3715SXin Li VERIFY_IS_APPROX( a1+a1, bfloat16(2)*a1 );
338*bf2c3715SXin Li VERIFY( (a1.abs() >= bfloat16(0)).all() );
339*bf2c3715SXin Li VERIFY_IS_APPROX( (a1*a1).sqrt(), a1.abs() );
340*bf2c3715SXin Li
341*bf2c3715SXin Li VERIFY( ((a1.min)(a2) <= (a1.max)(a2)).all() );
342*bf2c3715SXin Li a1(i) = bfloat16(-10.);
343*bf2c3715SXin Li VERIFY_IS_EQUAL( a1.minCoeff(), bfloat16(-10.) );
344*bf2c3715SXin Li a1(i) = bfloat16(10.);
345*bf2c3715SXin Li VERIFY_IS_EQUAL( a1.maxCoeff(), bfloat16(10.) );
346*bf2c3715SXin Li
347*bf2c3715SXin Li std::stringstream ss;
348*bf2c3715SXin Li ss << a1;
349*bf2c3715SXin Li }
350*bf2c3715SXin Li
test_product()351*bf2c3715SXin Li void test_product()
352*bf2c3715SXin Li {
353*bf2c3715SXin Li typedef Matrix<bfloat16,Dynamic,Dynamic> MatrixXh;
354*bf2c3715SXin Li Index rows = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE);
355*bf2c3715SXin Li Index cols = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE);
356*bf2c3715SXin Li Index depth = internal::random<Index>(1,EIGEN_TEST_MAX_SIZE);
357*bf2c3715SXin Li MatrixXh Ah = MatrixXh::Random(rows,depth);
358*bf2c3715SXin Li MatrixXh Bh = MatrixXh::Random(depth,cols);
359*bf2c3715SXin Li MatrixXh Ch = MatrixXh::Random(rows,cols);
360*bf2c3715SXin Li MatrixXf Af = Ah.cast<float>();
361*bf2c3715SXin Li MatrixXf Bf = Bh.cast<float>();
362*bf2c3715SXin Li MatrixXf Cf = Ch.cast<float>();
363*bf2c3715SXin Li VERIFY_IS_APPROX(Ch.noalias()+=Ah*Bh, (Cf.noalias()+=Af*Bf).cast<bfloat16>());
364*bf2c3715SXin Li }
365*bf2c3715SXin Li
EIGEN_DECLARE_TEST(bfloat16_float)366*bf2c3715SXin Li EIGEN_DECLARE_TEST(bfloat16_float)
367*bf2c3715SXin Li {
368*bf2c3715SXin Li CALL_SUBTEST(test_numtraits());
369*bf2c3715SXin Li for(int i = 0; i < g_repeat; i++) {
370*bf2c3715SXin Li CALL_SUBTEST(test_conversion());
371*bf2c3715SXin Li CALL_SUBTEST(test_arithmetic());
372*bf2c3715SXin Li CALL_SUBTEST(test_comparison());
373*bf2c3715SXin Li CALL_SUBTEST(test_basic_functions());
374*bf2c3715SXin Li CALL_SUBTEST(test_trigonometric_functions());
375*bf2c3715SXin Li CALL_SUBTEST(test_array());
376*bf2c3715SXin Li CALL_SUBTEST(test_product());
377*bf2c3715SXin Li }
378*bf2c3715SXin Li }
379