1 use crate::util::AnyValueId;
2 use crate::util::FlatMap;
3 
4 #[derive(Default, Clone, Debug)]
5 pub(crate) struct Extensions {
6     extensions: FlatMap<AnyValueId, BoxedExtension>,
7 }
8 
9 impl Extensions {
10     #[allow(dead_code)]
get<T: Extension>(&self) -> Option<&T>11     pub(crate) fn get<T: Extension>(&self) -> Option<&T> {
12         let id = AnyValueId::of::<T>();
13         self.extensions.get(&id).map(|e| e.as_ref::<T>())
14     }
15 
16     #[allow(dead_code)]
get_mut<T: Extension>(&mut self) -> Option<&mut T>17     pub(crate) fn get_mut<T: Extension>(&mut self) -> Option<&mut T> {
18         let id = AnyValueId::of::<T>();
19         self.extensions.get_mut(&id).map(|e| e.as_mut::<T>())
20     }
21 
22     #[allow(dead_code)]
get_or_insert_default<T: Extension + Default>(&mut self) -> &mut T23     pub(crate) fn get_or_insert_default<T: Extension + Default>(&mut self) -> &mut T {
24         let id = AnyValueId::of::<T>();
25         self.extensions
26             .entry(id)
27             .or_insert_with(|| BoxedExtension::new(T::default()))
28             .as_mut::<T>()
29     }
30 
31     #[allow(dead_code)]
set<T: Extension + Into<BoxedEntry>>(&mut self, tagged: T) -> bool32     pub(crate) fn set<T: Extension + Into<BoxedEntry>>(&mut self, tagged: T) -> bool {
33         let BoxedEntry { id, value } = tagged.into();
34         self.extensions.insert(id, value).is_some()
35     }
36 
37     #[allow(dead_code)]
remove<T: Extension>(&mut self) -> Option<Box<dyn Extension>>38     pub(crate) fn remove<T: Extension>(&mut self) -> Option<Box<dyn Extension>> {
39         let id = AnyValueId::of::<T>();
40         self.extensions.remove(&id).map(BoxedExtension::into_inner)
41     }
42 
update(&mut self, other: &Self)43     pub(crate) fn update(&mut self, other: &Self) {
44         for (key, value) in other.extensions.iter() {
45             self.extensions.insert(*key, value.clone());
46         }
47     }
48 }
49 
50 /// Supports conversion to `Any`. Traits to be extended by `impl_downcast!` must extend `Extension`.
51 pub(crate) trait Extension: std::fmt::Debug + Send + Sync + 'static {
52     /// Convert `Box<dyn Trait>` (where `Trait: Extension`) to `Box<dyn Any>`.
53     ///
54     /// `Box<dyn Any>` can /// then be further `downcast` into
55     /// `Box<ConcreteType>` where `ConcreteType` implements `Trait`.
into_any(self: Box<Self>) -> Box<dyn std::any::Any>56     fn into_any(self: Box<Self>) -> Box<dyn std::any::Any>;
57     /// Clone `&Box<dyn Trait>` (where `Trait: Extension`) to `Box<dyn Extension>`.
58     ///
59     /// `Box<dyn Any>` can /// then be further `downcast` into
60     // `Box<ConcreteType>` where `ConcreteType` implements `Trait`.
clone_extension(&self) -> Box<dyn Extension>61     fn clone_extension(&self) -> Box<dyn Extension>;
62     /// Convert `&Trait` (where `Trait: Extension`) to `&Any`.
63     ///
64     /// This is needed since Rust cannot /// generate `&Any`'s vtable from
65     /// `&Trait`'s.
as_any(&self) -> &dyn std::any::Any66     fn as_any(&self) -> &dyn std::any::Any;
67     /// Convert `&mut Trait` (where `Trait: Extension`) to `&Any`.
68     ///
69     /// This is needed since Rust cannot /// generate `&mut Any`'s vtable from
70     /// `&mut Trait`'s.
as_any_mut(&mut self) -> &mut dyn std::any::Any71     fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
72 }
73 
74 impl<T> Extension for T
75 where
76     T: Clone + std::fmt::Debug + Send + Sync + 'static,
77 {
into_any(self: Box<Self>) -> Box<dyn std::any::Any>78     fn into_any(self: Box<Self>) -> Box<dyn std::any::Any> {
79         self
80     }
clone_extension(&self) -> Box<dyn Extension>81     fn clone_extension(&self) -> Box<dyn Extension> {
82         Box::new(self.clone())
83     }
as_any(&self) -> &dyn std::any::Any84     fn as_any(&self) -> &dyn std::any::Any {
85         self
86     }
as_any_mut(&mut self) -> &mut dyn std::any::Any87     fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
88         self
89     }
90 }
91 
92 impl Clone for Box<dyn Extension> {
clone(&self) -> Self93     fn clone(&self) -> Self {
94         self.as_ref().clone_extension()
95     }
96 }
97 
98 #[derive(Clone)]
99 #[repr(transparent)]
100 struct BoxedExtension(Box<dyn Extension>);
101 
102 impl BoxedExtension {
new<T: Extension>(inner: T) -> Self103     fn new<T: Extension>(inner: T) -> Self {
104         Self(Box::new(inner))
105     }
106 
into_inner(self) -> Box<dyn Extension>107     fn into_inner(self) -> Box<dyn Extension> {
108         self.0
109     }
110 
as_ref<T: Extension>(&self) -> &T111     fn as_ref<T: Extension>(&self) -> &T {
112         self.0.as_ref().as_any().downcast_ref::<T>().unwrap()
113     }
114 
as_mut<T: Extension>(&mut self) -> &mut T115     fn as_mut<T: Extension>(&mut self) -> &mut T {
116         self.0.as_mut().as_any_mut().downcast_mut::<T>().unwrap()
117     }
118 }
119 
120 impl std::fmt::Debug for BoxedExtension {
fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error>121     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
122         self.0.fmt(f)
123     }
124 }
125 
126 #[derive(Clone)]
127 pub(crate) struct BoxedEntry {
128     id: AnyValueId,
129     value: BoxedExtension,
130 }
131 
132 impl BoxedEntry {
new(r: impl Extension) -> Self133     pub(crate) fn new(r: impl Extension) -> Self {
134         let id = AnyValueId::from(&r);
135         let value = BoxedExtension::new(r);
136         BoxedEntry { id, value }
137     }
138 }
139 
140 impl<R: Extension> From<R> for BoxedEntry {
from(inner: R) -> Self141     fn from(inner: R) -> Self {
142         BoxedEntry::new(inner)
143     }
144 }
145 
146 #[cfg(test)]
147 mod test {
148     use super::*;
149 
150     #[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
151     struct Number(usize);
152 
153     #[test]
get()154     fn get() {
155         let mut ext = Extensions::default();
156         ext.set(Number(10));
157         assert_eq!(ext.get::<Number>(), Some(&Number(10)));
158     }
159 
160     #[test]
get_mut()161     fn get_mut() {
162         let mut ext = Extensions::default();
163         ext.set(Number(10));
164         *ext.get_mut::<Number>().unwrap() = Number(20);
165         assert_eq!(ext.get::<Number>(), Some(&Number(20)));
166     }
167 
168     #[test]
get_or_insert_default_empty()169     fn get_or_insert_default_empty() {
170         let mut ext = Extensions::default();
171         assert_eq!(ext.get_or_insert_default::<Number>(), &Number(0));
172     }
173 
174     #[test]
get_or_insert_default_full()175     fn get_or_insert_default_full() {
176         let mut ext = Extensions::default();
177         ext.set(Number(10));
178         assert_eq!(ext.get_or_insert_default::<Number>(), &Number(10));
179     }
180 
181     #[test]
set()182     fn set() {
183         let mut ext = Extensions::default();
184         assert!(!ext.set(Number(10)));
185         assert_eq!(ext.get::<Number>(), Some(&Number(10)));
186         assert!(ext.set(Number(20)));
187         assert_eq!(ext.get::<Number>(), Some(&Number(20)));
188     }
189 
190     #[test]
reset()191     fn reset() {
192         let mut ext = Extensions::default();
193         assert_eq!(ext.get::<Number>(), None);
194 
195         assert!(ext.remove::<Number>().is_none());
196         assert_eq!(ext.get::<Number>(), None);
197 
198         assert!(!ext.set(Number(10)));
199         assert_eq!(ext.get::<Number>(), Some(&Number(10)));
200 
201         assert!(ext.remove::<Number>().is_some());
202         assert_eq!(ext.get::<Number>(), None);
203     }
204 
205     #[test]
update()206     fn update() {
207         let mut ext = Extensions::default();
208         assert_eq!(ext.get::<Number>(), None);
209 
210         let mut new = Extensions::default();
211         assert!(!new.set(Number(10)));
212 
213         ext.update(&new);
214         assert_eq!(ext.get::<Number>(), Some(&Number(10)));
215     }
216 }
217