1 //! Varint decode utilities.
2 
3 use crate::error::WireError;
4 use crate::varint::MAX_VARINT32_ENCODED_LEN;
5 use crate::varint::MAX_VARINT_ENCODED_LEN;
6 
7 trait DecodeVarint {
8     const MAX_ENCODED_LEN: usize;
9     const LAST_BYTE_MAX_VALUE: u8;
10 
from_u64(value: u64) -> Self11     fn from_u64(value: u64) -> Self;
12 }
13 
14 impl DecodeVarint for u64 {
15     const MAX_ENCODED_LEN: usize = MAX_VARINT_ENCODED_LEN;
16     const LAST_BYTE_MAX_VALUE: u8 = 0x01;
17 
from_u64(value: u64) -> Self18     fn from_u64(value: u64) -> Self {
19         value
20     }
21 }
22 
23 impl DecodeVarint for u32 {
24     const MAX_ENCODED_LEN: usize = MAX_VARINT32_ENCODED_LEN;
25     const LAST_BYTE_MAX_VALUE: u8 = 0x0f;
26 
from_u64(value: u64) -> Self27     fn from_u64(value: u64) -> Self {
28         value as u32
29     }
30 }
31 
32 /// Decode a varint, and return decoded value and decoded byte count.
33 #[inline]
decode_varint_full<D: DecodeVarint>(rem: &[u8]) -> crate::Result<Option<(D, usize)>>34 fn decode_varint_full<D: DecodeVarint>(rem: &[u8]) -> crate::Result<Option<(D, usize)>> {
35     let mut r: u64 = 0;
36     for (i, &b) in rem.iter().enumerate() {
37         if i == D::MAX_ENCODED_LEN - 1 {
38             if b > D::LAST_BYTE_MAX_VALUE {
39                 return Err(WireError::IncorrectVarint.into());
40             }
41             let r = r | ((b as u64) << (i as u64 * 7));
42             return Ok(Some((D::from_u64(r), i + 1)));
43         }
44 
45         r = r | (((b & 0x7f) as u64) << (i as u64 * 7));
46         if b < 0x80 {
47             return Ok(Some((D::from_u64(r), i + 1)));
48         }
49     }
50     Ok(None)
51 }
52 
53 #[inline]
decode_varint_impl<D: DecodeVarint>(buf: &[u8]) -> crate::Result<Option<(D, usize)>>54 fn decode_varint_impl<D: DecodeVarint>(buf: &[u8]) -> crate::Result<Option<(D, usize)>> {
55     if buf.len() >= 1 && buf[0] < 0x80 {
56         // The the most common case.
57         let ret = buf[0] as u64;
58         let consume = 1;
59         Ok(Some((D::from_u64(ret), consume)))
60     } else if buf.len() >= 2 && buf[1] < 0x80 {
61         // Handle the case of two bytes too.
62         let ret = (buf[0] & 0x7f) as u64 | (buf[1] as u64) << 7;
63         let consume = 2;
64         Ok(Some((D::from_u64(ret), consume)))
65     } else {
66         // Read from array when buf at at least 10 bytes,
67         // max len for varint.
68         decode_varint_full(buf)
69     }
70 }
71 
72 /// Try decode a varint. Return `None` if the buffer does not contain complete varint.
73 #[inline]
decode_varint64(buf: &[u8]) -> crate::Result<Option<(u64, usize)>>74 pub(crate) fn decode_varint64(buf: &[u8]) -> crate::Result<Option<(u64, usize)>> {
75     decode_varint_impl(buf)
76 }
77 
78 /// Try decode a varint. Return `None` if the buffer does not contain complete varint.
79 #[inline]
decode_varint32(buf: &[u8]) -> crate::Result<Option<(u32, usize)>>80 pub(crate) fn decode_varint32(buf: &[u8]) -> crate::Result<Option<(u32, usize)>> {
81     decode_varint_impl(buf)
82 }
83 
84 #[cfg(test)]
85 mod tests {
86     use crate::hex::decode_hex;
87     use crate::varint::decode::decode_varint32;
88     use crate::varint::decode::decode_varint64;
89 
90     #[test]
test_decode_varint64()91     fn test_decode_varint64() {
92         assert_eq!((0, 1), decode_varint64(&decode_hex("00")).unwrap().unwrap());
93         assert_eq!(
94             (u64::MAX, 10),
95             decode_varint64(&decode_hex("ff ff ff ff ff ff ff ff ff 01"))
96                 .unwrap()
97                 .unwrap()
98         );
99         assert!(decode_varint64(&decode_hex("ff ff ff ff ff ff ff ff ff 02")).is_err());
100     }
101 
102     #[test]
test_decode_varint32()103     fn test_decode_varint32() {
104         assert_eq!((0, 1), decode_varint32(&decode_hex("00")).unwrap().unwrap());
105         assert_eq!(
106             (u32::MAX, 5),
107             decode_varint32(&decode_hex("ff ff ff ff 0f"))
108                 .unwrap()
109                 .unwrap()
110         );
111         assert!(decode_varint32(&decode_hex("ff ff ff ff 10")).is_err());
112     }
113 }
114