1 use alloc::boxed::Box;
2 use alloc::vec::Vec;
3 use std::fmt;
4 use std::iter::FusedIterator;
5 
6 use super::lazy_buffer::LazyBuffer;
7 use crate::adaptors::checked_binomial;
8 
9 /// An iterator to iterate through all the `n`-length combinations in an iterator, with replacement.
10 ///
11 /// See [`.combinations_with_replacement()`](crate::Itertools::combinations_with_replacement)
12 /// for more information.
13 #[derive(Clone)]
14 #[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
15 pub struct CombinationsWithReplacement<I>
16 where
17     I: Iterator,
18     I::Item: Clone,
19 {
20     indices: Box<[usize]>,
21     pool: LazyBuffer<I>,
22     first: bool,
23 }
24 
25 impl<I> fmt::Debug for CombinationsWithReplacement<I>
26 where
27     I: Iterator + fmt::Debug,
28     I::Item: fmt::Debug + Clone,
29 {
30     debug_fmt_fields!(CombinationsWithReplacement, indices, pool, first);
31 }
32 
33 /// Create a new `CombinationsWithReplacement` from a clonable iterator.
combinations_with_replacement<I>(iter: I, k: usize) -> CombinationsWithReplacement<I> where I: Iterator, I::Item: Clone,34 pub fn combinations_with_replacement<I>(iter: I, k: usize) -> CombinationsWithReplacement<I>
35 where
36     I: Iterator,
37     I::Item: Clone,
38 {
39     let indices = alloc::vec![0; k].into_boxed_slice();
40     let pool: LazyBuffer<I> = LazyBuffer::new(iter);
41 
42     CombinationsWithReplacement {
43         indices,
44         pool,
45         first: true,
46     }
47 }
48 
49 impl<I> CombinationsWithReplacement<I>
50 where
51     I: Iterator,
52     I::Item: Clone,
53 {
54     /// Increments indices representing the combination to advance to the next
55     /// (in lexicographic order by increasing sequence) combination.
56     ///
57     /// Returns true if we've run out of combinations, false otherwise.
increment_indices(&mut self) -> bool58     fn increment_indices(&mut self) -> bool {
59         // Check if we need to consume more from the iterator
60         // This will run while we increment our first index digit
61         self.pool.get_next();
62 
63         // Work out where we need to update our indices
64         let mut increment = None;
65         for (i, indices_int) in self.indices.iter().enumerate().rev() {
66             if *indices_int < self.pool.len() - 1 {
67                 increment = Some((i, indices_int + 1));
68                 break;
69             }
70         }
71         match increment {
72             // If we can update the indices further
73             Some((increment_from, increment_value)) => {
74                 // We need to update the rightmost non-max value
75                 // and all those to the right
76                 for i in &mut self.indices[increment_from..] {
77                     *i = increment_value;
78                 }
79                 // TODO: once MSRV >= 1.50, use `fill` instead:
80                 // self.indices[increment_from..].fill(increment_value);
81                 false
82             }
83             // Otherwise, we're done
84             None => true,
85         }
86     }
87 }
88 
89 impl<I> Iterator for CombinationsWithReplacement<I>
90 where
91     I: Iterator,
92     I::Item: Clone,
93 {
94     type Item = Vec<I::Item>;
95 
next(&mut self) -> Option<Self::Item>96     fn next(&mut self) -> Option<Self::Item> {
97         if self.first {
98             // In empty edge cases, stop iterating immediately
99             if !(self.indices.is_empty() || self.pool.get_next()) {
100                 return None;
101             }
102             self.first = false;
103         } else if self.increment_indices() {
104             return None;
105         }
106         Some(self.pool.get_at(&self.indices))
107     }
108 
nth(&mut self, n: usize) -> Option<Self::Item>109     fn nth(&mut self, n: usize) -> Option<Self::Item> {
110         if self.first {
111             // In empty edge cases, stop iterating immediately
112             if !(self.indices.is_empty() || self.pool.get_next()) {
113                 return None;
114             }
115             self.first = false;
116         } else if self.increment_indices() {
117             return None;
118         }
119         for _ in 0..n {
120             if self.increment_indices() {
121                 return None;
122             }
123         }
124         Some(self.pool.get_at(&self.indices))
125     }
126 
size_hint(&self) -> (usize, Option<usize>)127     fn size_hint(&self) -> (usize, Option<usize>) {
128         let (mut low, mut upp) = self.pool.size_hint();
129         low = remaining_for(low, self.first, &self.indices).unwrap_or(usize::MAX);
130         upp = upp.and_then(|upp| remaining_for(upp, self.first, &self.indices));
131         (low, upp)
132     }
133 
count(self) -> usize134     fn count(self) -> usize {
135         let Self {
136             indices,
137             pool,
138             first,
139         } = self;
140         let n = pool.count();
141         remaining_for(n, first, &indices).unwrap()
142     }
143 }
144 
145 impl<I> FusedIterator for CombinationsWithReplacement<I>
146 where
147     I: Iterator,
148     I::Item: Clone,
149 {
150 }
151 
152 /// For a given size `n`, return the count of remaining combinations with replacement or None if it would overflow.
remaining_for(n: usize, first: bool, indices: &[usize]) -> Option<usize>153 fn remaining_for(n: usize, first: bool, indices: &[usize]) -> Option<usize> {
154     // With a "stars and bars" representation, choose k values with replacement from n values is
155     // like choosing k out of k + n − 1 positions (hence binomial(k + n - 1, k) possibilities)
156     // to place k stars and therefore n - 1 bars.
157     // Example (n=4, k=6): ***|*||** represents [0,0,0,1,3,3].
158     let count = |n: usize, k: usize| {
159         let positions = if n == 0 {
160             k.saturating_sub(1)
161         } else {
162             (n - 1).checked_add(k)?
163         };
164         checked_binomial(positions, k)
165     };
166     let k = indices.len();
167     if first {
168         count(n, k)
169     } else {
170         // The algorithm is similar to the one for combinations *without replacement*,
171         // except we choose values *with replacement* and indices are *non-strictly* monotonically sorted.
172 
173         // The combinations generated after the current one can be counted by counting as follows:
174         // - The subsequent combinations that differ in indices[0]:
175         //   If subsequent combinations differ in indices[0], then their value for indices[0]
176         //   must be at least 1 greater than the current indices[0].
177         //   As indices is monotonically sorted, this means we can effectively choose k values with
178         //   replacement from (n - 1 - indices[0]), leading to count(n - 1 - indices[0], k) possibilities.
179         // - The subsequent combinations with same indices[0], but differing indices[1]:
180         //   Here we can choose k - 1 values with replacement from (n - 1 - indices[1]) values,
181         //   leading to count(n - 1 - indices[1], k - 1) possibilities.
182         // - (...)
183         // - The subsequent combinations with same indices[0..=i], but differing indices[i]:
184         //   Here we can choose k - i values with replacement from (n - 1 - indices[i]) values: count(n - 1 - indices[i], k - i).
185         //   Since subsequent combinations can in any index, we must sum up the aforementioned binomial coefficients.
186 
187         // Below, `n0` resembles indices[i].
188         indices.iter().enumerate().try_fold(0usize, |sum, (i, n0)| {
189             sum.checked_add(count(n - 1 - *n0, k - i)?)
190         })
191     }
192 }
193