xref: /aosp_15_r20/bootable/libbootloader/gbl/libsafemath/src/lib.rs (revision 5225e6b173e52d2efc6bcf950c27374fd72adabc)
1 // Copyright (C) 2024 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! # safemath library
16 //!
17 //! This library provides an API to safely work with unsigned integers. At a high level, all math
18 //! operations are checked by default rather than having to remember to call specific `checked_*`
19 //! functions, so that the burden is on the programmer if they want to perform unchecked math
20 //! rather than the other way around:
21 //!
22 //! ```
23 //! use safemath::SafeNum;
24 //!
25 //! let safe = SafeNum::from(0);
26 //! let result = safe - 1;
27 //! assert!(u32::try_from(result).is_err());
28 //!
29 //! let safe_chain = (SafeNum::from(BIG_NUMBER) * HUGE_NUMBER) / MAYBE_ZERO;
30 //! // If any operation would have caused an overflow or division by zero,
31 //! // the number is flagged and the lexical location is specified for logging.
32 //! if safe_chain.has_error() {
33 //!     eprintln!("safe_chain error = {:#?}", safe_chain);
34 //! }
35 //! ```
36 //!
37 //! In addition to checked-by-default arithmetic, the API exposed here support
38 //! more natural usage than the `checked_*` functions by allowing chaining
39 //! of operations without having to check the result at each step.
40 //! This is similar to how floating-point `NaN` works - you can continue to use the
41 //! value, but continued operations will just propagate `NaN`.
42 //!
43 //! ## Supported Operations
44 //!
45 //! ### Arithmetic
46 //! The basic arithmetic operations are supported:
47 //! addition, subtraction, multiplication, division, and remainder.
48 //! The right hand side may be another SafeNum or any integer,
49 //! and the result is always another SafeNum.
50 //! If the operation would result in an overflow or division by zero,
51 //! or if converting the right hand element to a `u64` would cause an error,
52 //! the result is an error-tagged SafeNum that tracks the lexical origin of the error.
53 //!
54 //! ### Conversion from and to SafeNum
55 //! SafeNums support conversion to and from all integer types.
56 //! Conversion to SafeNum from signed integers and from usize and u128
57 //! can fail, generating an error value that is then propagated.
58 //! Conversion from SafeNum to all integers is only exposed via `try_from`
59 //! in order to force the user to handle potential resultant errors.
60 //!
61 //! E.g.
62 //! ```
63 //! fn call_func(_: u32, _: u32) {
64 //! }
65 //!
66 //! fn do_a_thing(a: SafeNum) -> Result<(), safemath::Error> {
67 //!     call_func(16, a.try_into()?);
68 //!     Ok(())
69 //! }
70 //! ```
71 //!
72 //! ### Comparison
73 //! SafeNums can be checked for equality against each other.
74 //! Valid numbers are equal to other numbers of the same magnitude.
75 //! Errored SafeNums are only equal to themselves.
76 //! Note that because errors propagate from their first introduction in an
77 //! arithmetic chain this can lead to surprising results.
78 //!
79 //! E.g.
80 //! ```
81 //! let overflow = SafeNum::MAX + 1;
82 //! let otherflow = SafeNum::MAX + 1;
83 //!
84 //! assert_ne!(overflow, otherflow);
85 //! assert_eq!(overflow + otherflow, overflow);
86 //! assert_eq!(otherflow + overflow, otherflow);
87 //! ```
88 //!
89 //! Inequality comparison operators are deliberately not provided.
90 //! By necessity they would have similar caveats to floating point comparisons,
91 //! which are easy to use incorrectly and unintuitive to use correctly.
92 //!
93 //! The required alternative is to convert to a real integer type before comparing,
94 //! forcing any errors upwards.
95 //!
96 //! E.g.
97 //! ```
98 //! impl From<safemath::Error> for &'static str {
99 //!     fn from(_: safemath::Error) -> Self {
100 //!         "checked arithmetic error"
101 //!     }
102 //! }
103 //!
104 //! fn my_op(a: SafeNum, b: SafeNum, c: SafeNum, d: SafeNum) -> Result<bool, &'static str> {
105 //!     Ok(safemath::Primitive::try_from(a)? < b.try_into()?
106 //!        && safemath::Primitive::try_from(c)? >= d.try_into()?)
107 //! }
108 //! ```
109 //!
110 //! ### Miscellaneous
111 //! SafeNums also provide helper methods to round up or down
112 //! to the nearest multiple of another number
113 //! and helper predicate methods that indicate whether the SafeNum
114 //! is valid or is tracking an error.
115 //!
116 //! Also provided are constants `SafeNum::MAX`, `SafeNum::MIN`, and `SafeNum::ZERO`.
117 //!
118 //! Warning: SafeNums can help prevent, isolate, and detect arithmetic overflow
119 //!          but they are not a panacea. In particular, chains of different operations
120 //!          are not guaranteed to be associative or commutative.
121 //!
122 //! E.g.
123 //! ```
124 //! let a = SafeNum::MAX - 1 + 1;
125 //! let b = SafeNum::MAX + 1 - 1;
126 //! assert_ne!(a, b);
127 //! assert!(a.is_valid());
128 //! assert!(b.has_error());
129 //!
130 //! let c = (SafeNum::MAX + 31) / 31;
131 //! let d = SafeNum::MAX / 31 + 31 / 31;
132 //! assert_ne!(c, d);
133 //! assert!(c.has_error());
134 //! assert!(d.is_valid());
135 //! ```
136 //!
137 //! Note:    SafeNum arithmetic is much slower than arithmetic on integer primitives.
138 //!          If you are concerned about performance, be sure to run benchmarks.
139 
140 #![cfg_attr(not(test), no_std)]
141 
142 use core::convert::TryFrom;
143 use core::fmt;
144 use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Sub, SubAssign};
145 use core::panic::Location;
146 
147 /// The underlying primitive type used for [SafeNum] operations.
148 pub type Primitive = u64;
149 /// Safe math error type, which points to the location of the original failed operation.
150 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
151 pub struct Error(&'static Location<'static>);
152 
153 impl From<&'static Location<'static>> for Error {
from(loc: &'static Location<'static>) -> Self154     fn from(loc: &'static Location<'static>) -> Self {
155         Self(loc)
156     }
157 }
158 
159 impl From<Error> for &'static Location<'static> {
from(err: Error) -> Self160     fn from(err: Error) -> Self {
161         err.0
162     }
163 }
164 
165 impl From<core::num::TryFromIntError> for Error {
166     #[track_caller]
from(_err: core::num::TryFromIntError) -> Self167     fn from(_err: core::num::TryFromIntError) -> Self {
168         Self(Location::caller())
169     }
170 }
171 
172 /// Wraps a raw [Primitive] type for safe-by-default math. See module docs for info and usage.
173 #[derive(Copy, Clone, PartialEq, Eq)]
174 pub struct SafeNum(Result<Primitive, Error>);
175 
176 impl fmt::Debug for SafeNum {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result177     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178         match self.0 {
179             Ok(val) => write!(f, "{}", val),
180             Err(location) => write!(f, "error at {}", location),
181         }
182     }
183 }
184 
185 impl fmt::Display for Error {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result186     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187         self.0.fmt(f)
188     }
189 }
190 
191 impl SafeNum {
192     /// The maximum [SafeNum].
193     pub const MAX: SafeNum = SafeNum(Ok(u64::MAX));
194     /// The minimum [SafeNum].
195     pub const MIN: SafeNum = SafeNum(Ok(u64::MIN));
196     /// Zero as a [SafeNum].
197     pub const ZERO: SafeNum = SafeNum(Ok(0));
198 
199     /// Round `self` down to the nearest multiple of `rhs`.
200     #[track_caller]
round_down<T>(self, rhs: T) -> Self where Self: Rem<T, Output = Self>,201     pub fn round_down<T>(self, rhs: T) -> Self
202     where
203         Self: Rem<T, Output = Self>,
204     {
205         self - (self % rhs)
206     }
207 
208     /// Round `self` up to the nearest multiple of `rhs`.
209     #[track_caller]
round_up<T>(self, rhs: T) -> Self where Self: Add<T, Output = Self>, T: Copy + Into<Self>,210     pub fn round_up<T>(self, rhs: T) -> Self
211     where
212         Self: Add<T, Output = Self>,
213         T: Copy + Into<Self>,
214     {
215         ((self + rhs) - 1).round_down(rhs)
216     }
217 
218     /// Returns whether self is the result of an operation that has errored.
has_error(&self) -> bool219     pub const fn has_error(&self) -> bool {
220         self.0.is_err()
221     }
222 
223     /// Returns whether self represents a valid, non-overflowed integer.
is_valid(&self) -> bool224     pub const fn is_valid(&self) -> bool {
225         self.0.is_ok()
226     }
227 }
228 
229 macro_rules! try_conversion_func {
230     ($other_type:tt) => {
231         impl TryFrom<SafeNum> for $other_type {
232             type Error = Error;
233 
234             #[track_caller]
235             fn try_from(val: SafeNum) -> Result<Self, Self::Error> {
236                 Self::try_from(val.0?).map_err(|_| Location::caller().into())
237             }
238         }
239     };
240 }
241 
242 macro_rules! conversion_func {
243     ($from_type:tt) => {
244         impl From<$from_type> for SafeNum {
245             fn from(val: $from_type) -> SafeNum {
246                 Self(Ok(val.into()))
247             }
248         }
249 
250         try_conversion_func!($from_type);
251     };
252 }
253 
254 macro_rules! conversion_func_maybe_error {
255     ($from_type:tt) => {
256         impl From<$from_type> for SafeNum {
257             #[track_caller]
258             fn from(val: $from_type) -> Self {
259                 Self(Primitive::try_from(val).map_err(|_| Location::caller().into()))
260             }
261         }
262 
263         try_conversion_func!($from_type);
264     };
265 }
266 
267 macro_rules! arithmetic_impl {
268     ($trait_name:ident, $op:ident, $assign_trait_name:ident, $assign_op:ident, $func:ident) => {
269         impl<T: Into<SafeNum>> $trait_name<T> for SafeNum {
270             type Output = Self;
271             #[track_caller]
272             fn $op(self, rhs: T) -> Self {
273                 let rhs: Self = rhs.into();
274 
275                 match (self.0, rhs.0) {
276                     (Err(_), _) => self,
277                     (_, Err(_)) => rhs,
278                     (Ok(lhs), Ok(rhs)) => {
279                         Self(lhs.$func(rhs).ok_or_else(|| Location::caller().into()))
280                     }
281                 }
282             }
283         }
284 
285         impl<T> $assign_trait_name<T> for SafeNum
286         where
287             Self: $trait_name<T, Output = Self>,
288         {
289             #[track_caller]
290             fn $assign_op(&mut self, rhs: T) {
291                 *self = self.$op(rhs)
292             }
293         }
294     };
295 }
296 
297 conversion_func!(u8);
298 conversion_func!(u16);
299 conversion_func!(u32);
300 conversion_func!(u64);
301 conversion_func_maybe_error!(usize);
302 conversion_func_maybe_error!(u128);
303 conversion_func_maybe_error!(i8);
304 conversion_func_maybe_error!(i16);
305 conversion_func_maybe_error!(i32);
306 conversion_func_maybe_error!(i64);
307 conversion_func_maybe_error!(i128);
308 conversion_func_maybe_error!(isize);
309 arithmetic_impl!(Add, add, AddAssign, add_assign, checked_add);
310 arithmetic_impl!(Sub, sub, SubAssign, sub_assign, checked_sub);
311 arithmetic_impl!(Mul, mul, MulAssign, mul_assign, checked_mul);
312 arithmetic_impl!(Div, div, DivAssign, div_assign, checked_div);
313 arithmetic_impl!(Rem, rem, RemAssign, rem_assign, checked_rem);
314 
315 #[cfg(test)]
316 mod test {
317     use super::*;
318 
319     #[test]
test_addition()320     fn test_addition() {
321         let a: SafeNum = 2100.into();
322         let b: SafeNum = 12.into();
323         assert_eq!(a + b, 2112.into());
324     }
325 
326     #[test]
test_subtraction()327     fn test_subtraction() {
328         let a: SafeNum = 667.into();
329         let b: SafeNum = 1.into();
330         assert_eq!(a - b, 666.into());
331     }
332 
333     #[test]
test_multiplication()334     fn test_multiplication() {
335         let a: SafeNum = 17.into();
336         let b: SafeNum = 3.into();
337         assert_eq!(a * b, 51.into());
338     }
339 
340     #[test]
test_division()341     fn test_division() {
342         let a: SafeNum = 1066.into();
343         let b: SafeNum = 41.into();
344         assert_eq!(a / b, 26.into());
345     }
346 
347     #[test]
test_remainder()348     fn test_remainder() {
349         let a: SafeNum = 613.into();
350         let b: SafeNum = 10.into();
351         assert_eq!(a % b, 3.into());
352     }
353 
354     #[test]
test_addition_poison()355     fn test_addition_poison() {
356         let base: SafeNum = 2.into();
357         let poison = base + SafeNum::MAX;
358         assert!(u64::try_from(poison).is_err());
359 
360         let a = poison - 1;
361         let b = poison - 2;
362 
363         assert_eq!(a, poison);
364         assert_eq!(b, poison);
365     }
366 
367     #[test]
test_subtraction_poison()368     fn test_subtraction_poison() {
369         let base: SafeNum = 2.into();
370         let poison = base - SafeNum::MAX;
371         assert!(u64::try_from(poison).is_err());
372 
373         let a = poison + 1;
374         let b = poison + 2;
375 
376         assert_eq!(a, poison);
377         assert_eq!(b, poison);
378     }
379 
380     #[test]
test_multiplication_poison()381     fn test_multiplication_poison() {
382         let base: SafeNum = 2.into();
383         let poison = base * SafeNum::MAX;
384         assert!(u64::try_from(poison).is_err());
385 
386         let a = poison / 2;
387         let b = poison / 4;
388 
389         assert_eq!(a, poison);
390         assert_eq!(b, poison);
391     }
392 
393     #[test]
test_division_poison()394     fn test_division_poison() {
395         let base: SafeNum = 2.into();
396         let poison = base / 0;
397         assert!(u64::try_from(poison).is_err());
398 
399         let a = poison * 2;
400         let b = poison * 4;
401 
402         assert_eq!(a, poison);
403         assert_eq!(b, poison);
404     }
405 
406     #[test]
test_remainder_poison()407     fn test_remainder_poison() {
408         let base: SafeNum = 2.into();
409         let poison = base % 0;
410         assert!(u64::try_from(poison).is_err());
411 
412         let a = poison * 2;
413         let b = poison * 4;
414 
415         assert_eq!(a, poison);
416         assert_eq!(b, poison);
417     }
418 
419     macro_rules! conversion_test {
420         ($name:ident) => {
421             mod $name {
422                 use super::*;
423                 use core::convert::TryInto;
424 
425                 #[test]
426                 fn test_between_safenum() {
427                     let var: $name = 16;
428                     let sn: SafeNum = var.into();
429                     let res: $name = sn.try_into().unwrap();
430                     assert_eq!(var, res);
431                 }
432 
433                 #[test]
434                 fn test_arithmetic_safenum() {
435                     let primitive: $name = ((((0 + 11) * 11) / 3) % 32) - 3;
436                     let safe = ((((SafeNum::ZERO + $name::try_from(11u8).unwrap())
437                         * $name::try_from(11u8).unwrap())
438                         / $name::try_from(3u8).unwrap())
439                         % $name::try_from(32u8).unwrap())
440                         - $name::try_from(3u8).unwrap();
441                     assert_eq!($name::try_from(safe).unwrap(), primitive);
442                 }
443             }
444         };
445     }
446 
447     conversion_test!(u8);
448     conversion_test!(u16);
449     conversion_test!(u32);
450     conversion_test!(u64);
451     conversion_test!(u128);
452     conversion_test!(usize);
453     conversion_test!(i8);
454     conversion_test!(i16);
455     conversion_test!(i32);
456     conversion_test!(i64);
457     conversion_test!(i128);
458     conversion_test!(isize);
459 
460     macro_rules! correctness_tests {
461         ($name:ident, $operation:ident, $assign_operation:ident) => {
462             mod $operation {
463                 use super::*;
464                 use core::ops::$name;
465 
466                 #[test]
467                 fn test_correctness() {
468                     let normal = 300u64;
469                     let safe: SafeNum = normal.into();
470                     let rhs = 7u64;
471                     assert_eq!(
472                         u64::try_from(safe.$operation(rhs)).unwrap(),
473                         normal.$operation(rhs)
474                     );
475                 }
476 
477                 #[test]
478                 fn test_assign() {
479                     let mut var: SafeNum = 2112.into();
480                     let rhs = 666u64;
481                     let expect = var.$operation(rhs);
482                     var.$assign_operation(rhs);
483                     assert_eq!(var, expect);
484                 }
485 
486                 #[test]
487                 fn test_assign_poison() {
488                     let mut var = SafeNum::MIN - 1;
489                     let expected = var - 1;
490                     var.$assign_operation(2);
491                     // Poison saturates and doesn't perform additional changes
492                     assert_eq!(var, expected);
493                 }
494             }
495         };
496     }
497 
498     correctness_tests!(Add, add, add_assign);
499     correctness_tests!(Sub, sub, sub_assign);
500     correctness_tests!(Mul, mul, mul_assign);
501     correctness_tests!(Div, div, div_assign);
502     correctness_tests!(Rem, rem, rem_assign);
503 
504     #[test]
test_round_down()505     fn test_round_down() {
506         let x: SafeNum = 255.into();
507         assert_eq!(x.round_down(32), 224.into());
508         assert_eq!((x + 1).round_down(64), 256.into());
509         assert_eq!(x.round_down(256), SafeNum::ZERO);
510         assert!(x.round_down(SafeNum::MIN).has_error());
511     }
512 
513     #[test]
test_round_up()514     fn test_round_up() {
515         let x: SafeNum = 255.into();
516         assert_eq!(x.round_up(32), 256.into());
517         assert_eq!(x.round_up(51), x);
518         assert_eq!(SafeNum::ZERO.round_up(x), SafeNum::ZERO);
519         assert!(SafeNum::MAX.round_up(32).has_error());
520     }
521 }
522