1 // Copyright (C) 2024 The Android Open Source Project 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 //! # safemath library 16 //! 17 //! This library provides an API to safely work with unsigned integers. At a high level, all math 18 //! operations are checked by default rather than having to remember to call specific `checked_*` 19 //! functions, so that the burden is on the programmer if they want to perform unchecked math 20 //! rather than the other way around: 21 //! 22 //! ``` 23 //! use safemath::SafeNum; 24 //! 25 //! let safe = SafeNum::from(0); 26 //! let result = safe - 1; 27 //! assert!(u32::try_from(result).is_err()); 28 //! 29 //! let safe_chain = (SafeNum::from(BIG_NUMBER) * HUGE_NUMBER) / MAYBE_ZERO; 30 //! // If any operation would have caused an overflow or division by zero, 31 //! // the number is flagged and the lexical location is specified for logging. 32 //! if safe_chain.has_error() { 33 //! eprintln!("safe_chain error = {:#?}", safe_chain); 34 //! } 35 //! ``` 36 //! 37 //! In addition to checked-by-default arithmetic, the API exposed here support 38 //! more natural usage than the `checked_*` functions by allowing chaining 39 //! of operations without having to check the result at each step. 40 //! This is similar to how floating-point `NaN` works - you can continue to use the 41 //! value, but continued operations will just propagate `NaN`. 42 //! 43 //! ## Supported Operations 44 //! 45 //! ### Arithmetic 46 //! The basic arithmetic operations are supported: 47 //! addition, subtraction, multiplication, division, and remainder. 48 //! The right hand side may be another SafeNum or any integer, 49 //! and the result is always another SafeNum. 50 //! If the operation would result in an overflow or division by zero, 51 //! or if converting the right hand element to a `u64` would cause an error, 52 //! the result is an error-tagged SafeNum that tracks the lexical origin of the error. 53 //! 54 //! ### Conversion from and to SafeNum 55 //! SafeNums support conversion to and from all integer types. 56 //! Conversion to SafeNum from signed integers and from usize and u128 57 //! can fail, generating an error value that is then propagated. 58 //! Conversion from SafeNum to all integers is only exposed via `try_from` 59 //! in order to force the user to handle potential resultant errors. 60 //! 61 //! E.g. 62 //! ``` 63 //! fn call_func(_: u32, _: u32) { 64 //! } 65 //! 66 //! fn do_a_thing(a: SafeNum) -> Result<(), safemath::Error> { 67 //! call_func(16, a.try_into()?); 68 //! Ok(()) 69 //! } 70 //! ``` 71 //! 72 //! ### Comparison 73 //! SafeNums can be checked for equality against each other. 74 //! Valid numbers are equal to other numbers of the same magnitude. 75 //! Errored SafeNums are only equal to themselves. 76 //! Note that because errors propagate from their first introduction in an 77 //! arithmetic chain this can lead to surprising results. 78 //! 79 //! E.g. 80 //! ``` 81 //! let overflow = SafeNum::MAX + 1; 82 //! let otherflow = SafeNum::MAX + 1; 83 //! 84 //! assert_ne!(overflow, otherflow); 85 //! assert_eq!(overflow + otherflow, overflow); 86 //! assert_eq!(otherflow + overflow, otherflow); 87 //! ``` 88 //! 89 //! Inequality comparison operators are deliberately not provided. 90 //! By necessity they would have similar caveats to floating point comparisons, 91 //! which are easy to use incorrectly and unintuitive to use correctly. 92 //! 93 //! The required alternative is to convert to a real integer type before comparing, 94 //! forcing any errors upwards. 95 //! 96 //! E.g. 97 //! ``` 98 //! impl From<safemath::Error> for &'static str { 99 //! fn from(_: safemath::Error) -> Self { 100 //! "checked arithmetic error" 101 //! } 102 //! } 103 //! 104 //! fn my_op(a: SafeNum, b: SafeNum, c: SafeNum, d: SafeNum) -> Result<bool, &'static str> { 105 //! Ok(safemath::Primitive::try_from(a)? < b.try_into()? 106 //! && safemath::Primitive::try_from(c)? >= d.try_into()?) 107 //! } 108 //! ``` 109 //! 110 //! ### Miscellaneous 111 //! SafeNums also provide helper methods to round up or down 112 //! to the nearest multiple of another number 113 //! and helper predicate methods that indicate whether the SafeNum 114 //! is valid or is tracking an error. 115 //! 116 //! Also provided are constants `SafeNum::MAX`, `SafeNum::MIN`, and `SafeNum::ZERO`. 117 //! 118 //! Warning: SafeNums can help prevent, isolate, and detect arithmetic overflow 119 //! but they are not a panacea. In particular, chains of different operations 120 //! are not guaranteed to be associative or commutative. 121 //! 122 //! E.g. 123 //! ``` 124 //! let a = SafeNum::MAX - 1 + 1; 125 //! let b = SafeNum::MAX + 1 - 1; 126 //! assert_ne!(a, b); 127 //! assert!(a.is_valid()); 128 //! assert!(b.has_error()); 129 //! 130 //! let c = (SafeNum::MAX + 31) / 31; 131 //! let d = SafeNum::MAX / 31 + 31 / 31; 132 //! assert_ne!(c, d); 133 //! assert!(c.has_error()); 134 //! assert!(d.is_valid()); 135 //! ``` 136 //! 137 //! Note: SafeNum arithmetic is much slower than arithmetic on integer primitives. 138 //! If you are concerned about performance, be sure to run benchmarks. 139 140 #![cfg_attr(not(test), no_std)] 141 142 use core::convert::TryFrom; 143 use core::fmt; 144 use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Sub, SubAssign}; 145 use core::panic::Location; 146 147 /// The underlying primitive type used for [SafeNum] operations. 148 pub type Primitive = u64; 149 /// Safe math error type, which points to the location of the original failed operation. 150 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] 151 pub struct Error(&'static Location<'static>); 152 153 impl From<&'static Location<'static>> for Error { from(loc: &'static Location<'static>) -> Self154 fn from(loc: &'static Location<'static>) -> Self { 155 Self(loc) 156 } 157 } 158 159 impl From<Error> for &'static Location<'static> { from(err: Error) -> Self160 fn from(err: Error) -> Self { 161 err.0 162 } 163 } 164 165 impl From<core::num::TryFromIntError> for Error { 166 #[track_caller] from(_err: core::num::TryFromIntError) -> Self167 fn from(_err: core::num::TryFromIntError) -> Self { 168 Self(Location::caller()) 169 } 170 } 171 172 /// Wraps a raw [Primitive] type for safe-by-default math. See module docs for info and usage. 173 #[derive(Copy, Clone, PartialEq, Eq)] 174 pub struct SafeNum(Result<Primitive, Error>); 175 176 impl fmt::Debug for SafeNum { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 178 match self.0 { 179 Ok(val) => write!(f, "{}", val), 180 Err(location) => write!(f, "error at {}", location), 181 } 182 } 183 } 184 185 impl fmt::Display for Error { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 187 self.0.fmt(f) 188 } 189 } 190 191 impl SafeNum { 192 /// The maximum [SafeNum]. 193 pub const MAX: SafeNum = SafeNum(Ok(u64::MAX)); 194 /// The minimum [SafeNum]. 195 pub const MIN: SafeNum = SafeNum(Ok(u64::MIN)); 196 /// Zero as a [SafeNum]. 197 pub const ZERO: SafeNum = SafeNum(Ok(0)); 198 199 /// Round `self` down to the nearest multiple of `rhs`. 200 #[track_caller] round_down<T>(self, rhs: T) -> Self where Self: Rem<T, Output = Self>,201 pub fn round_down<T>(self, rhs: T) -> Self 202 where 203 Self: Rem<T, Output = Self>, 204 { 205 self - (self % rhs) 206 } 207 208 /// Round `self` up to the nearest multiple of `rhs`. 209 #[track_caller] round_up<T>(self, rhs: T) -> Self where Self: Add<T, Output = Self>, T: Copy + Into<Self>,210 pub fn round_up<T>(self, rhs: T) -> Self 211 where 212 Self: Add<T, Output = Self>, 213 T: Copy + Into<Self>, 214 { 215 ((self + rhs) - 1).round_down(rhs) 216 } 217 218 /// Returns whether self is the result of an operation that has errored. has_error(&self) -> bool219 pub const fn has_error(&self) -> bool { 220 self.0.is_err() 221 } 222 223 /// Returns whether self represents a valid, non-overflowed integer. is_valid(&self) -> bool224 pub const fn is_valid(&self) -> bool { 225 self.0.is_ok() 226 } 227 } 228 229 macro_rules! try_conversion_func { 230 ($other_type:tt) => { 231 impl TryFrom<SafeNum> for $other_type { 232 type Error = Error; 233 234 #[track_caller] 235 fn try_from(val: SafeNum) -> Result<Self, Self::Error> { 236 Self::try_from(val.0?).map_err(|_| Location::caller().into()) 237 } 238 } 239 }; 240 } 241 242 macro_rules! conversion_func { 243 ($from_type:tt) => { 244 impl From<$from_type> for SafeNum { 245 fn from(val: $from_type) -> SafeNum { 246 Self(Ok(val.into())) 247 } 248 } 249 250 try_conversion_func!($from_type); 251 }; 252 } 253 254 macro_rules! conversion_func_maybe_error { 255 ($from_type:tt) => { 256 impl From<$from_type> for SafeNum { 257 #[track_caller] 258 fn from(val: $from_type) -> Self { 259 Self(Primitive::try_from(val).map_err(|_| Location::caller().into())) 260 } 261 } 262 263 try_conversion_func!($from_type); 264 }; 265 } 266 267 macro_rules! arithmetic_impl { 268 ($trait_name:ident, $op:ident, $assign_trait_name:ident, $assign_op:ident, $func:ident) => { 269 impl<T: Into<SafeNum>> $trait_name<T> for SafeNum { 270 type Output = Self; 271 #[track_caller] 272 fn $op(self, rhs: T) -> Self { 273 let rhs: Self = rhs.into(); 274 275 match (self.0, rhs.0) { 276 (Err(_), _) => self, 277 (_, Err(_)) => rhs, 278 (Ok(lhs), Ok(rhs)) => { 279 Self(lhs.$func(rhs).ok_or_else(|| Location::caller().into())) 280 } 281 } 282 } 283 } 284 285 impl<T> $assign_trait_name<T> for SafeNum 286 where 287 Self: $trait_name<T, Output = Self>, 288 { 289 #[track_caller] 290 fn $assign_op(&mut self, rhs: T) { 291 *self = self.$op(rhs) 292 } 293 } 294 }; 295 } 296 297 conversion_func!(u8); 298 conversion_func!(u16); 299 conversion_func!(u32); 300 conversion_func!(u64); 301 conversion_func_maybe_error!(usize); 302 conversion_func_maybe_error!(u128); 303 conversion_func_maybe_error!(i8); 304 conversion_func_maybe_error!(i16); 305 conversion_func_maybe_error!(i32); 306 conversion_func_maybe_error!(i64); 307 conversion_func_maybe_error!(i128); 308 conversion_func_maybe_error!(isize); 309 arithmetic_impl!(Add, add, AddAssign, add_assign, checked_add); 310 arithmetic_impl!(Sub, sub, SubAssign, sub_assign, checked_sub); 311 arithmetic_impl!(Mul, mul, MulAssign, mul_assign, checked_mul); 312 arithmetic_impl!(Div, div, DivAssign, div_assign, checked_div); 313 arithmetic_impl!(Rem, rem, RemAssign, rem_assign, checked_rem); 314 315 #[cfg(test)] 316 mod test { 317 use super::*; 318 319 #[test] test_addition()320 fn test_addition() { 321 let a: SafeNum = 2100.into(); 322 let b: SafeNum = 12.into(); 323 assert_eq!(a + b, 2112.into()); 324 } 325 326 #[test] test_subtraction()327 fn test_subtraction() { 328 let a: SafeNum = 667.into(); 329 let b: SafeNum = 1.into(); 330 assert_eq!(a - b, 666.into()); 331 } 332 333 #[test] test_multiplication()334 fn test_multiplication() { 335 let a: SafeNum = 17.into(); 336 let b: SafeNum = 3.into(); 337 assert_eq!(a * b, 51.into()); 338 } 339 340 #[test] test_division()341 fn test_division() { 342 let a: SafeNum = 1066.into(); 343 let b: SafeNum = 41.into(); 344 assert_eq!(a / b, 26.into()); 345 } 346 347 #[test] test_remainder()348 fn test_remainder() { 349 let a: SafeNum = 613.into(); 350 let b: SafeNum = 10.into(); 351 assert_eq!(a % b, 3.into()); 352 } 353 354 #[test] test_addition_poison()355 fn test_addition_poison() { 356 let base: SafeNum = 2.into(); 357 let poison = base + SafeNum::MAX; 358 assert!(u64::try_from(poison).is_err()); 359 360 let a = poison - 1; 361 let b = poison - 2; 362 363 assert_eq!(a, poison); 364 assert_eq!(b, poison); 365 } 366 367 #[test] test_subtraction_poison()368 fn test_subtraction_poison() { 369 let base: SafeNum = 2.into(); 370 let poison = base - SafeNum::MAX; 371 assert!(u64::try_from(poison).is_err()); 372 373 let a = poison + 1; 374 let b = poison + 2; 375 376 assert_eq!(a, poison); 377 assert_eq!(b, poison); 378 } 379 380 #[test] test_multiplication_poison()381 fn test_multiplication_poison() { 382 let base: SafeNum = 2.into(); 383 let poison = base * SafeNum::MAX; 384 assert!(u64::try_from(poison).is_err()); 385 386 let a = poison / 2; 387 let b = poison / 4; 388 389 assert_eq!(a, poison); 390 assert_eq!(b, poison); 391 } 392 393 #[test] test_division_poison()394 fn test_division_poison() { 395 let base: SafeNum = 2.into(); 396 let poison = base / 0; 397 assert!(u64::try_from(poison).is_err()); 398 399 let a = poison * 2; 400 let b = poison * 4; 401 402 assert_eq!(a, poison); 403 assert_eq!(b, poison); 404 } 405 406 #[test] test_remainder_poison()407 fn test_remainder_poison() { 408 let base: SafeNum = 2.into(); 409 let poison = base % 0; 410 assert!(u64::try_from(poison).is_err()); 411 412 let a = poison * 2; 413 let b = poison * 4; 414 415 assert_eq!(a, poison); 416 assert_eq!(b, poison); 417 } 418 419 macro_rules! conversion_test { 420 ($name:ident) => { 421 mod $name { 422 use super::*; 423 use core::convert::TryInto; 424 425 #[test] 426 fn test_between_safenum() { 427 let var: $name = 16; 428 let sn: SafeNum = var.into(); 429 let res: $name = sn.try_into().unwrap(); 430 assert_eq!(var, res); 431 } 432 433 #[test] 434 fn test_arithmetic_safenum() { 435 let primitive: $name = ((((0 + 11) * 11) / 3) % 32) - 3; 436 let safe = ((((SafeNum::ZERO + $name::try_from(11u8).unwrap()) 437 * $name::try_from(11u8).unwrap()) 438 / $name::try_from(3u8).unwrap()) 439 % $name::try_from(32u8).unwrap()) 440 - $name::try_from(3u8).unwrap(); 441 assert_eq!($name::try_from(safe).unwrap(), primitive); 442 } 443 } 444 }; 445 } 446 447 conversion_test!(u8); 448 conversion_test!(u16); 449 conversion_test!(u32); 450 conversion_test!(u64); 451 conversion_test!(u128); 452 conversion_test!(usize); 453 conversion_test!(i8); 454 conversion_test!(i16); 455 conversion_test!(i32); 456 conversion_test!(i64); 457 conversion_test!(i128); 458 conversion_test!(isize); 459 460 macro_rules! correctness_tests { 461 ($name:ident, $operation:ident, $assign_operation:ident) => { 462 mod $operation { 463 use super::*; 464 use core::ops::$name; 465 466 #[test] 467 fn test_correctness() { 468 let normal = 300u64; 469 let safe: SafeNum = normal.into(); 470 let rhs = 7u64; 471 assert_eq!( 472 u64::try_from(safe.$operation(rhs)).unwrap(), 473 normal.$operation(rhs) 474 ); 475 } 476 477 #[test] 478 fn test_assign() { 479 let mut var: SafeNum = 2112.into(); 480 let rhs = 666u64; 481 let expect = var.$operation(rhs); 482 var.$assign_operation(rhs); 483 assert_eq!(var, expect); 484 } 485 486 #[test] 487 fn test_assign_poison() { 488 let mut var = SafeNum::MIN - 1; 489 let expected = var - 1; 490 var.$assign_operation(2); 491 // Poison saturates and doesn't perform additional changes 492 assert_eq!(var, expected); 493 } 494 } 495 }; 496 } 497 498 correctness_tests!(Add, add, add_assign); 499 correctness_tests!(Sub, sub, sub_assign); 500 correctness_tests!(Mul, mul, mul_assign); 501 correctness_tests!(Div, div, div_assign); 502 correctness_tests!(Rem, rem, rem_assign); 503 504 #[test] test_round_down()505 fn test_round_down() { 506 let x: SafeNum = 255.into(); 507 assert_eq!(x.round_down(32), 224.into()); 508 assert_eq!((x + 1).round_down(64), 256.into()); 509 assert_eq!(x.round_down(256), SafeNum::ZERO); 510 assert!(x.round_down(SafeNum::MIN).has_error()); 511 } 512 513 #[test] test_round_up()514 fn test_round_up() { 515 let x: SafeNum = 255.into(); 516 assert_eq!(x.round_up(32), 256.into()); 517 assert_eq!(x.round_up(51), x); 518 assert_eq!(SafeNum::ZERO.round_up(x), SafeNum::ZERO); 519 assert!(SafeNum::MAX.round_up(32).has_error()); 520 } 521 } 522