1 #[cfg(not(feature = "std"))]
2 use core as std;
3 
4 use crate::{Block, FixedBitSet, BYTES};
5 use alloc::vec::Vec;
6 use core::{convert::TryFrom, fmt};
7 use serde::de::{self, Deserialize, Deserializer, MapAccess, SeqAccess, Visitor};
8 use serde::ser::{Serialize, SerializeStruct, Serializer};
9 
10 struct BitSetByteSerializer<'a>(&'a FixedBitSet);
11 
12 impl Serialize for FixedBitSet {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer,13     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
14     where
15         S: Serializer,
16     {
17         let mut struct_serializer = serializer.serialize_struct("FixedBitset", 2)?;
18         struct_serializer.serialize_field("length", &(self.length as u64))?;
19         struct_serializer.serialize_field("data", &BitSetByteSerializer(self))?;
20         struct_serializer.end()
21     }
22 }
23 
24 impl<'a> Serialize for BitSetByteSerializer<'a> {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer,25     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
26     where
27         S: Serializer,
28     {
29         let len = self.0.as_slice().len() * BYTES;
30         // PERF: Figure out a way to do this without allocating.
31         let mut temp = Vec::with_capacity(len);
32         for block in self.0.as_slice() {
33             temp.extend(&block.to_le_bytes());
34         }
35         serializer.serialize_bytes(&temp)
36     }
37 }
38 
39 impl<'de> Deserialize<'de> for FixedBitSet {
deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de>,40     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41     where
42         D: Deserializer<'de>,
43     {
44         enum Field {
45             Length,
46             Data,
47         }
48 
49         fn bytes_to_data(length: usize, input: &[u8]) -> Vec<Block> {
50             let block_len = length / BYTES + 1;
51             let mut data = Vec::with_capacity(block_len);
52             for chunk in input.chunks(BYTES) {
53                 match <&[u8; BYTES]>::try_from(chunk) {
54                     Ok(bytes) => data.push(usize::from_le_bytes(*bytes)),
55                     Err(_) => {
56                         let mut bytes = [0u8; BYTES];
57                         bytes[0..BYTES].copy_from_slice(chunk);
58                         data.push(usize::from_le_bytes(bytes));
59                     }
60                 }
61             }
62             data
63         }
64 
65         impl<'de> Deserialize<'de> for Field {
66             fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
67             where
68                 D: Deserializer<'de>,
69             {
70                 struct FieldVisitor;
71 
72                 impl<'de> Visitor<'de> for FieldVisitor {
73                     type Value = Field;
74 
75                     fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
76                         formatter.write_str("`length` or `data`")
77                     }
78 
79                     fn visit_str<E>(self, value: &str) -> Result<Field, E>
80                     where
81                         E: de::Error,
82                     {
83                         match value {
84                             "length" => Ok(Field::Length),
85                             "data" => Ok(Field::Data),
86                             _ => Err(de::Error::unknown_field(value, FIELDS)),
87                         }
88                     }
89                 }
90 
91                 deserializer.deserialize_identifier(FieldVisitor)
92             }
93         }
94 
95         struct FixedBitSetVisitor;
96 
97         impl<'de> Visitor<'de> for FixedBitSetVisitor {
98             type Value = FixedBitSet;
99 
100             fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
101                 formatter.write_str("struct Duration")
102             }
103 
104             fn visit_seq<V>(self, mut seq: V) -> Result<FixedBitSet, V::Error>
105             where
106                 V: SeqAccess<'de>,
107             {
108                 let length = seq
109                     .next_element()?
110                     .ok_or_else(|| de::Error::invalid_length(0, &self))?;
111                 let data: &[u8] = seq
112                     .next_element()?
113                     .ok_or_else(|| de::Error::invalid_length(1, &self))?;
114                 let data = bytes_to_data(length, data);
115                 Ok(FixedBitSet::with_capacity_and_blocks(length, data))
116             }
117 
118             fn visit_map<V>(self, mut map: V) -> Result<FixedBitSet, V::Error>
119             where
120                 V: MapAccess<'de>,
121             {
122                 let mut length = None;
123                 let mut temp: Option<&[u8]> = None;
124                 while let Some(key) = map.next_key()? {
125                     match key {
126                         Field::Length => {
127                             if length.is_some() {
128                                 return Err(de::Error::duplicate_field("length"));
129                             }
130                             length = Some(map.next_value()?);
131                         }
132                         Field::Data => {
133                             if temp.is_some() {
134                                 return Err(de::Error::duplicate_field("data"));
135                             }
136                             temp = Some(map.next_value()?);
137                         }
138                     }
139                 }
140                 let length = length.ok_or_else(|| de::Error::missing_field("length"))?;
141                 let data = temp.ok_or_else(|| de::Error::missing_field("data"))?;
142                 let data = bytes_to_data(length, data);
143                 Ok(FixedBitSet::with_capacity_and_blocks(length, data))
144             }
145         }
146 
147         const FIELDS: &'static [&'static str] = &["length", "data"];
148         deserializer.deserialize_struct("Duration", FIELDS, FixedBitSetVisitor)
149     }
150 }
151