1 #[cfg(feature = "bytemuck")] 2 use bytemuck::{Pod, Zeroable}; 3 use core::{ 4 cmp::Ordering, 5 iter::{Product, Sum}, 6 num::FpCategory, 7 ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign}, 8 }; 9 #[cfg(not(target_arch = "spirv"))] 10 use core::{ 11 fmt::{ 12 Binary, Debug, Display, Error, Formatter, LowerExp, LowerHex, Octal, UpperExp, UpperHex, 13 }, 14 num::ParseFloatError, 15 str::FromStr, 16 }; 17 #[cfg(feature = "serde")] 18 use serde::{Deserialize, Serialize}; 19 #[cfg(feature = "zerocopy")] 20 use zerocopy::{AsBytes, FromBytes}; 21 22 pub(crate) mod convert; 23 24 /// A 16-bit floating point type implementing the [`bfloat16`] format. 25 /// 26 /// The [`bfloat16`] floating point format is a truncated 16-bit version of the IEEE 754 standard 27 /// `binary32`, a.k.a [`f32`]. [`bf16`] has approximately the same dynamic range as [`f32`] by 28 /// having a lower precision than [`f16`][crate::f16]. While [`f16`][crate::f16] has a precision of 29 /// 11 bits, [`bf16`] has a precision of only 8 bits. 30 /// 31 /// Like [`f16`][crate::f16], [`bf16`] does not offer arithmetic operations as it is intended for 32 /// compact storage rather than calculations. Operations should be performed with [`f32`] or 33 /// higher-precision types and converted to/from [`bf16`] as necessary. 34 /// 35 /// [`bfloat16`]: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format 36 #[allow(non_camel_case_types)] 37 #[derive(Clone, Copy, Default)] 38 #[repr(transparent)] 39 #[cfg_attr(feature = "serde", derive(Serialize))] 40 #[cfg_attr(feature = "bytemuck", derive(Zeroable, Pod))] 41 #[cfg_attr(feature = "zerocopy", derive(AsBytes, FromBytes))] 42 pub struct bf16(u16); 43 44 impl bf16 { 45 /// Constructs a [`bf16`] value from the raw bits. 46 #[inline] 47 #[must_use] from_bits(bits: u16) -> bf1648 pub const fn from_bits(bits: u16) -> bf16 { 49 bf16(bits) 50 } 51 52 /// Constructs a [`bf16`] value from a 32-bit floating point value. 53 /// 54 /// If the 32-bit value is too large to fit, ±∞ will result. NaN values are preserved. 55 /// Subnormal values that are too tiny to be represented will result in ±0. All other values 56 /// are truncated and rounded to the nearest representable value. 57 #[inline] 58 #[must_use] from_f32(value: f32) -> bf1659 pub fn from_f32(value: f32) -> bf16 { 60 Self::from_f32_const(value) 61 } 62 63 /// Constructs a [`bf16`] value from a 32-bit floating point value. 64 /// 65 /// This function is identical to [`from_f32`][Self::from_f32] except it never uses hardware 66 /// intrinsics, which allows it to be `const`. [`from_f32`][Self::from_f32] should be preferred 67 /// in any non-`const` context. 68 /// 69 /// If the 32-bit value is too large to fit, ±∞ will result. NaN values are preserved. 70 /// Subnormal values that are too tiny to be represented will result in ±0. All other values 71 /// are truncated and rounded to the nearest representable value. 72 #[inline] 73 #[must_use] from_f32_const(value: f32) -> bf1674 pub const fn from_f32_const(value: f32) -> bf16 { 75 bf16(convert::f32_to_bf16(value)) 76 } 77 78 /// Constructs a [`bf16`] value from a 64-bit floating point value. 79 /// 80 /// If the 64-bit value is to large to fit, ±∞ will result. NaN values are preserved. 81 /// 64-bit subnormal values are too tiny to be represented and result in ±0. Exponents that 82 /// underflow the minimum exponent will result in subnormals or ±0. All other values are 83 /// truncated and rounded to the nearest representable value. 84 #[inline] 85 #[must_use] from_f64(value: f64) -> bf1686 pub fn from_f64(value: f64) -> bf16 { 87 Self::from_f64_const(value) 88 } 89 90 /// Constructs a [`bf16`] value from a 64-bit floating point value. 91 /// 92 /// This function is identical to [`from_f64`][Self::from_f64] except it never uses hardware 93 /// intrinsics, which allows it to be `const`. [`from_f64`][Self::from_f64] should be preferred 94 /// in any non-`const` context. 95 /// 96 /// If the 64-bit value is to large to fit, ±∞ will result. NaN values are preserved. 97 /// 64-bit subnormal values are too tiny to be represented and result in ±0. Exponents that 98 /// underflow the minimum exponent will result in subnormals or ±0. All other values are 99 /// truncated and rounded to the nearest representable value. 100 #[inline] 101 #[must_use] from_f64_const(value: f64) -> bf16102 pub const fn from_f64_const(value: f64) -> bf16 { 103 bf16(convert::f64_to_bf16(value)) 104 } 105 106 /// Converts a [`bf16`] into the underlying bit representation. 107 #[inline] 108 #[must_use] to_bits(self) -> u16109 pub const fn to_bits(self) -> u16 { 110 self.0 111 } 112 113 /// Returns the memory representation of the underlying bit representation as a byte array in 114 /// little-endian byte order. 115 /// 116 /// # Examples 117 /// 118 /// ```rust 119 /// # use half::prelude::*; 120 /// let bytes = bf16::from_f32(12.5).to_le_bytes(); 121 /// assert_eq!(bytes, [0x48, 0x41]); 122 /// ``` 123 #[inline] 124 #[must_use] to_le_bytes(self) -> [u8; 2]125 pub const fn to_le_bytes(self) -> [u8; 2] { 126 self.0.to_le_bytes() 127 } 128 129 /// Returns the memory representation of the underlying bit representation as a byte array in 130 /// big-endian (network) byte order. 131 /// 132 /// # Examples 133 /// 134 /// ```rust 135 /// # use half::prelude::*; 136 /// let bytes = bf16::from_f32(12.5).to_be_bytes(); 137 /// assert_eq!(bytes, [0x41, 0x48]); 138 /// ``` 139 #[inline] 140 #[must_use] to_be_bytes(self) -> [u8; 2]141 pub const fn to_be_bytes(self) -> [u8; 2] { 142 self.0.to_be_bytes() 143 } 144 145 /// Returns the memory representation of the underlying bit representation as a byte array in 146 /// native byte order. 147 /// 148 /// As the target platform's native endianness is used, portable code should use 149 /// [`to_be_bytes`][bf16::to_be_bytes] or [`to_le_bytes`][bf16::to_le_bytes], as appropriate, 150 /// instead. 151 /// 152 /// # Examples 153 /// 154 /// ```rust 155 /// # use half::prelude::*; 156 /// let bytes = bf16::from_f32(12.5).to_ne_bytes(); 157 /// assert_eq!(bytes, if cfg!(target_endian = "big") { 158 /// [0x41, 0x48] 159 /// } else { 160 /// [0x48, 0x41] 161 /// }); 162 /// ``` 163 #[inline] 164 #[must_use] to_ne_bytes(self) -> [u8; 2]165 pub const fn to_ne_bytes(self) -> [u8; 2] { 166 self.0.to_ne_bytes() 167 } 168 169 /// Creates a floating point value from its representation as a byte array in little endian. 170 /// 171 /// # Examples 172 /// 173 /// ```rust 174 /// # use half::prelude::*; 175 /// let value = bf16::from_le_bytes([0x48, 0x41]); 176 /// assert_eq!(value, bf16::from_f32(12.5)); 177 /// ``` 178 #[inline] 179 #[must_use] from_le_bytes(bytes: [u8; 2]) -> bf16180 pub const fn from_le_bytes(bytes: [u8; 2]) -> bf16 { 181 bf16::from_bits(u16::from_le_bytes(bytes)) 182 } 183 184 /// Creates a floating point value from its representation as a byte array in big endian. 185 /// 186 /// # Examples 187 /// 188 /// ```rust 189 /// # use half::prelude::*; 190 /// let value = bf16::from_be_bytes([0x41, 0x48]); 191 /// assert_eq!(value, bf16::from_f32(12.5)); 192 /// ``` 193 #[inline] 194 #[must_use] from_be_bytes(bytes: [u8; 2]) -> bf16195 pub const fn from_be_bytes(bytes: [u8; 2]) -> bf16 { 196 bf16::from_bits(u16::from_be_bytes(bytes)) 197 } 198 199 /// Creates a floating point value from its representation as a byte array in native endian. 200 /// 201 /// As the target platform's native endianness is used, portable code likely wants to use 202 /// [`from_be_bytes`][bf16::from_be_bytes] or [`from_le_bytes`][bf16::from_le_bytes], as 203 /// appropriate instead. 204 /// 205 /// # Examples 206 /// 207 /// ```rust 208 /// # use half::prelude::*; 209 /// let value = bf16::from_ne_bytes(if cfg!(target_endian = "big") { 210 /// [0x41, 0x48] 211 /// } else { 212 /// [0x48, 0x41] 213 /// }); 214 /// assert_eq!(value, bf16::from_f32(12.5)); 215 /// ``` 216 #[inline] 217 #[must_use] from_ne_bytes(bytes: [u8; 2]) -> bf16218 pub const fn from_ne_bytes(bytes: [u8; 2]) -> bf16 { 219 bf16::from_bits(u16::from_ne_bytes(bytes)) 220 } 221 222 /// Converts a [`bf16`] value into an [`f32`] value. 223 /// 224 /// This conversion is lossless as all values can be represented exactly in [`f32`]. 225 #[inline] 226 #[must_use] to_f32(self) -> f32227 pub fn to_f32(self) -> f32 { 228 self.to_f32_const() 229 } 230 231 /// Converts a [`bf16`] value into an [`f32`] value. 232 /// 233 /// This function is identical to [`to_f32`][Self::to_f32] except it never uses hardware 234 /// intrinsics, which allows it to be `const`. [`to_f32`][Self::to_f32] should be preferred 235 /// in any non-`const` context. 236 /// 237 /// This conversion is lossless as all values can be represented exactly in [`f32`]. 238 #[inline] 239 #[must_use] to_f32_const(self) -> f32240 pub const fn to_f32_const(self) -> f32 { 241 convert::bf16_to_f32(self.0) 242 } 243 244 /// Converts a [`bf16`] value into an [`f64`] value. 245 /// 246 /// This conversion is lossless as all values can be represented exactly in [`f64`]. 247 #[inline] 248 #[must_use] to_f64(self) -> f64249 pub fn to_f64(self) -> f64 { 250 self.to_f64_const() 251 } 252 253 /// Converts a [`bf16`] value into an [`f64`] value. 254 /// 255 /// This function is identical to [`to_f64`][Self::to_f64] except it never uses hardware 256 /// intrinsics, which allows it to be `const`. [`to_f64`][Self::to_f64] should be preferred 257 /// in any non-`const` context. 258 /// 259 /// This conversion is lossless as all values can be represented exactly in [`f64`]. 260 #[inline] 261 #[must_use] to_f64_const(self) -> f64262 pub const fn to_f64_const(self) -> f64 { 263 convert::bf16_to_f64(self.0) 264 } 265 266 /// Returns `true` if this value is NaN and `false` otherwise. 267 /// 268 /// # Examples 269 /// 270 /// ```rust 271 /// # use half::prelude::*; 272 /// 273 /// let nan = bf16::NAN; 274 /// let f = bf16::from_f32(7.0_f32); 275 /// 276 /// assert!(nan.is_nan()); 277 /// assert!(!f.is_nan()); 278 /// ``` 279 #[inline] 280 #[must_use] is_nan(self) -> bool281 pub const fn is_nan(self) -> bool { 282 self.0 & 0x7FFFu16 > 0x7F80u16 283 } 284 285 /// Returns `true` if this value is ±∞ and `false` otherwise. 286 /// 287 /// # Examples 288 /// 289 /// ```rust 290 /// # use half::prelude::*; 291 /// 292 /// let f = bf16::from_f32(7.0f32); 293 /// let inf = bf16::INFINITY; 294 /// let neg_inf = bf16::NEG_INFINITY; 295 /// let nan = bf16::NAN; 296 /// 297 /// assert!(!f.is_infinite()); 298 /// assert!(!nan.is_infinite()); 299 /// 300 /// assert!(inf.is_infinite()); 301 /// assert!(neg_inf.is_infinite()); 302 /// ``` 303 #[inline] 304 #[must_use] is_infinite(self) -> bool305 pub const fn is_infinite(self) -> bool { 306 self.0 & 0x7FFFu16 == 0x7F80u16 307 } 308 309 /// Returns `true` if this number is neither infinite nor NaN. 310 /// 311 /// # Examples 312 /// 313 /// ```rust 314 /// # use half::prelude::*; 315 /// 316 /// let f = bf16::from_f32(7.0f32); 317 /// let inf = bf16::INFINITY; 318 /// let neg_inf = bf16::NEG_INFINITY; 319 /// let nan = bf16::NAN; 320 /// 321 /// assert!(f.is_finite()); 322 /// 323 /// assert!(!nan.is_finite()); 324 /// assert!(!inf.is_finite()); 325 /// assert!(!neg_inf.is_finite()); 326 /// ``` 327 #[inline] 328 #[must_use] is_finite(self) -> bool329 pub const fn is_finite(self) -> bool { 330 self.0 & 0x7F80u16 != 0x7F80u16 331 } 332 333 /// Returns `true` if the number is neither zero, infinite, subnormal, or NaN. 334 /// 335 /// # Examples 336 /// 337 /// ```rust 338 /// # use half::prelude::*; 339 /// 340 /// let min = bf16::MIN_POSITIVE; 341 /// let max = bf16::MAX; 342 /// let lower_than_min = bf16::from_f32(1.0e-39_f32); 343 /// let zero = bf16::from_f32(0.0_f32); 344 /// 345 /// assert!(min.is_normal()); 346 /// assert!(max.is_normal()); 347 /// 348 /// assert!(!zero.is_normal()); 349 /// assert!(!bf16::NAN.is_normal()); 350 /// assert!(!bf16::INFINITY.is_normal()); 351 /// // Values between 0 and `min` are subnormal. 352 /// assert!(!lower_than_min.is_normal()); 353 /// ``` 354 #[inline] 355 #[must_use] is_normal(self) -> bool356 pub const fn is_normal(self) -> bool { 357 let exp = self.0 & 0x7F80u16; 358 exp != 0x7F80u16 && exp != 0 359 } 360 361 /// Returns the floating point category of the number. 362 /// 363 /// If only one property is going to be tested, it is generally faster to use the specific 364 /// predicate instead. 365 /// 366 /// # Examples 367 /// 368 /// ```rust 369 /// use std::num::FpCategory; 370 /// # use half::prelude::*; 371 /// 372 /// let num = bf16::from_f32(12.4_f32); 373 /// let inf = bf16::INFINITY; 374 /// 375 /// assert_eq!(num.classify(), FpCategory::Normal); 376 /// assert_eq!(inf.classify(), FpCategory::Infinite); 377 /// ``` 378 #[must_use] classify(self) -> FpCategory379 pub const fn classify(self) -> FpCategory { 380 let exp = self.0 & 0x7F80u16; 381 let man = self.0 & 0x007Fu16; 382 match (exp, man) { 383 (0, 0) => FpCategory::Zero, 384 (0, _) => FpCategory::Subnormal, 385 (0x7F80u16, 0) => FpCategory::Infinite, 386 (0x7F80u16, _) => FpCategory::Nan, 387 _ => FpCategory::Normal, 388 } 389 } 390 391 /// Returns a number that represents the sign of `self`. 392 /// 393 /// * 1.0 if the number is positive, +0.0 or [`INFINITY`][bf16::INFINITY] 394 /// * −1.0 if the number is negative, −0.0` or [`NEG_INFINITY`][bf16::NEG_INFINITY] 395 /// * [`NAN`][bf16::NAN] if the number is NaN 396 /// 397 /// # Examples 398 /// 399 /// ```rust 400 /// # use half::prelude::*; 401 /// 402 /// let f = bf16::from_f32(3.5_f32); 403 /// 404 /// assert_eq!(f.signum(), bf16::from_f32(1.0)); 405 /// assert_eq!(bf16::NEG_INFINITY.signum(), bf16::from_f32(-1.0)); 406 /// 407 /// assert!(bf16::NAN.signum().is_nan()); 408 /// ``` 409 #[must_use] signum(self) -> bf16410 pub const fn signum(self) -> bf16 { 411 if self.is_nan() { 412 self 413 } else if self.0 & 0x8000u16 != 0 { 414 Self::NEG_ONE 415 } else { 416 Self::ONE 417 } 418 } 419 420 /// Returns `true` if and only if `self` has a positive sign, including +0.0, NaNs with a 421 /// positive sign bit and +∞. 422 /// 423 /// # Examples 424 /// 425 /// ```rust 426 /// # use half::prelude::*; 427 /// 428 /// let nan = bf16::NAN; 429 /// let f = bf16::from_f32(7.0_f32); 430 /// let g = bf16::from_f32(-7.0_f32); 431 /// 432 /// assert!(f.is_sign_positive()); 433 /// assert!(!g.is_sign_positive()); 434 /// // NaN can be either positive or negative 435 /// assert!(nan.is_sign_positive() != nan.is_sign_negative()); 436 /// ``` 437 #[inline] 438 #[must_use] is_sign_positive(self) -> bool439 pub const fn is_sign_positive(self) -> bool { 440 self.0 & 0x8000u16 == 0 441 } 442 443 /// Returns `true` if and only if `self` has a negative sign, including −0.0, NaNs with a 444 /// negative sign bit and −∞. 445 /// 446 /// # Examples 447 /// 448 /// ```rust 449 /// # use half::prelude::*; 450 /// 451 /// let nan = bf16::NAN; 452 /// let f = bf16::from_f32(7.0f32); 453 /// let g = bf16::from_f32(-7.0f32); 454 /// 455 /// assert!(!f.is_sign_negative()); 456 /// assert!(g.is_sign_negative()); 457 /// // NaN can be either positive or negative 458 /// assert!(nan.is_sign_positive() != nan.is_sign_negative()); 459 /// ``` 460 #[inline] 461 #[must_use] is_sign_negative(self) -> bool462 pub const fn is_sign_negative(self) -> bool { 463 self.0 & 0x8000u16 != 0 464 } 465 466 /// Returns a number composed of the magnitude of `self` and the sign of `sign`. 467 /// 468 /// Equal to `self` if the sign of `self` and `sign` are the same, otherwise equal to `-self`. 469 /// If `self` is NaN, then NaN with the sign of `sign` is returned. 470 /// 471 /// # Examples 472 /// 473 /// ``` 474 /// # use half::prelude::*; 475 /// let f = bf16::from_f32(3.5); 476 /// 477 /// assert_eq!(f.copysign(bf16::from_f32(0.42)), bf16::from_f32(3.5)); 478 /// assert_eq!(f.copysign(bf16::from_f32(-0.42)), bf16::from_f32(-3.5)); 479 /// assert_eq!((-f).copysign(bf16::from_f32(0.42)), bf16::from_f32(3.5)); 480 /// assert_eq!((-f).copysign(bf16::from_f32(-0.42)), bf16::from_f32(-3.5)); 481 /// 482 /// assert!(bf16::NAN.copysign(bf16::from_f32(1.0)).is_nan()); 483 /// ``` 484 #[inline] 485 #[must_use] copysign(self, sign: bf16) -> bf16486 pub const fn copysign(self, sign: bf16) -> bf16 { 487 bf16((sign.0 & 0x8000u16) | (self.0 & 0x7FFFu16)) 488 } 489 490 /// Returns the maximum of the two numbers. 491 /// 492 /// If one of the arguments is NaN, then the other argument is returned. 493 /// 494 /// # Examples 495 /// 496 /// ``` 497 /// # use half::prelude::*; 498 /// let x = bf16::from_f32(1.0); 499 /// let y = bf16::from_f32(2.0); 500 /// 501 /// assert_eq!(x.max(y), y); 502 /// ``` 503 #[inline] 504 #[must_use] max(self, other: bf16) -> bf16505 pub fn max(self, other: bf16) -> bf16 { 506 if other > self && !other.is_nan() { 507 other 508 } else { 509 self 510 } 511 } 512 513 /// Returns the minimum of the two numbers. 514 /// 515 /// If one of the arguments is NaN, then the other argument is returned. 516 /// 517 /// # Examples 518 /// 519 /// ``` 520 /// # use half::prelude::*; 521 /// let x = bf16::from_f32(1.0); 522 /// let y = bf16::from_f32(2.0); 523 /// 524 /// assert_eq!(x.min(y), x); 525 /// ``` 526 #[inline] 527 #[must_use] min(self, other: bf16) -> bf16528 pub fn min(self, other: bf16) -> bf16 { 529 if other < self && !other.is_nan() { 530 other 531 } else { 532 self 533 } 534 } 535 536 /// Restrict a value to a certain interval unless it is NaN. 537 /// 538 /// Returns `max` if `self` is greater than `max`, and `min` if `self` is less than `min`. 539 /// Otherwise this returns `self`. 540 /// 541 /// Note that this function returns NaN if the initial value was NaN as well. 542 /// 543 /// # Panics 544 /// Panics if `min > max`, `min` is NaN, or `max` is NaN. 545 /// 546 /// # Examples 547 /// 548 /// ``` 549 /// # use half::prelude::*; 550 /// assert!(bf16::from_f32(-3.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(-2.0)); 551 /// assert!(bf16::from_f32(0.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(0.0)); 552 /// assert!(bf16::from_f32(2.0).clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)) == bf16::from_f32(1.0)); 553 /// assert!(bf16::NAN.clamp(bf16::from_f32(-2.0), bf16::from_f32(1.0)).is_nan()); 554 /// ``` 555 #[inline] 556 #[must_use] clamp(self, min: bf16, max: bf16) -> bf16557 pub fn clamp(self, min: bf16, max: bf16) -> bf16 { 558 assert!(min <= max); 559 let mut x = self; 560 if x < min { 561 x = min; 562 } 563 if x > max { 564 x = max; 565 } 566 x 567 } 568 569 /// Returns the ordering between `self` and `other`. 570 /// 571 /// Unlike the standard partial comparison between floating point numbers, 572 /// this comparison always produces an ordering in accordance to 573 /// the `totalOrder` predicate as defined in the IEEE 754 (2008 revision) 574 /// floating point standard. The values are ordered in the following sequence: 575 /// 576 /// - negative quiet NaN 577 /// - negative signaling NaN 578 /// - negative infinity 579 /// - negative numbers 580 /// - negative subnormal numbers 581 /// - negative zero 582 /// - positive zero 583 /// - positive subnormal numbers 584 /// - positive numbers 585 /// - positive infinity 586 /// - positive signaling NaN 587 /// - positive quiet NaN. 588 /// 589 /// The ordering established by this function does not always agree with the 590 /// [`PartialOrd`] and [`PartialEq`] implementations of `bf16`. For example, 591 /// they consider negative and positive zero equal, while `total_cmp` 592 /// doesn't. 593 /// 594 /// The interpretation of the signaling NaN bit follows the definition in 595 /// the IEEE 754 standard, which may not match the interpretation by some of 596 /// the older, non-conformant (e.g. MIPS) hardware implementations. 597 /// 598 /// # Examples 599 /// ``` 600 /// # use half::bf16; 601 /// let mut v: Vec<bf16> = vec![]; 602 /// v.push(bf16::ONE); 603 /// v.push(bf16::INFINITY); 604 /// v.push(bf16::NEG_INFINITY); 605 /// v.push(bf16::NAN); 606 /// v.push(bf16::MAX_SUBNORMAL); 607 /// v.push(-bf16::MAX_SUBNORMAL); 608 /// v.push(bf16::ZERO); 609 /// v.push(bf16::NEG_ZERO); 610 /// v.push(bf16::NEG_ONE); 611 /// v.push(bf16::MIN_POSITIVE); 612 /// 613 /// v.sort_by(|a, b| a.total_cmp(&b)); 614 /// 615 /// assert!(v 616 /// .into_iter() 617 /// .zip( 618 /// [ 619 /// bf16::NEG_INFINITY, 620 /// bf16::NEG_ONE, 621 /// -bf16::MAX_SUBNORMAL, 622 /// bf16::NEG_ZERO, 623 /// bf16::ZERO, 624 /// bf16::MAX_SUBNORMAL, 625 /// bf16::MIN_POSITIVE, 626 /// bf16::ONE, 627 /// bf16::INFINITY, 628 /// bf16::NAN 629 /// ] 630 /// .iter() 631 /// ) 632 /// .all(|(a, b)| a.to_bits() == b.to_bits())); 633 /// ``` 634 // Implementation based on: https://doc.rust-lang.org/std/primitive.f32.html#method.total_cmp 635 #[inline] 636 #[must_use] total_cmp(&self, other: &Self) -> Ordering637 pub fn total_cmp(&self, other: &Self) -> Ordering { 638 let mut left = self.to_bits() as i16; 639 let mut right = other.to_bits() as i16; 640 left ^= (((left >> 15) as u16) >> 1) as i16; 641 right ^= (((right >> 15) as u16) >> 1) as i16; 642 left.cmp(&right) 643 } 644 645 /// Alternate serialize adapter for serializing as a float. 646 /// 647 /// By default, [`bf16`] serializes as a newtype of [`u16`]. This is an alternate serialize 648 /// implementation that serializes as an [`f32`] value. It is designed for use with 649 /// `serialize_with` serde attributes. Deserialization from `f32` values is already supported by 650 /// the default deserialize implementation. 651 /// 652 /// # Examples 653 /// 654 /// A demonstration on how to use this adapater: 655 /// 656 /// ``` 657 /// use serde::{Serialize, Deserialize}; 658 /// use half::bf16; 659 /// 660 /// #[derive(Serialize, Deserialize)] 661 /// struct MyStruct { 662 /// #[serde(serialize_with = "bf16::serialize_as_f32")] 663 /// value: bf16 // Will be serialized as f32 instead of u16 664 /// } 665 /// ``` 666 #[cfg(feature = "serde")] serialize_as_f32<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error>667 pub fn serialize_as_f32<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { 668 serializer.serialize_f32(self.to_f32()) 669 } 670 671 /// Alternate serialize adapter for serializing as a string. 672 /// 673 /// By default, [`bf16`] serializes as a newtype of [`u16`]. This is an alternate serialize 674 /// implementation that serializes as a string value. It is designed for use with 675 /// `serialize_with` serde attributes. Deserialization from string values is already supported 676 /// by the default deserialize implementation. 677 /// 678 /// # Examples 679 /// 680 /// A demonstration on how to use this adapater: 681 /// 682 /// ``` 683 /// use serde::{Serialize, Deserialize}; 684 /// use half::bf16; 685 /// 686 /// #[derive(Serialize, Deserialize)] 687 /// struct MyStruct { 688 /// #[serde(serialize_with = "bf16::serialize_as_string")] 689 /// value: bf16 // Will be serialized as a string instead of u16 690 /// } 691 /// ``` 692 #[cfg(feature = "serde")] serialize_as_string<S: serde::Serializer>( &self, serializer: S, ) -> Result<S::Ok, S::Error>693 pub fn serialize_as_string<S: serde::Serializer>( 694 &self, 695 serializer: S, 696 ) -> Result<S::Ok, S::Error> { 697 serializer.serialize_str(&self.to_string()) 698 } 699 700 /// Approximate number of [`bf16`] significant digits in base 10 701 pub const DIGITS: u32 = 2; 702 /// [`bf16`] 703 /// [machine epsilon](https://en.wikipedia.org/wiki/Machine_epsilon) value 704 /// 705 /// This is the difference between 1.0 and the next largest representable number. 706 pub const EPSILON: bf16 = bf16(0x3C00u16); 707 /// [`bf16`] positive Infinity (+∞) 708 pub const INFINITY: bf16 = bf16(0x7F80u16); 709 /// Number of [`bf16`] significant digits in base 2 710 pub const MANTISSA_DIGITS: u32 = 8; 711 /// Largest finite [`bf16`] value 712 pub const MAX: bf16 = bf16(0x7F7F); 713 /// Maximum possible [`bf16`] power of 10 exponent 714 pub const MAX_10_EXP: i32 = 38; 715 /// Maximum possible [`bf16`] power of 2 exponent 716 pub const MAX_EXP: i32 = 128; 717 /// Smallest finite [`bf16`] value 718 pub const MIN: bf16 = bf16(0xFF7F); 719 /// Minimum possible normal [`bf16`] power of 10 exponent 720 pub const MIN_10_EXP: i32 = -37; 721 /// One greater than the minimum possible normal [`bf16`] power of 2 exponent 722 pub const MIN_EXP: i32 = -125; 723 /// Smallest positive normal [`bf16`] value 724 pub const MIN_POSITIVE: bf16 = bf16(0x0080u16); 725 /// [`bf16`] Not a Number (NaN) 726 pub const NAN: bf16 = bf16(0x7FC0u16); 727 /// [`bf16`] negative infinity (-∞). 728 pub const NEG_INFINITY: bf16 = bf16(0xFF80u16); 729 /// The radix or base of the internal representation of [`bf16`] 730 pub const RADIX: u32 = 2; 731 732 /// Minimum positive subnormal [`bf16`] value 733 pub const MIN_POSITIVE_SUBNORMAL: bf16 = bf16(0x0001u16); 734 /// Maximum subnormal [`bf16`] value 735 pub const MAX_SUBNORMAL: bf16 = bf16(0x007Fu16); 736 737 /// [`bf16`] 1 738 pub const ONE: bf16 = bf16(0x3F80u16); 739 /// [`bf16`] 0 740 pub const ZERO: bf16 = bf16(0x0000u16); 741 /// [`bf16`] -0 742 pub const NEG_ZERO: bf16 = bf16(0x8000u16); 743 /// [`bf16`] -1 744 pub const NEG_ONE: bf16 = bf16(0xBF80u16); 745 746 /// [`bf16`] Euler's number (ℯ) 747 pub const E: bf16 = bf16(0x402Eu16); 748 /// [`bf16`] Archimedes' constant (π) 749 pub const PI: bf16 = bf16(0x4049u16); 750 /// [`bf16`] 1/π 751 pub const FRAC_1_PI: bf16 = bf16(0x3EA3u16); 752 /// [`bf16`] 1/√2 753 pub const FRAC_1_SQRT_2: bf16 = bf16(0x3F35u16); 754 /// [`bf16`] 2/π 755 pub const FRAC_2_PI: bf16 = bf16(0x3F23u16); 756 /// [`bf16`] 2/√π 757 pub const FRAC_2_SQRT_PI: bf16 = bf16(0x3F90u16); 758 /// [`bf16`] π/2 759 pub const FRAC_PI_2: bf16 = bf16(0x3FC9u16); 760 /// [`bf16`] π/3 761 pub const FRAC_PI_3: bf16 = bf16(0x3F86u16); 762 /// [`bf16`] π/4 763 pub const FRAC_PI_4: bf16 = bf16(0x3F49u16); 764 /// [`bf16`] π/6 765 pub const FRAC_PI_6: bf16 = bf16(0x3F06u16); 766 /// [`bf16`] π/8 767 pub const FRAC_PI_8: bf16 = bf16(0x3EC9u16); 768 /// [`bf16`] 10 769 pub const LN_10: bf16 = bf16(0x4013u16); 770 /// [`bf16`] 2 771 pub const LN_2: bf16 = bf16(0x3F31u16); 772 /// [`bf16`] ₁₀ℯ 773 pub const LOG10_E: bf16 = bf16(0x3EDEu16); 774 /// [`bf16`] ₁₀2 775 pub const LOG10_2: bf16 = bf16(0x3E9Au16); 776 /// [`bf16`] ₂ℯ 777 pub const LOG2_E: bf16 = bf16(0x3FB9u16); 778 /// [`bf16`] ₂10 779 pub const LOG2_10: bf16 = bf16(0x4055u16); 780 /// [`bf16`] √2 781 pub const SQRT_2: bf16 = bf16(0x3FB5u16); 782 } 783 784 impl From<bf16> for f32 { 785 #[inline] from(x: bf16) -> f32786 fn from(x: bf16) -> f32 { 787 x.to_f32() 788 } 789 } 790 791 impl From<bf16> for f64 { 792 #[inline] from(x: bf16) -> f64793 fn from(x: bf16) -> f64 { 794 x.to_f64() 795 } 796 } 797 798 impl From<i8> for bf16 { 799 #[inline] from(x: i8) -> bf16800 fn from(x: i8) -> bf16 { 801 // Convert to f32, then to bf16 802 bf16::from_f32(f32::from(x)) 803 } 804 } 805 806 impl From<u8> for bf16 { 807 #[inline] from(x: u8) -> bf16808 fn from(x: u8) -> bf16 { 809 // Convert to f32, then to f16 810 bf16::from_f32(f32::from(x)) 811 } 812 } 813 814 impl PartialEq for bf16 { eq(&self, other: &bf16) -> bool815 fn eq(&self, other: &bf16) -> bool { 816 if self.is_nan() || other.is_nan() { 817 false 818 } else { 819 (self.0 == other.0) || ((self.0 | other.0) & 0x7FFFu16 == 0) 820 } 821 } 822 } 823 824 impl PartialOrd for bf16 { partial_cmp(&self, other: &bf16) -> Option<Ordering>825 fn partial_cmp(&self, other: &bf16) -> Option<Ordering> { 826 if self.is_nan() || other.is_nan() { 827 None 828 } else { 829 let neg = self.0 & 0x8000u16 != 0; 830 let other_neg = other.0 & 0x8000u16 != 0; 831 match (neg, other_neg) { 832 (false, false) => Some(self.0.cmp(&other.0)), 833 (false, true) => { 834 if (self.0 | other.0) & 0x7FFFu16 == 0 { 835 Some(Ordering::Equal) 836 } else { 837 Some(Ordering::Greater) 838 } 839 } 840 (true, false) => { 841 if (self.0 | other.0) & 0x7FFFu16 == 0 { 842 Some(Ordering::Equal) 843 } else { 844 Some(Ordering::Less) 845 } 846 } 847 (true, true) => Some(other.0.cmp(&self.0)), 848 } 849 } 850 } 851 lt(&self, other: &bf16) -> bool852 fn lt(&self, other: &bf16) -> bool { 853 if self.is_nan() || other.is_nan() { 854 false 855 } else { 856 let neg = self.0 & 0x8000u16 != 0; 857 let other_neg = other.0 & 0x8000u16 != 0; 858 match (neg, other_neg) { 859 (false, false) => self.0 < other.0, 860 (false, true) => false, 861 (true, false) => (self.0 | other.0) & 0x7FFFu16 != 0, 862 (true, true) => self.0 > other.0, 863 } 864 } 865 } 866 le(&self, other: &bf16) -> bool867 fn le(&self, other: &bf16) -> bool { 868 if self.is_nan() || other.is_nan() { 869 false 870 } else { 871 let neg = self.0 & 0x8000u16 != 0; 872 let other_neg = other.0 & 0x8000u16 != 0; 873 match (neg, other_neg) { 874 (false, false) => self.0 <= other.0, 875 (false, true) => (self.0 | other.0) & 0x7FFFu16 == 0, 876 (true, false) => true, 877 (true, true) => self.0 >= other.0, 878 } 879 } 880 } 881 gt(&self, other: &bf16) -> bool882 fn gt(&self, other: &bf16) -> bool { 883 if self.is_nan() || other.is_nan() { 884 false 885 } else { 886 let neg = self.0 & 0x8000u16 != 0; 887 let other_neg = other.0 & 0x8000u16 != 0; 888 match (neg, other_neg) { 889 (false, false) => self.0 > other.0, 890 (false, true) => (self.0 | other.0) & 0x7FFFu16 != 0, 891 (true, false) => false, 892 (true, true) => self.0 < other.0, 893 } 894 } 895 } 896 ge(&self, other: &bf16) -> bool897 fn ge(&self, other: &bf16) -> bool { 898 if self.is_nan() || other.is_nan() { 899 false 900 } else { 901 let neg = self.0 & 0x8000u16 != 0; 902 let other_neg = other.0 & 0x8000u16 != 0; 903 match (neg, other_neg) { 904 (false, false) => self.0 >= other.0, 905 (false, true) => true, 906 (true, false) => (self.0 | other.0) & 0x7FFFu16 == 0, 907 (true, true) => self.0 <= other.0, 908 } 909 } 910 } 911 } 912 913 #[cfg(not(target_arch = "spirv"))] 914 impl FromStr for bf16 { 915 type Err = ParseFloatError; from_str(src: &str) -> Result<bf16, ParseFloatError>916 fn from_str(src: &str) -> Result<bf16, ParseFloatError> { 917 f32::from_str(src).map(bf16::from_f32) 918 } 919 } 920 921 #[cfg(not(target_arch = "spirv"))] 922 impl Debug for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>923 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 924 write!(f, "{:?}", self.to_f32()) 925 } 926 } 927 928 #[cfg(not(target_arch = "spirv"))] 929 impl Display for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>930 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 931 write!(f, "{}", self.to_f32()) 932 } 933 } 934 935 #[cfg(not(target_arch = "spirv"))] 936 impl LowerExp for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>937 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 938 write!(f, "{:e}", self.to_f32()) 939 } 940 } 941 942 #[cfg(not(target_arch = "spirv"))] 943 impl UpperExp for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>944 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 945 write!(f, "{:E}", self.to_f32()) 946 } 947 } 948 949 #[cfg(not(target_arch = "spirv"))] 950 impl Binary for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>951 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 952 write!(f, "{:b}", self.0) 953 } 954 } 955 956 #[cfg(not(target_arch = "spirv"))] 957 impl Octal for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>958 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 959 write!(f, "{:o}", self.0) 960 } 961 } 962 963 #[cfg(not(target_arch = "spirv"))] 964 impl LowerHex for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>965 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 966 write!(f, "{:x}", self.0) 967 } 968 } 969 970 #[cfg(not(target_arch = "spirv"))] 971 impl UpperHex for bf16 { fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>972 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 973 write!(f, "{:X}", self.0) 974 } 975 } 976 977 impl Neg for bf16 { 978 type Output = Self; 979 neg(self) -> Self::Output980 fn neg(self) -> Self::Output { 981 Self(self.0 ^ 0x8000) 982 } 983 } 984 985 impl Neg for &bf16 { 986 type Output = <bf16 as Neg>::Output; 987 988 #[inline] neg(self) -> Self::Output989 fn neg(self) -> Self::Output { 990 Neg::neg(*self) 991 } 992 } 993 994 impl Add for bf16 { 995 type Output = Self; 996 add(self, rhs: Self) -> Self::Output997 fn add(self, rhs: Self) -> Self::Output { 998 Self::from_f32(Self::to_f32(self) + Self::to_f32(rhs)) 999 } 1000 } 1001 1002 impl Add<&bf16> for bf16 { 1003 type Output = <bf16 as Add<bf16>>::Output; 1004 1005 #[inline] add(self, rhs: &bf16) -> Self::Output1006 fn add(self, rhs: &bf16) -> Self::Output { 1007 self.add(*rhs) 1008 } 1009 } 1010 1011 impl Add<&bf16> for &bf16 { 1012 type Output = <bf16 as Add<bf16>>::Output; 1013 1014 #[inline] add(self, rhs: &bf16) -> Self::Output1015 fn add(self, rhs: &bf16) -> Self::Output { 1016 (*self).add(*rhs) 1017 } 1018 } 1019 1020 impl Add<bf16> for &bf16 { 1021 type Output = <bf16 as Add<bf16>>::Output; 1022 1023 #[inline] add(self, rhs: bf16) -> Self::Output1024 fn add(self, rhs: bf16) -> Self::Output { 1025 (*self).add(rhs) 1026 } 1027 } 1028 1029 impl AddAssign for bf16 { 1030 #[inline] add_assign(&mut self, rhs: Self)1031 fn add_assign(&mut self, rhs: Self) { 1032 *self = (*self).add(rhs); 1033 } 1034 } 1035 1036 impl AddAssign<&bf16> for bf16 { 1037 #[inline] add_assign(&mut self, rhs: &bf16)1038 fn add_assign(&mut self, rhs: &bf16) { 1039 *self = (*self).add(rhs); 1040 } 1041 } 1042 1043 impl Sub for bf16 { 1044 type Output = Self; 1045 sub(self, rhs: Self) -> Self::Output1046 fn sub(self, rhs: Self) -> Self::Output { 1047 Self::from_f32(Self::to_f32(self) - Self::to_f32(rhs)) 1048 } 1049 } 1050 1051 impl Sub<&bf16> for bf16 { 1052 type Output = <bf16 as Sub<bf16>>::Output; 1053 1054 #[inline] sub(self, rhs: &bf16) -> Self::Output1055 fn sub(self, rhs: &bf16) -> Self::Output { 1056 self.sub(*rhs) 1057 } 1058 } 1059 1060 impl Sub<&bf16> for &bf16 { 1061 type Output = <bf16 as Sub<bf16>>::Output; 1062 1063 #[inline] sub(self, rhs: &bf16) -> Self::Output1064 fn sub(self, rhs: &bf16) -> Self::Output { 1065 (*self).sub(*rhs) 1066 } 1067 } 1068 1069 impl Sub<bf16> for &bf16 { 1070 type Output = <bf16 as Sub<bf16>>::Output; 1071 1072 #[inline] sub(self, rhs: bf16) -> Self::Output1073 fn sub(self, rhs: bf16) -> Self::Output { 1074 (*self).sub(rhs) 1075 } 1076 } 1077 1078 impl SubAssign for bf16 { 1079 #[inline] sub_assign(&mut self, rhs: Self)1080 fn sub_assign(&mut self, rhs: Self) { 1081 *self = (*self).sub(rhs); 1082 } 1083 } 1084 1085 impl SubAssign<&bf16> for bf16 { 1086 #[inline] sub_assign(&mut self, rhs: &bf16)1087 fn sub_assign(&mut self, rhs: &bf16) { 1088 *self = (*self).sub(rhs); 1089 } 1090 } 1091 1092 impl Mul for bf16 { 1093 type Output = Self; 1094 mul(self, rhs: Self) -> Self::Output1095 fn mul(self, rhs: Self) -> Self::Output { 1096 Self::from_f32(Self::to_f32(self) * Self::to_f32(rhs)) 1097 } 1098 } 1099 1100 impl Mul<&bf16> for bf16 { 1101 type Output = <bf16 as Mul<bf16>>::Output; 1102 1103 #[inline] mul(self, rhs: &bf16) -> Self::Output1104 fn mul(self, rhs: &bf16) -> Self::Output { 1105 self.mul(*rhs) 1106 } 1107 } 1108 1109 impl Mul<&bf16> for &bf16 { 1110 type Output = <bf16 as Mul<bf16>>::Output; 1111 1112 #[inline] mul(self, rhs: &bf16) -> Self::Output1113 fn mul(self, rhs: &bf16) -> Self::Output { 1114 (*self).mul(*rhs) 1115 } 1116 } 1117 1118 impl Mul<bf16> for &bf16 { 1119 type Output = <bf16 as Mul<bf16>>::Output; 1120 1121 #[inline] mul(self, rhs: bf16) -> Self::Output1122 fn mul(self, rhs: bf16) -> Self::Output { 1123 (*self).mul(rhs) 1124 } 1125 } 1126 1127 impl MulAssign for bf16 { 1128 #[inline] mul_assign(&mut self, rhs: Self)1129 fn mul_assign(&mut self, rhs: Self) { 1130 *self = (*self).mul(rhs); 1131 } 1132 } 1133 1134 impl MulAssign<&bf16> for bf16 { 1135 #[inline] mul_assign(&mut self, rhs: &bf16)1136 fn mul_assign(&mut self, rhs: &bf16) { 1137 *self = (*self).mul(rhs); 1138 } 1139 } 1140 1141 impl Div for bf16 { 1142 type Output = Self; 1143 div(self, rhs: Self) -> Self::Output1144 fn div(self, rhs: Self) -> Self::Output { 1145 Self::from_f32(Self::to_f32(self) / Self::to_f32(rhs)) 1146 } 1147 } 1148 1149 impl Div<&bf16> for bf16 { 1150 type Output = <bf16 as Div<bf16>>::Output; 1151 1152 #[inline] div(self, rhs: &bf16) -> Self::Output1153 fn div(self, rhs: &bf16) -> Self::Output { 1154 self.div(*rhs) 1155 } 1156 } 1157 1158 impl Div<&bf16> for &bf16 { 1159 type Output = <bf16 as Div<bf16>>::Output; 1160 1161 #[inline] div(self, rhs: &bf16) -> Self::Output1162 fn div(self, rhs: &bf16) -> Self::Output { 1163 (*self).div(*rhs) 1164 } 1165 } 1166 1167 impl Div<bf16> for &bf16 { 1168 type Output = <bf16 as Div<bf16>>::Output; 1169 1170 #[inline] div(self, rhs: bf16) -> Self::Output1171 fn div(self, rhs: bf16) -> Self::Output { 1172 (*self).div(rhs) 1173 } 1174 } 1175 1176 impl DivAssign for bf16 { 1177 #[inline] div_assign(&mut self, rhs: Self)1178 fn div_assign(&mut self, rhs: Self) { 1179 *self = (*self).div(rhs); 1180 } 1181 } 1182 1183 impl DivAssign<&bf16> for bf16 { 1184 #[inline] div_assign(&mut self, rhs: &bf16)1185 fn div_assign(&mut self, rhs: &bf16) { 1186 *self = (*self).div(rhs); 1187 } 1188 } 1189 1190 impl Rem for bf16 { 1191 type Output = Self; 1192 rem(self, rhs: Self) -> Self::Output1193 fn rem(self, rhs: Self) -> Self::Output { 1194 Self::from_f32(Self::to_f32(self) % Self::to_f32(rhs)) 1195 } 1196 } 1197 1198 impl Rem<&bf16> for bf16 { 1199 type Output = <bf16 as Rem<bf16>>::Output; 1200 1201 #[inline] rem(self, rhs: &bf16) -> Self::Output1202 fn rem(self, rhs: &bf16) -> Self::Output { 1203 self.rem(*rhs) 1204 } 1205 } 1206 1207 impl Rem<&bf16> for &bf16 { 1208 type Output = <bf16 as Rem<bf16>>::Output; 1209 1210 #[inline] rem(self, rhs: &bf16) -> Self::Output1211 fn rem(self, rhs: &bf16) -> Self::Output { 1212 (*self).rem(*rhs) 1213 } 1214 } 1215 1216 impl Rem<bf16> for &bf16 { 1217 type Output = <bf16 as Rem<bf16>>::Output; 1218 1219 #[inline] rem(self, rhs: bf16) -> Self::Output1220 fn rem(self, rhs: bf16) -> Self::Output { 1221 (*self).rem(rhs) 1222 } 1223 } 1224 1225 impl RemAssign for bf16 { 1226 #[inline] rem_assign(&mut self, rhs: Self)1227 fn rem_assign(&mut self, rhs: Self) { 1228 *self = (*self).rem(rhs); 1229 } 1230 } 1231 1232 impl RemAssign<&bf16> for bf16 { 1233 #[inline] rem_assign(&mut self, rhs: &bf16)1234 fn rem_assign(&mut self, rhs: &bf16) { 1235 *self = (*self).rem(rhs); 1236 } 1237 } 1238 1239 impl Product for bf16 { 1240 #[inline] product<I: Iterator<Item = Self>>(iter: I) -> Self1241 fn product<I: Iterator<Item = Self>>(iter: I) -> Self { 1242 bf16::from_f32(iter.map(|f| f.to_f32()).product()) 1243 } 1244 } 1245 1246 impl<'a> Product<&'a bf16> for bf16 { 1247 #[inline] product<I: Iterator<Item = &'a bf16>>(iter: I) -> Self1248 fn product<I: Iterator<Item = &'a bf16>>(iter: I) -> Self { 1249 bf16::from_f32(iter.map(|f| f.to_f32()).product()) 1250 } 1251 } 1252 1253 impl Sum for bf16 { 1254 #[inline] sum<I: Iterator<Item = Self>>(iter: I) -> Self1255 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self { 1256 bf16::from_f32(iter.map(|f| f.to_f32()).sum()) 1257 } 1258 } 1259 1260 impl<'a> Sum<&'a bf16> for bf16 { 1261 #[inline] sum<I: Iterator<Item = &'a bf16>>(iter: I) -> Self1262 fn sum<I: Iterator<Item = &'a bf16>>(iter: I) -> Self { 1263 bf16::from_f32(iter.map(|f| f.to_f32()).product()) 1264 } 1265 } 1266 1267 #[cfg(feature = "serde")] 1268 struct Visitor; 1269 1270 #[cfg(feature = "serde")] 1271 impl<'de> Deserialize<'de> for bf16 { deserialize<D>(deserializer: D) -> Result<bf16, D::Error> where D: serde::de::Deserializer<'de>,1272 fn deserialize<D>(deserializer: D) -> Result<bf16, D::Error> 1273 where 1274 D: serde::de::Deserializer<'de>, 1275 { 1276 deserializer.deserialize_newtype_struct("bf16", Visitor) 1277 } 1278 } 1279 1280 #[cfg(feature = "serde")] 1281 impl<'de> serde::de::Visitor<'de> for Visitor { 1282 type Value = bf16; 1283 expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result1284 fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result { 1285 write!(formatter, "tuple struct bf16") 1286 } 1287 visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error> where D: serde::Deserializer<'de>,1288 fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error> 1289 where 1290 D: serde::Deserializer<'de>, 1291 { 1292 Ok(bf16(<u16 as Deserialize>::deserialize(deserializer)?)) 1293 } 1294 visit_str<E>(self, v: &str) -> Result<Self::Value, E> where E: serde::de::Error,1295 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> 1296 where 1297 E: serde::de::Error, 1298 { 1299 v.parse().map_err(|_| { 1300 serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &"a float string") 1301 }) 1302 } 1303 visit_f32<E>(self, v: f32) -> Result<Self::Value, E> where E: serde::de::Error,1304 fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E> 1305 where 1306 E: serde::de::Error, 1307 { 1308 Ok(bf16::from_f32(v)) 1309 } 1310 visit_f64<E>(self, v: f64) -> Result<Self::Value, E> where E: serde::de::Error,1311 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E> 1312 where 1313 E: serde::de::Error, 1314 { 1315 Ok(bf16::from_f64(v)) 1316 } 1317 } 1318 1319 #[allow( 1320 clippy::cognitive_complexity, 1321 clippy::float_cmp, 1322 clippy::neg_cmp_op_on_partial_ord 1323 )] 1324 #[cfg(test)] 1325 mod test { 1326 use super::*; 1327 use core::cmp::Ordering; 1328 #[cfg(feature = "num-traits")] 1329 use num_traits::{AsPrimitive, FromPrimitive, ToPrimitive}; 1330 use quickcheck_macros::quickcheck; 1331 1332 #[cfg(feature = "num-traits")] 1333 #[test] as_primitive()1334 fn as_primitive() { 1335 let two = bf16::from_f32(2.0); 1336 assert_eq!(<i32 as AsPrimitive<bf16>>::as_(2), two); 1337 assert_eq!(<bf16 as AsPrimitive<i32>>::as_(two), 2); 1338 1339 assert_eq!(<f32 as AsPrimitive<bf16>>::as_(2.0), two); 1340 assert_eq!(<bf16 as AsPrimitive<f32>>::as_(two), 2.0); 1341 1342 assert_eq!(<f64 as AsPrimitive<bf16>>::as_(2.0), two); 1343 assert_eq!(<bf16 as AsPrimitive<f64>>::as_(two), 2.0); 1344 } 1345 1346 #[cfg(feature = "num-traits")] 1347 #[test] to_primitive()1348 fn to_primitive() { 1349 let two = bf16::from_f32(2.0); 1350 assert_eq!(ToPrimitive::to_i32(&two).unwrap(), 2i32); 1351 assert_eq!(ToPrimitive::to_f32(&two).unwrap(), 2.0f32); 1352 assert_eq!(ToPrimitive::to_f64(&two).unwrap(), 2.0f64); 1353 } 1354 1355 #[cfg(feature = "num-traits")] 1356 #[test] from_primitive()1357 fn from_primitive() { 1358 let two = bf16::from_f32(2.0); 1359 assert_eq!(<bf16 as FromPrimitive>::from_i32(2).unwrap(), two); 1360 assert_eq!(<bf16 as FromPrimitive>::from_f32(2.0).unwrap(), two); 1361 assert_eq!(<bf16 as FromPrimitive>::from_f64(2.0).unwrap(), two); 1362 } 1363 1364 #[test] test_bf16_consts_from_f32()1365 fn test_bf16_consts_from_f32() { 1366 let one = bf16::from_f32(1.0); 1367 let zero = bf16::from_f32(0.0); 1368 let neg_zero = bf16::from_f32(-0.0); 1369 let neg_one = bf16::from_f32(-1.0); 1370 let inf = bf16::from_f32(core::f32::INFINITY); 1371 let neg_inf = bf16::from_f32(core::f32::NEG_INFINITY); 1372 let nan = bf16::from_f32(core::f32::NAN); 1373 1374 assert_eq!(bf16::ONE, one); 1375 assert_eq!(bf16::ZERO, zero); 1376 assert!(zero.is_sign_positive()); 1377 assert_eq!(bf16::NEG_ZERO, neg_zero); 1378 assert!(neg_zero.is_sign_negative()); 1379 assert_eq!(bf16::NEG_ONE, neg_one); 1380 assert!(neg_one.is_sign_negative()); 1381 assert_eq!(bf16::INFINITY, inf); 1382 assert_eq!(bf16::NEG_INFINITY, neg_inf); 1383 assert!(nan.is_nan()); 1384 assert!(bf16::NAN.is_nan()); 1385 1386 let e = bf16::from_f32(core::f32::consts::E); 1387 let pi = bf16::from_f32(core::f32::consts::PI); 1388 let frac_1_pi = bf16::from_f32(core::f32::consts::FRAC_1_PI); 1389 let frac_1_sqrt_2 = bf16::from_f32(core::f32::consts::FRAC_1_SQRT_2); 1390 let frac_2_pi = bf16::from_f32(core::f32::consts::FRAC_2_PI); 1391 let frac_2_sqrt_pi = bf16::from_f32(core::f32::consts::FRAC_2_SQRT_PI); 1392 let frac_pi_2 = bf16::from_f32(core::f32::consts::FRAC_PI_2); 1393 let frac_pi_3 = bf16::from_f32(core::f32::consts::FRAC_PI_3); 1394 let frac_pi_4 = bf16::from_f32(core::f32::consts::FRAC_PI_4); 1395 let frac_pi_6 = bf16::from_f32(core::f32::consts::FRAC_PI_6); 1396 let frac_pi_8 = bf16::from_f32(core::f32::consts::FRAC_PI_8); 1397 let ln_10 = bf16::from_f32(core::f32::consts::LN_10); 1398 let ln_2 = bf16::from_f32(core::f32::consts::LN_2); 1399 let log10_e = bf16::from_f32(core::f32::consts::LOG10_E); 1400 // core::f32::consts::LOG10_2 requires rustc 1.43.0 1401 let log10_2 = bf16::from_f32(2f32.log10()); 1402 let log2_e = bf16::from_f32(core::f32::consts::LOG2_E); 1403 // core::f32::consts::LOG2_10 requires rustc 1.43.0 1404 let log2_10 = bf16::from_f32(10f32.log2()); 1405 let sqrt_2 = bf16::from_f32(core::f32::consts::SQRT_2); 1406 1407 assert_eq!(bf16::E, e); 1408 assert_eq!(bf16::PI, pi); 1409 assert_eq!(bf16::FRAC_1_PI, frac_1_pi); 1410 assert_eq!(bf16::FRAC_1_SQRT_2, frac_1_sqrt_2); 1411 assert_eq!(bf16::FRAC_2_PI, frac_2_pi); 1412 assert_eq!(bf16::FRAC_2_SQRT_PI, frac_2_sqrt_pi); 1413 assert_eq!(bf16::FRAC_PI_2, frac_pi_2); 1414 assert_eq!(bf16::FRAC_PI_3, frac_pi_3); 1415 assert_eq!(bf16::FRAC_PI_4, frac_pi_4); 1416 assert_eq!(bf16::FRAC_PI_6, frac_pi_6); 1417 assert_eq!(bf16::FRAC_PI_8, frac_pi_8); 1418 assert_eq!(bf16::LN_10, ln_10); 1419 assert_eq!(bf16::LN_2, ln_2); 1420 assert_eq!(bf16::LOG10_E, log10_e); 1421 assert_eq!(bf16::LOG10_2, log10_2); 1422 assert_eq!(bf16::LOG2_E, log2_e); 1423 assert_eq!(bf16::LOG2_10, log2_10); 1424 assert_eq!(bf16::SQRT_2, sqrt_2); 1425 } 1426 1427 #[test] test_bf16_consts_from_f64()1428 fn test_bf16_consts_from_f64() { 1429 let one = bf16::from_f64(1.0); 1430 let zero = bf16::from_f64(0.0); 1431 let neg_zero = bf16::from_f64(-0.0); 1432 let inf = bf16::from_f64(core::f64::INFINITY); 1433 let neg_inf = bf16::from_f64(core::f64::NEG_INFINITY); 1434 let nan = bf16::from_f64(core::f64::NAN); 1435 1436 assert_eq!(bf16::ONE, one); 1437 assert_eq!(bf16::ZERO, zero); 1438 assert_eq!(bf16::NEG_ZERO, neg_zero); 1439 assert_eq!(bf16::INFINITY, inf); 1440 assert_eq!(bf16::NEG_INFINITY, neg_inf); 1441 assert!(nan.is_nan()); 1442 assert!(bf16::NAN.is_nan()); 1443 1444 let e = bf16::from_f64(core::f64::consts::E); 1445 let pi = bf16::from_f64(core::f64::consts::PI); 1446 let frac_1_pi = bf16::from_f64(core::f64::consts::FRAC_1_PI); 1447 let frac_1_sqrt_2 = bf16::from_f64(core::f64::consts::FRAC_1_SQRT_2); 1448 let frac_2_pi = bf16::from_f64(core::f64::consts::FRAC_2_PI); 1449 let frac_2_sqrt_pi = bf16::from_f64(core::f64::consts::FRAC_2_SQRT_PI); 1450 let frac_pi_2 = bf16::from_f64(core::f64::consts::FRAC_PI_2); 1451 let frac_pi_3 = bf16::from_f64(core::f64::consts::FRAC_PI_3); 1452 let frac_pi_4 = bf16::from_f64(core::f64::consts::FRAC_PI_4); 1453 let frac_pi_6 = bf16::from_f64(core::f64::consts::FRAC_PI_6); 1454 let frac_pi_8 = bf16::from_f64(core::f64::consts::FRAC_PI_8); 1455 let ln_10 = bf16::from_f64(core::f64::consts::LN_10); 1456 let ln_2 = bf16::from_f64(core::f64::consts::LN_2); 1457 let log10_e = bf16::from_f64(core::f64::consts::LOG10_E); 1458 // core::f64::consts::LOG10_2 requires rustc 1.43.0 1459 let log10_2 = bf16::from_f64(2f64.log10()); 1460 let log2_e = bf16::from_f64(core::f64::consts::LOG2_E); 1461 // core::f64::consts::LOG2_10 requires rustc 1.43.0 1462 let log2_10 = bf16::from_f64(10f64.log2()); 1463 let sqrt_2 = bf16::from_f64(core::f64::consts::SQRT_2); 1464 1465 assert_eq!(bf16::E, e); 1466 assert_eq!(bf16::PI, pi); 1467 assert_eq!(bf16::FRAC_1_PI, frac_1_pi); 1468 assert_eq!(bf16::FRAC_1_SQRT_2, frac_1_sqrt_2); 1469 assert_eq!(bf16::FRAC_2_PI, frac_2_pi); 1470 assert_eq!(bf16::FRAC_2_SQRT_PI, frac_2_sqrt_pi); 1471 assert_eq!(bf16::FRAC_PI_2, frac_pi_2); 1472 assert_eq!(bf16::FRAC_PI_3, frac_pi_3); 1473 assert_eq!(bf16::FRAC_PI_4, frac_pi_4); 1474 assert_eq!(bf16::FRAC_PI_6, frac_pi_6); 1475 assert_eq!(bf16::FRAC_PI_8, frac_pi_8); 1476 assert_eq!(bf16::LN_10, ln_10); 1477 assert_eq!(bf16::LN_2, ln_2); 1478 assert_eq!(bf16::LOG10_E, log10_e); 1479 assert_eq!(bf16::LOG10_2, log10_2); 1480 assert_eq!(bf16::LOG2_E, log2_e); 1481 assert_eq!(bf16::LOG2_10, log2_10); 1482 assert_eq!(bf16::SQRT_2, sqrt_2); 1483 } 1484 1485 #[test] test_nan_conversion_to_smaller()1486 fn test_nan_conversion_to_smaller() { 1487 let nan64 = f64::from_bits(0x7FF0_0000_0000_0001u64); 1488 let neg_nan64 = f64::from_bits(0xFFF0_0000_0000_0001u64); 1489 let nan32 = f32::from_bits(0x7F80_0001u32); 1490 let neg_nan32 = f32::from_bits(0xFF80_0001u32); 1491 let nan32_from_64 = nan64 as f32; 1492 let neg_nan32_from_64 = neg_nan64 as f32; 1493 let nan16_from_64 = bf16::from_f64(nan64); 1494 let neg_nan16_from_64 = bf16::from_f64(neg_nan64); 1495 let nan16_from_32 = bf16::from_f32(nan32); 1496 let neg_nan16_from_32 = bf16::from_f32(neg_nan32); 1497 1498 assert!(nan64.is_nan() && nan64.is_sign_positive()); 1499 assert!(neg_nan64.is_nan() && neg_nan64.is_sign_negative()); 1500 assert!(nan32.is_nan() && nan32.is_sign_positive()); 1501 assert!(neg_nan32.is_nan() && neg_nan32.is_sign_negative()); 1502 assert!(nan32_from_64.is_nan() && nan32_from_64.is_sign_positive()); 1503 assert!(neg_nan32_from_64.is_nan() && neg_nan32_from_64.is_sign_negative()); 1504 assert!(nan16_from_64.is_nan() && nan16_from_64.is_sign_positive()); 1505 assert!(neg_nan16_from_64.is_nan() && neg_nan16_from_64.is_sign_negative()); 1506 assert!(nan16_from_32.is_nan() && nan16_from_32.is_sign_positive()); 1507 assert!(neg_nan16_from_32.is_nan() && neg_nan16_from_32.is_sign_negative()); 1508 } 1509 1510 #[test] test_nan_conversion_to_larger()1511 fn test_nan_conversion_to_larger() { 1512 let nan16 = bf16::from_bits(0x7F81u16); 1513 let neg_nan16 = bf16::from_bits(0xFF81u16); 1514 let nan32 = f32::from_bits(0x7F80_0001u32); 1515 let neg_nan32 = f32::from_bits(0xFF80_0001u32); 1516 let nan32_from_16 = f32::from(nan16); 1517 let neg_nan32_from_16 = f32::from(neg_nan16); 1518 let nan64_from_16 = f64::from(nan16); 1519 let neg_nan64_from_16 = f64::from(neg_nan16); 1520 let nan64_from_32 = f64::from(nan32); 1521 let neg_nan64_from_32 = f64::from(neg_nan32); 1522 1523 assert!(nan16.is_nan() && nan16.is_sign_positive()); 1524 assert!(neg_nan16.is_nan() && neg_nan16.is_sign_negative()); 1525 assert!(nan32.is_nan() && nan32.is_sign_positive()); 1526 assert!(neg_nan32.is_nan() && neg_nan32.is_sign_negative()); 1527 assert!(nan32_from_16.is_nan() && nan32_from_16.is_sign_positive()); 1528 assert!(neg_nan32_from_16.is_nan() && neg_nan32_from_16.is_sign_negative()); 1529 assert!(nan64_from_16.is_nan() && nan64_from_16.is_sign_positive()); 1530 assert!(neg_nan64_from_16.is_nan() && neg_nan64_from_16.is_sign_negative()); 1531 assert!(nan64_from_32.is_nan() && nan64_from_32.is_sign_positive()); 1532 assert!(neg_nan64_from_32.is_nan() && neg_nan64_from_32.is_sign_negative()); 1533 } 1534 1535 #[test] test_bf16_to_f32()1536 fn test_bf16_to_f32() { 1537 let f = bf16::from_f32(7.0); 1538 assert_eq!(f.to_f32(), 7.0f32); 1539 1540 // 7.1 is NOT exactly representable in 16-bit, it's rounded 1541 let f = bf16::from_f32(7.1); 1542 let diff = (f.to_f32() - 7.1f32).abs(); 1543 // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1 1544 assert!(diff <= 4.0 * bf16::EPSILON.to_f32()); 1545 1546 let tiny32 = f32::from_bits(0x0001_0000u32); 1547 assert_eq!(bf16::from_bits(0x0001).to_f32(), tiny32); 1548 assert_eq!(bf16::from_bits(0x0005).to_f32(), 5.0 * tiny32); 1549 1550 assert_eq!(bf16::from_bits(0x0001), bf16::from_f32(tiny32)); 1551 assert_eq!(bf16::from_bits(0x0005), bf16::from_f32(5.0 * tiny32)); 1552 } 1553 1554 #[test] test_bf16_to_f64()1555 fn test_bf16_to_f64() { 1556 let f = bf16::from_f64(7.0); 1557 assert_eq!(f.to_f64(), 7.0f64); 1558 1559 // 7.1 is NOT exactly representable in 16-bit, it's rounded 1560 let f = bf16::from_f64(7.1); 1561 let diff = (f.to_f64() - 7.1f64).abs(); 1562 // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1 1563 assert!(diff <= 4.0 * bf16::EPSILON.to_f64()); 1564 1565 let tiny64 = 2.0f64.powi(-133); 1566 assert_eq!(bf16::from_bits(0x0001).to_f64(), tiny64); 1567 assert_eq!(bf16::from_bits(0x0005).to_f64(), 5.0 * tiny64); 1568 1569 assert_eq!(bf16::from_bits(0x0001), bf16::from_f64(tiny64)); 1570 assert_eq!(bf16::from_bits(0x0005), bf16::from_f64(5.0 * tiny64)); 1571 } 1572 1573 #[test] test_comparisons()1574 fn test_comparisons() { 1575 let zero = bf16::from_f64(0.0); 1576 let one = bf16::from_f64(1.0); 1577 let neg_zero = bf16::from_f64(-0.0); 1578 let neg_one = bf16::from_f64(-1.0); 1579 1580 assert_eq!(zero.partial_cmp(&neg_zero), Some(Ordering::Equal)); 1581 assert_eq!(neg_zero.partial_cmp(&zero), Some(Ordering::Equal)); 1582 assert!(zero == neg_zero); 1583 assert!(neg_zero == zero); 1584 assert!(!(zero != neg_zero)); 1585 assert!(!(neg_zero != zero)); 1586 assert!(!(zero < neg_zero)); 1587 assert!(!(neg_zero < zero)); 1588 assert!(zero <= neg_zero); 1589 assert!(neg_zero <= zero); 1590 assert!(!(zero > neg_zero)); 1591 assert!(!(neg_zero > zero)); 1592 assert!(zero >= neg_zero); 1593 assert!(neg_zero >= zero); 1594 1595 assert_eq!(one.partial_cmp(&neg_zero), Some(Ordering::Greater)); 1596 assert_eq!(neg_zero.partial_cmp(&one), Some(Ordering::Less)); 1597 assert!(!(one == neg_zero)); 1598 assert!(!(neg_zero == one)); 1599 assert!(one != neg_zero); 1600 assert!(neg_zero != one); 1601 assert!(!(one < neg_zero)); 1602 assert!(neg_zero < one); 1603 assert!(!(one <= neg_zero)); 1604 assert!(neg_zero <= one); 1605 assert!(one > neg_zero); 1606 assert!(!(neg_zero > one)); 1607 assert!(one >= neg_zero); 1608 assert!(!(neg_zero >= one)); 1609 1610 assert_eq!(one.partial_cmp(&neg_one), Some(Ordering::Greater)); 1611 assert_eq!(neg_one.partial_cmp(&one), Some(Ordering::Less)); 1612 assert!(!(one == neg_one)); 1613 assert!(!(neg_one == one)); 1614 assert!(one != neg_one); 1615 assert!(neg_one != one); 1616 assert!(!(one < neg_one)); 1617 assert!(neg_one < one); 1618 assert!(!(one <= neg_one)); 1619 assert!(neg_one <= one); 1620 assert!(one > neg_one); 1621 assert!(!(neg_one > one)); 1622 assert!(one >= neg_one); 1623 assert!(!(neg_one >= one)); 1624 } 1625 1626 #[test] 1627 #[allow(clippy::erasing_op, clippy::identity_op)] round_to_even_f32()1628 fn round_to_even_f32() { 1629 // smallest positive subnormal = 0b0.0000_001 * 2^-126 = 2^-133 1630 let min_sub = bf16::from_bits(1); 1631 let min_sub_f = (-133f32).exp2(); 1632 assert_eq!(bf16::from_f32(min_sub_f).to_bits(), min_sub.to_bits()); 1633 assert_eq!(f32::from(min_sub).to_bits(), min_sub_f.to_bits()); 1634 1635 // 0.0000000_011111 rounded to 0.0000000 (< tie, no rounding) 1636 // 0.0000000_100000 rounded to 0.0000000 (tie and even, remains at even) 1637 // 0.0000000_100001 rounded to 0.0000001 (> tie, rounds up) 1638 assert_eq!( 1639 bf16::from_f32(min_sub_f * 0.49).to_bits(), 1640 min_sub.to_bits() * 0 1641 ); 1642 assert_eq!( 1643 bf16::from_f32(min_sub_f * 0.50).to_bits(), 1644 min_sub.to_bits() * 0 1645 ); 1646 assert_eq!( 1647 bf16::from_f32(min_sub_f * 0.51).to_bits(), 1648 min_sub.to_bits() * 1 1649 ); 1650 1651 // 0.0000001_011111 rounded to 0.0000001 (< tie, no rounding) 1652 // 0.0000001_100000 rounded to 0.0000010 (tie and odd, rounds up to even) 1653 // 0.0000001_100001 rounded to 0.0000010 (> tie, rounds up) 1654 assert_eq!( 1655 bf16::from_f32(min_sub_f * 1.49).to_bits(), 1656 min_sub.to_bits() * 1 1657 ); 1658 assert_eq!( 1659 bf16::from_f32(min_sub_f * 1.50).to_bits(), 1660 min_sub.to_bits() * 2 1661 ); 1662 assert_eq!( 1663 bf16::from_f32(min_sub_f * 1.51).to_bits(), 1664 min_sub.to_bits() * 2 1665 ); 1666 1667 // 0.0000010_011111 rounded to 0.0000010 (< tie, no rounding) 1668 // 0.0000010_100000 rounded to 0.0000010 (tie and even, remains at even) 1669 // 0.0000010_100001 rounded to 0.0000011 (> tie, rounds up) 1670 assert_eq!( 1671 bf16::from_f32(min_sub_f * 2.49).to_bits(), 1672 min_sub.to_bits() * 2 1673 ); 1674 assert_eq!( 1675 bf16::from_f32(min_sub_f * 2.50).to_bits(), 1676 min_sub.to_bits() * 2 1677 ); 1678 assert_eq!( 1679 bf16::from_f32(min_sub_f * 2.51).to_bits(), 1680 min_sub.to_bits() * 3 1681 ); 1682 1683 assert_eq!( 1684 bf16::from_f32(250.49f32).to_bits(), 1685 bf16::from_f32(250.0).to_bits() 1686 ); 1687 assert_eq!( 1688 bf16::from_f32(250.50f32).to_bits(), 1689 bf16::from_f32(250.0).to_bits() 1690 ); 1691 assert_eq!( 1692 bf16::from_f32(250.51f32).to_bits(), 1693 bf16::from_f32(251.0).to_bits() 1694 ); 1695 assert_eq!( 1696 bf16::from_f32(251.49f32).to_bits(), 1697 bf16::from_f32(251.0).to_bits() 1698 ); 1699 assert_eq!( 1700 bf16::from_f32(251.50f32).to_bits(), 1701 bf16::from_f32(252.0).to_bits() 1702 ); 1703 assert_eq!( 1704 bf16::from_f32(251.51f32).to_bits(), 1705 bf16::from_f32(252.0).to_bits() 1706 ); 1707 assert_eq!( 1708 bf16::from_f32(252.49f32).to_bits(), 1709 bf16::from_f32(252.0).to_bits() 1710 ); 1711 assert_eq!( 1712 bf16::from_f32(252.50f32).to_bits(), 1713 bf16::from_f32(252.0).to_bits() 1714 ); 1715 assert_eq!( 1716 bf16::from_f32(252.51f32).to_bits(), 1717 bf16::from_f32(253.0).to_bits() 1718 ); 1719 } 1720 1721 #[test] 1722 #[allow(clippy::erasing_op, clippy::identity_op)] round_to_even_f64()1723 fn round_to_even_f64() { 1724 // smallest positive subnormal = 0b0.0000_001 * 2^-126 = 2^-133 1725 let min_sub = bf16::from_bits(1); 1726 let min_sub_f = (-133f64).exp2(); 1727 assert_eq!(bf16::from_f64(min_sub_f).to_bits(), min_sub.to_bits()); 1728 assert_eq!(f64::from(min_sub).to_bits(), min_sub_f.to_bits()); 1729 1730 // 0.0000000_011111 rounded to 0.0000000 (< tie, no rounding) 1731 // 0.0000000_100000 rounded to 0.0000000 (tie and even, remains at even) 1732 // 0.0000000_100001 rounded to 0.0000001 (> tie, rounds up) 1733 assert_eq!( 1734 bf16::from_f64(min_sub_f * 0.49).to_bits(), 1735 min_sub.to_bits() * 0 1736 ); 1737 assert_eq!( 1738 bf16::from_f64(min_sub_f * 0.50).to_bits(), 1739 min_sub.to_bits() * 0 1740 ); 1741 assert_eq!( 1742 bf16::from_f64(min_sub_f * 0.51).to_bits(), 1743 min_sub.to_bits() * 1 1744 ); 1745 1746 // 0.0000001_011111 rounded to 0.0000001 (< tie, no rounding) 1747 // 0.0000001_100000 rounded to 0.0000010 (tie and odd, rounds up to even) 1748 // 0.0000001_100001 rounded to 0.0000010 (> tie, rounds up) 1749 assert_eq!( 1750 bf16::from_f64(min_sub_f * 1.49).to_bits(), 1751 min_sub.to_bits() * 1 1752 ); 1753 assert_eq!( 1754 bf16::from_f64(min_sub_f * 1.50).to_bits(), 1755 min_sub.to_bits() * 2 1756 ); 1757 assert_eq!( 1758 bf16::from_f64(min_sub_f * 1.51).to_bits(), 1759 min_sub.to_bits() * 2 1760 ); 1761 1762 // 0.0000010_011111 rounded to 0.0000010 (< tie, no rounding) 1763 // 0.0000010_100000 rounded to 0.0000010 (tie and even, remains at even) 1764 // 0.0000010_100001 rounded to 0.0000011 (> tie, rounds up) 1765 assert_eq!( 1766 bf16::from_f64(min_sub_f * 2.49).to_bits(), 1767 min_sub.to_bits() * 2 1768 ); 1769 assert_eq!( 1770 bf16::from_f64(min_sub_f * 2.50).to_bits(), 1771 min_sub.to_bits() * 2 1772 ); 1773 assert_eq!( 1774 bf16::from_f64(min_sub_f * 2.51).to_bits(), 1775 min_sub.to_bits() * 3 1776 ); 1777 1778 assert_eq!( 1779 bf16::from_f64(250.49f64).to_bits(), 1780 bf16::from_f64(250.0).to_bits() 1781 ); 1782 assert_eq!( 1783 bf16::from_f64(250.50f64).to_bits(), 1784 bf16::from_f64(250.0).to_bits() 1785 ); 1786 assert_eq!( 1787 bf16::from_f64(250.51f64).to_bits(), 1788 bf16::from_f64(251.0).to_bits() 1789 ); 1790 assert_eq!( 1791 bf16::from_f64(251.49f64).to_bits(), 1792 bf16::from_f64(251.0).to_bits() 1793 ); 1794 assert_eq!( 1795 bf16::from_f64(251.50f64).to_bits(), 1796 bf16::from_f64(252.0).to_bits() 1797 ); 1798 assert_eq!( 1799 bf16::from_f64(251.51f64).to_bits(), 1800 bf16::from_f64(252.0).to_bits() 1801 ); 1802 assert_eq!( 1803 bf16::from_f64(252.49f64).to_bits(), 1804 bf16::from_f64(252.0).to_bits() 1805 ); 1806 assert_eq!( 1807 bf16::from_f64(252.50f64).to_bits(), 1808 bf16::from_f64(252.0).to_bits() 1809 ); 1810 assert_eq!( 1811 bf16::from_f64(252.51f64).to_bits(), 1812 bf16::from_f64(253.0).to_bits() 1813 ); 1814 } 1815 1816 impl quickcheck::Arbitrary for bf16 { arbitrary(g: &mut quickcheck::Gen) -> Self1817 fn arbitrary(g: &mut quickcheck::Gen) -> Self { 1818 bf16(u16::arbitrary(g)) 1819 } 1820 } 1821 1822 #[quickcheck] qc_roundtrip_bf16_f32_is_identity(f: bf16) -> bool1823 fn qc_roundtrip_bf16_f32_is_identity(f: bf16) -> bool { 1824 let roundtrip = bf16::from_f32(f.to_f32()); 1825 if f.is_nan() { 1826 roundtrip.is_nan() && f.is_sign_negative() == roundtrip.is_sign_negative() 1827 } else { 1828 f.0 == roundtrip.0 1829 } 1830 } 1831 1832 #[quickcheck] qc_roundtrip_bf16_f64_is_identity(f: bf16) -> bool1833 fn qc_roundtrip_bf16_f64_is_identity(f: bf16) -> bool { 1834 let roundtrip = bf16::from_f64(f.to_f64()); 1835 if f.is_nan() { 1836 roundtrip.is_nan() && f.is_sign_negative() == roundtrip.is_sign_negative() 1837 } else { 1838 f.0 == roundtrip.0 1839 } 1840 } 1841 } 1842