xref: /aosp_15_r20/external/cronet/third_party/rust/chromium_crates_io/vendor/fend-core-1.4.6/src/num/dist.rs (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 use crate::error::{FendError, Interrupt};
2 use crate::interrupt::{test_int, Never};
3 use crate::num::bigrat::BigRat;
4 use crate::num::complex::{self, Complex};
5 use crate::result::FResult;
6 use crate::serialize::{Deserialize, Serialize};
7 use std::cmp::Ordering;
8 use std::fmt::Write;
9 use std::ops::Neg;
10 use std::{fmt, io};
11 
12 use super::real::Real;
13 use super::{Base, Exact, FormattingStyle};
14 
15 #[derive(Clone)]
16 pub(crate) struct Dist {
17 	// invariant: probabilities must sum to 1
18 	parts: Vec<(Complex, BigRat)>,
19 }
20 
21 impl Dist {
serialize(&self, write: &mut impl io::Write) -> FResult<()>22 	pub(crate) fn serialize(&self, write: &mut impl io::Write) -> FResult<()> {
23 		self.parts.len().serialize(write)?;
24 		for (a, b) in &self.parts {
25 			a.serialize(write)?;
26 			b.serialize(write)?;
27 		}
28 		Ok(())
29 	}
30 
deserialize(read: &mut impl io::Read) -> FResult<Self>31 	pub(crate) fn deserialize(read: &mut impl io::Read) -> FResult<Self> {
32 		let len = usize::deserialize(read)?;
33 		let mut parts = Vec::with_capacity(len);
34 		for _ in 0..len {
35 			let k = Complex::deserialize(read)?;
36 			let v = BigRat::deserialize(read)?;
37 			parts.push((k, v));
38 		}
39 		Ok(Self { parts })
40 	}
41 
one_point(self) -> FResult<Complex>42 	pub(crate) fn one_point(self) -> FResult<Complex> {
43 		if self.parts.len() == 1 {
44 			Ok(self.parts.into_iter().next().unwrap().0)
45 		} else {
46 			Err(FendError::ProbabilityDistributionsNotAllowed)
47 		}
48 	}
49 
one_point_ref(&self) -> FResult<&Complex>50 	pub(crate) fn one_point_ref(&self) -> FResult<&Complex> {
51 		if self.parts.len() == 1 {
52 			Ok(&self.parts[0].0)
53 		} else {
54 			Err(FendError::ProbabilityDistributionsNotAllowed)
55 		}
56 	}
57 
new_die<I: Interrupt>(count: u32, faces: u32, int: &I) -> FResult<Self>58 	pub(crate) fn new_die<I: Interrupt>(count: u32, faces: u32, int: &I) -> FResult<Self> {
59 		assert!(count != 0);
60 		assert!(faces != 0);
61 		if count > 1 {
62 			let mut result = Self::new_die(1, faces, int)?;
63 			for _ in 1..count {
64 				test_int(int)?;
65 				result = Exact::new(result, true)
66 					.add(&Exact::new(Self::new_die(1, faces, int)?, true), int)?
67 					.value;
68 			}
69 			return Ok(result);
70 		}
71 		let mut parts = Vec::new();
72 		let probability = BigRat::from(1).div(&BigRat::from(u64::from(faces)), int)?;
73 		for face in 1..=faces {
74 			test_int(int)?;
75 			parts.push((Complex::from(u64::from(face)), probability.clone()));
76 		}
77 		Ok(Self { parts })
78 	}
79 
equals_int<I: Interrupt>(&self, val: u64, int: &I) -> FResult<bool>80 	pub(crate) fn equals_int<I: Interrupt>(&self, val: u64, int: &I) -> FResult<bool> {
81 		Ok(self.parts.len() == 1
82 			&& self.parts[0].0.compare(&val.into(), int)? == Some(Ordering::Equal))
83 	}
84 
85 	#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
sample<I: Interrupt>(self, ctx: &crate::Context, int: &I) -> FResult<Self>86 	pub(crate) fn sample<I: Interrupt>(self, ctx: &crate::Context, int: &I) -> FResult<Self> {
87 		if self.parts.len() == 1 {
88 			return Ok(self);
89 		}
90 		let mut random = ctx.random_u32.ok_or(FendError::RandomNumbersNotAvailable)?();
91 		let mut res = None;
92 		for (k, v) in self.parts {
93 			random = random.saturating_sub((v.into_f64(int)? * f64::from(u32::MAX)) as u32);
94 			if random == 0 {
95 				return Ok(Self::from(k));
96 			}
97 			res = Some(Self::from(k));
98 		}
99 		res.ok_or(FendError::EmptyDistribution)
100 	}
101 
mean<I: Interrupt>(self, int: &I) -> FResult<Self>102 	pub(crate) fn mean<I: Interrupt>(self, int: &I) -> FResult<Self> {
103 		if self.parts.is_empty() {
104 			return Err(FendError::EmptyDistribution);
105 		} else if self.parts.len() == 1 {
106 			return Ok(self);
107 		}
108 
109 		let mut result = Exact::new(Complex::from(0), true);
110 		for (k, v) in self.parts {
111 			result = Exact::new(k, true)
112 				.mul(&Exact::new(Complex::from(Real::from(v)), true), int)?
113 				.add(result, int)?;
114 		}
115 
116 		Ok(Self::from(result.value))
117 	}
118 
119 	#[allow(
120 		clippy::cast_possible_truncation,
121 		clippy::cast_sign_loss,
122 		clippy::too_many_arguments
123 	)]
format<I: Interrupt>( &self, exact: bool, style: FormattingStyle, base: Base, use_parentheses: complex::UseParentheses, out: &mut String, ctx: &crate::Context, int: &I, ) -> FResult<Exact<()>>124 	pub(crate) fn format<I: Interrupt>(
125 		&self,
126 		exact: bool,
127 		style: FormattingStyle,
128 		base: Base,
129 		use_parentheses: complex::UseParentheses,
130 		out: &mut String,
131 		ctx: &crate::Context,
132 		int: &I,
133 	) -> FResult<Exact<()>> {
134 		if self.parts.len() == 1 {
135 			let res = self.parts[0]
136 				.0
137 				.format(exact, style, base, use_parentheses, int)?;
138 			write!(out, "{}", res.value)?;
139 			Ok(Exact::new((), res.exact))
140 		} else {
141 			let mut ordered_kvs = vec![];
142 			let mut max_prob = 0.0;
143 			for (n, prob) in &self.parts {
144 				let prob_f64 = prob.clone().into_f64(int)?;
145 				if prob_f64 > max_prob {
146 					max_prob = prob_f64;
147 				}
148 				ordered_kvs.push((n, prob, prob_f64));
149 			}
150 			ordered_kvs.sort_unstable_by(|(a, _, _), (b, _, _)| {
151 				a.compare(b, &Never).unwrap().unwrap_or(Ordering::Equal)
152 			});
153 			if ctx.output_mode == crate::OutputMode::SimpleText {
154 				write!(out, "{{ ")?;
155 			}
156 			let mut first = true;
157 			for (num, _prob, prob_f64) in ordered_kvs {
158 				let num = num
159 					.format(exact, style, base, use_parentheses, int)?
160 					.value
161 					.to_string();
162 				let prob_percentage = prob_f64 * 100.0;
163 				if ctx.output_mode == crate::OutputMode::TerminalFixedWidth {
164 					if !first {
165 						writeln!(out)?;
166 					}
167 					let mut bar = String::new();
168 					for _ in 0..(prob_f64 / max_prob * 30.0).min(30.0) as u32 {
169 						bar.push('#');
170 					}
171 					write!(out, "{num:>3}: {prob_percentage:>5.2}%  {bar}")?;
172 				} else {
173 					if !first {
174 						write!(out, ", ")?;
175 					}
176 					write!(out, "{num}: {prob_percentage:.2}%")?;
177 				}
178 				if first {
179 					first = false;
180 				}
181 			}
182 			if ctx.output_mode == crate::OutputMode::SimpleText {
183 				write!(out, " }}")?;
184 			}
185 			// TODO check exactness
186 			Ok(Exact::new((), true))
187 		}
188 	}
189 
bop<I: Interrupt>( self, rhs: &Self, mut f: impl FnMut(&Complex, &Complex, &I) -> FResult<Complex>, int: &I, ) -> FResult<Self>190 	fn bop<I: Interrupt>(
191 		self,
192 		rhs: &Self,
193 		mut f: impl FnMut(&Complex, &Complex, &I) -> FResult<Complex>,
194 		int: &I,
195 	) -> FResult<Self> {
196 		let mut parts = Vec::<(Complex, BigRat)>::new();
197 		for (n1, p1) in &self.parts {
198 			for (n2, p2) in &rhs.parts {
199 				let n = f(n1, n2, int)?;
200 				let p = p1.clone().mul(p2, int)?;
201 				let mut found = false;
202 				for (k, prob) in &mut parts {
203 					if k.compare(&n, int)? == Some(Ordering::Equal) {
204 						*prob = prob.clone().add(p.clone(), int)?;
205 						found = true;
206 						break;
207 					}
208 				}
209 				if !found {
210 					parts.push((n, p));
211 				}
212 			}
213 		}
214 		Ok(Self { parts })
215 	}
216 }
217 
218 impl Exact<Dist> {
add<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self>219 	pub(crate) fn add<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self> {
220 		let self_exact = self.exact;
221 		let rhs_exact = rhs.exact;
222 		let mut exact = true;
223 		Ok(Self::new(
224 			self.value.bop(
225 				&rhs.value,
226 				|a, b, int| {
227 					let sum = Exact::new(a.clone(), self_exact)
228 						.add(Exact::new(b.clone(), rhs_exact), int)?;
229 					exact &= sum.exact;
230 					Ok(sum.value)
231 				},
232 				int,
233 			)?,
234 			exact,
235 		))
236 	}
237 
mul<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self>238 	pub(crate) fn mul<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self> {
239 		let self_exact = self.exact;
240 		let rhs_exact = rhs.exact;
241 		let mut exact = true;
242 		Ok(Self::new(
243 			self.value.bop(
244 				&rhs.value,
245 				|a, b, int| {
246 					let sum = Exact::new(a.clone(), self_exact)
247 						.mul(&Exact::new(b.clone(), rhs_exact), int)?;
248 					exact &= sum.exact;
249 					Ok(sum.value)
250 				},
251 				int,
252 			)?,
253 			exact,
254 		))
255 	}
256 
div<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self>257 	pub(crate) fn div<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self> {
258 		let self_exact = self.exact;
259 		let rhs_exact = rhs.exact;
260 		let mut exact = true;
261 		Ok(Self::new(
262 			self.value.bop(
263 				&rhs.value,
264 				|a, b, int| {
265 					let sum = Exact::new(a.clone(), self_exact)
266 						.div(Exact::new(b.clone(), rhs_exact), int)?;
267 					exact &= sum.exact;
268 					Ok(sum.value)
269 				},
270 				int,
271 			)?,
272 			exact,
273 		))
274 	}
275 }
276 
277 impl From<Complex> for Dist {
from(v: Complex) -> Self278 	fn from(v: Complex) -> Self {
279 		Self {
280 			parts: vec![(v, 1.into())],
281 		}
282 	}
283 }
284 
285 impl From<Real> for Dist {
from(v: Real) -> Self286 	fn from(v: Real) -> Self {
287 		Self::from(Complex::from(v))
288 	}
289 }
290 
291 impl From<u64> for Dist {
from(i: u64) -> Self292 	fn from(i: u64) -> Self {
293 		Self::from(Complex::from(i))
294 	}
295 }
296 
297 impl fmt::Debug for Dist {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result298 	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
299 		match self.one_point_ref() {
300 			Ok(complex) => write!(f, "{complex:?}"),
301 			Err(_) => write!(f, "dist {:?}", self.parts),
302 		}
303 	}
304 }
305 
306 impl Neg for Dist {
307 	type Output = Self;
neg(mut self) -> Self308 	fn neg(mut self) -> Self {
309 		for (k, _) in &mut self.parts {
310 			*k = -k.clone();
311 		}
312 		self
313 	}
314 }
315