1 use rayon::prelude::*;
2 
3 use std::panic;
4 use std::sync::atomic::AtomicUsize;
5 use std::sync::atomic::Ordering;
6 use std::sync::Mutex;
7 
8 #[test]
9 #[cfg_attr(not(panic = "unwind"), ignore)]
collect_drop_on_unwind()10 fn collect_drop_on_unwind() {
11     struct Recorddrop<'a>(i64, &'a Mutex<Vec<i64>>);
12 
13     impl<'a> Drop for Recorddrop<'a> {
14         fn drop(&mut self) {
15             self.1.lock().unwrap().push(self.0);
16         }
17     }
18 
19     let test_collect_panic = |will_panic: bool| {
20         let test_vec_len = 1024;
21         let panic_point = 740;
22 
23         let mut inserts = Mutex::new(Vec::new());
24         let mut drops = Mutex::new(Vec::new());
25 
26         let mut a = (0..test_vec_len).collect::<Vec<_>>();
27         let b = (0..test_vec_len).collect::<Vec<_>>();
28 
29         let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
30             let mut result = Vec::new();
31             a.par_iter_mut()
32                 .zip(&b)
33                 .map(|(&mut a, &b)| {
34                     if a > panic_point && will_panic {
35                         panic!("unwinding for test");
36                     }
37                     let elt = a + b;
38                     inserts.lock().unwrap().push(elt);
39                     Recorddrop(elt, &drops)
40                 })
41                 .collect_into_vec(&mut result);
42 
43             // If we reach this point, this must pass
44             assert_eq!(a.len(), result.len());
45         }));
46 
47         let inserts = inserts.get_mut().unwrap();
48         let drops = drops.get_mut().unwrap();
49         println!("{:?}", inserts);
50         println!("{:?}", drops);
51 
52         assert_eq!(inserts.len(), drops.len(), "Incorrect number of drops");
53         // sort to normalize order
54         inserts.sort();
55         drops.sort();
56         assert_eq!(inserts, drops, "Incorrect elements were dropped");
57     };
58 
59     for &should_panic in &[true, false] {
60         test_collect_panic(should_panic);
61     }
62 }
63 
64 #[test]
65 #[cfg_attr(not(panic = "unwind"), ignore)]
collect_drop_on_unwind_zst()66 fn collect_drop_on_unwind_zst() {
67     static INSERTS: AtomicUsize = AtomicUsize::new(0);
68     static DROPS: AtomicUsize = AtomicUsize::new(0);
69 
70     struct RecorddropZst;
71 
72     impl Drop for RecorddropZst {
73         fn drop(&mut self) {
74             DROPS.fetch_add(1, Ordering::SeqCst);
75         }
76     }
77 
78     let test_collect_panic = |will_panic: bool| {
79         INSERTS.store(0, Ordering::SeqCst);
80         DROPS.store(0, Ordering::SeqCst);
81 
82         let test_vec_len = 1024;
83         let panic_point = 740;
84 
85         let a = (0..test_vec_len).collect::<Vec<_>>();
86 
87         let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
88             let mut result = Vec::new();
89             a.par_iter()
90                 .map(|&a| {
91                     if a > panic_point && will_panic {
92                         panic!("unwinding for test");
93                     }
94                     INSERTS.fetch_add(1, Ordering::SeqCst);
95                     RecorddropZst
96                 })
97                 .collect_into_vec(&mut result);
98 
99             // If we reach this point, this must pass
100             assert_eq!(a.len(), result.len());
101         }));
102 
103         let inserts = INSERTS.load(Ordering::SeqCst);
104         let drops = DROPS.load(Ordering::SeqCst);
105 
106         assert_eq!(inserts, drops, "Incorrect number of drops");
107         assert!(will_panic || drops == test_vec_len)
108     };
109 
110     for &should_panic in &[true, false] {
111         test_collect_panic(should_panic);
112     }
113 }
114