1 /// Fused multiply-add. Computes `(self * a) + b` with only one rounding
2 /// error, yielding a more accurate result than an unfused multiply-add.
3 ///
4 /// Using `mul_add` can be more performant than an unfused multiply-add if
5 /// the target architecture has a dedicated `fma` CPU instruction.
6 ///
7 /// Note that `A` and `B` are `Self` by default, but this is not mandatory.
8 ///
9 /// # Example
10 ///
11 /// ```
12 /// use std::f32;
13 ///
14 /// let m = 10.0_f32;
15 /// let x = 4.0_f32;
16 /// let b = 60.0_f32;
17 ///
18 /// // 100.0
19 /// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs();
20 ///
21 /// assert!(abs_difference <= 100.0 * f32::EPSILON);
22 /// ```
23 pub trait MulAdd<A = Self, B = Self> {
24     /// The resulting type after applying the fused multiply-add.
25     type Output;
26 
27     /// Performs the fused multiply-add operation `(self * a) + b`
mul_add(self, a: A, b: B) -> Self::Output28     fn mul_add(self, a: A, b: B) -> Self::Output;
29 }
30 
31 /// The fused multiply-add assignment operation `*self = (*self * a) + b`
32 pub trait MulAddAssign<A = Self, B = Self> {
33     /// Performs the fused multiply-add assignment operation `*self = (*self * a) + b`
mul_add_assign(&mut self, a: A, b: B)34     fn mul_add_assign(&mut self, a: A, b: B);
35 }
36 
37 #[cfg(any(feature = "std", feature = "libm"))]
38 impl MulAdd<f32, f32> for f32 {
39     type Output = Self;
40 
41     #[inline]
mul_add(self, a: Self, b: Self) -> Self::Output42     fn mul_add(self, a: Self, b: Self) -> Self::Output {
43         <Self as crate::Float>::mul_add(self, a, b)
44     }
45 }
46 
47 #[cfg(any(feature = "std", feature = "libm"))]
48 impl MulAdd<f64, f64> for f64 {
49     type Output = Self;
50 
51     #[inline]
mul_add(self, a: Self, b: Self) -> Self::Output52     fn mul_add(self, a: Self, b: Self) -> Self::Output {
53         <Self as crate::Float>::mul_add(self, a, b)
54     }
55 }
56 
57 macro_rules! mul_add_impl {
58     ($trait_name:ident for $($t:ty)*) => {$(
59         impl $trait_name for $t {
60             type Output = Self;
61 
62             #[inline]
63             fn mul_add(self, a: Self, b: Self) -> Self::Output {
64                 (self * a) + b
65             }
66         }
67     )*}
68 }
69 
70 mul_add_impl!(MulAdd for isize i8 i16 i32 i64 i128);
71 mul_add_impl!(MulAdd for usize u8 u16 u32 u64 u128);
72 
73 #[cfg(any(feature = "std", feature = "libm"))]
74 impl MulAddAssign<f32, f32> for f32 {
75     #[inline]
mul_add_assign(&mut self, a: Self, b: Self)76     fn mul_add_assign(&mut self, a: Self, b: Self) {
77         *self = <Self as crate::Float>::mul_add(*self, a, b)
78     }
79 }
80 
81 #[cfg(any(feature = "std", feature = "libm"))]
82 impl MulAddAssign<f64, f64> for f64 {
83     #[inline]
mul_add_assign(&mut self, a: Self, b: Self)84     fn mul_add_assign(&mut self, a: Self, b: Self) {
85         *self = <Self as crate::Float>::mul_add(*self, a, b)
86     }
87 }
88 
89 macro_rules! mul_add_assign_impl {
90     ($trait_name:ident for $($t:ty)*) => {$(
91         impl $trait_name for $t {
92             #[inline]
93             fn mul_add_assign(&mut self, a: Self, b: Self) {
94                 *self = (*self * a) + b
95             }
96         }
97     )*}
98 }
99 
100 mul_add_assign_impl!(MulAddAssign for isize i8 i16 i32 i64 i128);
101 mul_add_assign_impl!(MulAddAssign for usize u8 u16 u32 u64 u128);
102 
103 #[cfg(test)]
104 mod tests {
105     use super::*;
106 
107     #[test]
mul_add_integer()108     fn mul_add_integer() {
109         macro_rules! test_mul_add {
110             ($($t:ident)+) => {
111                 $(
112                     {
113                         let m: $t = 2;
114                         let x: $t = 3;
115                         let b: $t = 4;
116 
117                         assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b));
118                     }
119                 )+
120             };
121         }
122 
123         test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64);
124     }
125 
126     #[test]
127     #[cfg(feature = "std")]
mul_add_float()128     fn mul_add_float() {
129         macro_rules! test_mul_add {
130             ($($t:ident)+) => {
131                 $(
132                     {
133                         use core::$t;
134 
135                         let m: $t = 12.0;
136                         let x: $t = 3.4;
137                         let b: $t = 5.6;
138 
139                         let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs();
140 
141                         assert!(abs_difference <= 46.4 * $t::EPSILON);
142                     }
143                 )+
144             };
145         }
146 
147         test_mul_add!(f32 f64);
148     }
149 }
150