1 use crate::error::{FendError, Interrupt}; 2 use crate::interrupt::{test_int, Never}; 3 use crate::num::bigrat::BigRat; 4 use crate::num::complex::{self, Complex}; 5 use crate::result::FResult; 6 use crate::serialize::{Deserialize, Serialize}; 7 use std::cmp::Ordering; 8 use std::fmt::Write; 9 use std::ops::Neg; 10 use std::{fmt, io}; 11 12 use super::real::Real; 13 use super::{Base, Exact, FormattingStyle}; 14 15 #[derive(Clone)] 16 pub(crate) struct Dist { 17 // invariant: probabilities must sum to 1 18 parts: Vec<(Complex, BigRat)>, 19 } 20 21 impl Dist { serialize(&self, write: &mut impl io::Write) -> FResult<()>22 pub(crate) fn serialize(&self, write: &mut impl io::Write) -> FResult<()> { 23 self.parts.len().serialize(write)?; 24 for (a, b) in &self.parts { 25 a.serialize(write)?; 26 b.serialize(write)?; 27 } 28 Ok(()) 29 } 30 deserialize(read: &mut impl io::Read) -> FResult<Self>31 pub(crate) fn deserialize(read: &mut impl io::Read) -> FResult<Self> { 32 let len = usize::deserialize(read)?; 33 let mut parts = Vec::with_capacity(len); 34 for _ in 0..len { 35 let k = Complex::deserialize(read)?; 36 let v = BigRat::deserialize(read)?; 37 parts.push((k, v)); 38 } 39 Ok(Self { parts }) 40 } 41 one_point(self) -> FResult<Complex>42 pub(crate) fn one_point(self) -> FResult<Complex> { 43 if self.parts.len() == 1 { 44 Ok(self.parts.into_iter().next().unwrap().0) 45 } else { 46 Err(FendError::ProbabilityDistributionsNotAllowed) 47 } 48 } 49 one_point_ref(&self) -> FResult<&Complex>50 pub(crate) fn one_point_ref(&self) -> FResult<&Complex> { 51 if self.parts.len() == 1 { 52 Ok(&self.parts[0].0) 53 } else { 54 Err(FendError::ProbabilityDistributionsNotAllowed) 55 } 56 } 57 new_die<I: Interrupt>(count: u32, faces: u32, int: &I) -> FResult<Self>58 pub(crate) fn new_die<I: Interrupt>(count: u32, faces: u32, int: &I) -> FResult<Self> { 59 assert!(count != 0); 60 assert!(faces != 0); 61 if count > 1 { 62 let mut result = Self::new_die(1, faces, int)?; 63 for _ in 1..count { 64 test_int(int)?; 65 result = Exact::new(result, true) 66 .add(&Exact::new(Self::new_die(1, faces, int)?, true), int)? 67 .value; 68 } 69 return Ok(result); 70 } 71 let mut parts = Vec::new(); 72 let probability = BigRat::from(1).div(&BigRat::from(u64::from(faces)), int)?; 73 for face in 1..=faces { 74 test_int(int)?; 75 parts.push((Complex::from(u64::from(face)), probability.clone())); 76 } 77 Ok(Self { parts }) 78 } 79 equals_int<I: Interrupt>(&self, val: u64, int: &I) -> FResult<bool>80 pub(crate) fn equals_int<I: Interrupt>(&self, val: u64, int: &I) -> FResult<bool> { 81 Ok(self.parts.len() == 1 82 && self.parts[0].0.compare(&val.into(), int)? == Some(Ordering::Equal)) 83 } 84 85 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] sample<I: Interrupt>(self, ctx: &crate::Context, int: &I) -> FResult<Self>86 pub(crate) fn sample<I: Interrupt>(self, ctx: &crate::Context, int: &I) -> FResult<Self> { 87 if self.parts.len() == 1 { 88 return Ok(self); 89 } 90 let mut random = ctx.random_u32.ok_or(FendError::RandomNumbersNotAvailable)?(); 91 let mut res = None; 92 for (k, v) in self.parts { 93 random = random.saturating_sub((v.into_f64(int)? * f64::from(u32::MAX)) as u32); 94 if random == 0 { 95 return Ok(Self::from(k)); 96 } 97 res = Some(Self::from(k)); 98 } 99 res.ok_or(FendError::EmptyDistribution) 100 } 101 mean<I: Interrupt>(self, int: &I) -> FResult<Self>102 pub(crate) fn mean<I: Interrupt>(self, int: &I) -> FResult<Self> { 103 if self.parts.is_empty() { 104 return Err(FendError::EmptyDistribution); 105 } else if self.parts.len() == 1 { 106 return Ok(self); 107 } 108 109 let mut result = Exact::new(Complex::from(0), true); 110 for (k, v) in self.parts { 111 result = Exact::new(k, true) 112 .mul(&Exact::new(Complex::from(Real::from(v)), true), int)? 113 .add(result, int)?; 114 } 115 116 Ok(Self::from(result.value)) 117 } 118 119 #[allow( 120 clippy::cast_possible_truncation, 121 clippy::cast_sign_loss, 122 clippy::too_many_arguments 123 )] format<I: Interrupt>( &self, exact: bool, style: FormattingStyle, base: Base, use_parentheses: complex::UseParentheses, out: &mut String, ctx: &crate::Context, int: &I, ) -> FResult<Exact<()>>124 pub(crate) fn format<I: Interrupt>( 125 &self, 126 exact: bool, 127 style: FormattingStyle, 128 base: Base, 129 use_parentheses: complex::UseParentheses, 130 out: &mut String, 131 ctx: &crate::Context, 132 int: &I, 133 ) -> FResult<Exact<()>> { 134 if self.parts.len() == 1 { 135 let res = self.parts[0] 136 .0 137 .format(exact, style, base, use_parentheses, int)?; 138 write!(out, "{}", res.value)?; 139 Ok(Exact::new((), res.exact)) 140 } else { 141 let mut ordered_kvs = vec![]; 142 let mut max_prob = 0.0; 143 for (n, prob) in &self.parts { 144 let prob_f64 = prob.clone().into_f64(int)?; 145 if prob_f64 > max_prob { 146 max_prob = prob_f64; 147 } 148 ordered_kvs.push((n, prob, prob_f64)); 149 } 150 ordered_kvs.sort_unstable_by(|(a, _, _), (b, _, _)| { 151 a.compare(b, &Never).unwrap().unwrap_or(Ordering::Equal) 152 }); 153 if ctx.output_mode == crate::OutputMode::SimpleText { 154 write!(out, "{{ ")?; 155 } 156 let mut first = true; 157 for (num, _prob, prob_f64) in ordered_kvs { 158 let num = num 159 .format(exact, style, base, use_parentheses, int)? 160 .value 161 .to_string(); 162 let prob_percentage = prob_f64 * 100.0; 163 if ctx.output_mode == crate::OutputMode::TerminalFixedWidth { 164 if !first { 165 writeln!(out)?; 166 } 167 let mut bar = String::new(); 168 for _ in 0..(prob_f64 / max_prob * 30.0).min(30.0) as u32 { 169 bar.push('#'); 170 } 171 write!(out, "{num:>3}: {prob_percentage:>5.2}% {bar}")?; 172 } else { 173 if !first { 174 write!(out, ", ")?; 175 } 176 write!(out, "{num}: {prob_percentage:.2}%")?; 177 } 178 if first { 179 first = false; 180 } 181 } 182 if ctx.output_mode == crate::OutputMode::SimpleText { 183 write!(out, " }}")?; 184 } 185 // TODO check exactness 186 Ok(Exact::new((), true)) 187 } 188 } 189 bop<I: Interrupt>( self, rhs: &Self, mut f: impl FnMut(&Complex, &Complex, &I) -> FResult<Complex>, int: &I, ) -> FResult<Self>190 fn bop<I: Interrupt>( 191 self, 192 rhs: &Self, 193 mut f: impl FnMut(&Complex, &Complex, &I) -> FResult<Complex>, 194 int: &I, 195 ) -> FResult<Self> { 196 let mut parts = Vec::<(Complex, BigRat)>::new(); 197 for (n1, p1) in &self.parts { 198 for (n2, p2) in &rhs.parts { 199 let n = f(n1, n2, int)?; 200 let p = p1.clone().mul(p2, int)?; 201 let mut found = false; 202 for (k, prob) in &mut parts { 203 if k.compare(&n, int)? == Some(Ordering::Equal) { 204 *prob = prob.clone().add(p.clone(), int)?; 205 found = true; 206 break; 207 } 208 } 209 if !found { 210 parts.push((n, p)); 211 } 212 } 213 } 214 Ok(Self { parts }) 215 } 216 } 217 218 impl Exact<Dist> { add<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self>219 pub(crate) fn add<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self> { 220 let self_exact = self.exact; 221 let rhs_exact = rhs.exact; 222 let mut exact = true; 223 Ok(Self::new( 224 self.value.bop( 225 &rhs.value, 226 |a, b, int| { 227 let sum = Exact::new(a.clone(), self_exact) 228 .add(Exact::new(b.clone(), rhs_exact), int)?; 229 exact &= sum.exact; 230 Ok(sum.value) 231 }, 232 int, 233 )?, 234 exact, 235 )) 236 } 237 mul<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self>238 pub(crate) fn mul<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self> { 239 let self_exact = self.exact; 240 let rhs_exact = rhs.exact; 241 let mut exact = true; 242 Ok(Self::new( 243 self.value.bop( 244 &rhs.value, 245 |a, b, int| { 246 let sum = Exact::new(a.clone(), self_exact) 247 .mul(&Exact::new(b.clone(), rhs_exact), int)?; 248 exact &= sum.exact; 249 Ok(sum.value) 250 }, 251 int, 252 )?, 253 exact, 254 )) 255 } 256 div<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self>257 pub(crate) fn div<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self> { 258 let self_exact = self.exact; 259 let rhs_exact = rhs.exact; 260 let mut exact = true; 261 Ok(Self::new( 262 self.value.bop( 263 &rhs.value, 264 |a, b, int| { 265 let sum = Exact::new(a.clone(), self_exact) 266 .div(Exact::new(b.clone(), rhs_exact), int)?; 267 exact &= sum.exact; 268 Ok(sum.value) 269 }, 270 int, 271 )?, 272 exact, 273 )) 274 } 275 } 276 277 impl From<Complex> for Dist { from(v: Complex) -> Self278 fn from(v: Complex) -> Self { 279 Self { 280 parts: vec![(v, 1.into())], 281 } 282 } 283 } 284 285 impl From<Real> for Dist { from(v: Real) -> Self286 fn from(v: Real) -> Self { 287 Self::from(Complex::from(v)) 288 } 289 } 290 291 impl From<u64> for Dist { from(i: u64) -> Self292 fn from(i: u64) -> Self { 293 Self::from(Complex::from(i)) 294 } 295 } 296 297 impl fmt::Debug for Dist { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result298 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 299 match self.one_point_ref() { 300 Ok(complex) => write!(f, "{complex:?}"), 301 Err(_) => write!(f, "dist {:?}", self.parts), 302 } 303 } 304 } 305 306 impl Neg for Dist { 307 type Output = Self; neg(mut self) -> Self308 fn neg(mut self) -> Self { 309 for (k, _) in &mut self.parts { 310 *k = -k.clone(); 311 } 312 self 313 } 314 } 315