xref: /aosp_15_r20/external/pigweed/pw_stream/rust/pw_stream/cursor.rs (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 use core::cmp::min;
16 
17 use paste::paste;
18 use pw_status::{Error, Result};
19 use pw_varint::{VarintDecode, VarintEncode};
20 
21 use super::{Read, Seek, SeekFrom, Write};
22 
23 /// Wraps an <code>[AsRef]<[u8]></code> in a container implementing
24 /// [`Read`], [`Write`], and [`Seek`].
25 ///
26 /// [`Write`] support requires the inner type also implement
27 /// <code>[AsMut]<[u8]></code>.
28 pub struct Cursor<T>
29 where
30     T: AsRef<[u8]>,
31 {
32     inner: T,
33     pos: usize,
34 }
35 
36 impl<T: AsRef<[u8]>> Cursor<T> {
37     /// Create a new Cursor wrapping `inner` with an initial position of 0.
38     ///
39     /// Semantics match [`std::io::Cursor::new()`].
new(inner: T) -> Self40     pub fn new(inner: T) -> Self {
41         Self { inner, pos: 0 }
42     }
43 
44     /// Consumes the cursor and returns the inner wrapped data.
into_inner(self) -> T45     pub fn into_inner(self) -> T {
46         self.inner
47     }
48 
49     /// Returns the number of remaining bytes in the Cursor.
remaining(&self) -> usize50     pub fn remaining(&self) -> usize {
51         self.len() - self.pos
52     }
53 
54     /// Returns the total length of the Cursor.
55     // Empty is ambiguous whether it should refer to len() or remaining() so
56     // we don't provide it.
57     #[allow(clippy::len_without_is_empty)]
len(&self) -> usize58     pub fn len(&self) -> usize {
59         self.inner.as_ref().len()
60     }
61 
62     /// Returns current IO position of the Cursor.
position(&self) -> usize63     pub fn position(&self) -> usize {
64         self.pos
65     }
66 
remaining_slice(&mut self) -> &[u8]67     fn remaining_slice(&mut self) -> &[u8] {
68         &self.inner.as_ref()[self.pos..]
69     }
70 }
71 
72 impl<T: AsRef<[u8]> + AsMut<[u8]>> Cursor<T> {
remaining_mut(&mut self) -> &mut [u8]73     fn remaining_mut(&mut self) -> &mut [u8] {
74         &mut self.inner.as_mut()[self.pos..]
75     }
76 }
77 
78 // Implement `read()` as a concrete function to avoid extra monomorphization
79 // overhead.
read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize>80 fn read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize> {
81     let remaining = inner.len() - *pos;
82     let read_len = min(remaining, buf.len());
83     buf[..read_len].copy_from_slice(&inner[*pos..(*pos + read_len)]);
84     *pos += read_len;
85     Ok(read_len)
86 }
87 
88 impl<T: AsRef<[u8]>> Read for Cursor<T> {
read(&mut self, buf: &mut [u8]) -> Result<usize>89     fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
90         read_impl(self.inner.as_ref(), &mut self.pos, buf)
91     }
92 }
93 
94 // Implement `write()` as a concrete function to avoid extra monomorphization
95 // overhead.
write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize>96 fn write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize> {
97     let remaining = inner.len() - *pos;
98     let write_len = min(remaining, buf.len());
99     inner[*pos..(*pos + write_len)].copy_from_slice(&buf[0..write_len]);
100     *pos += write_len;
101     Ok(write_len)
102 }
103 
104 impl<T: AsRef<[u8]> + AsMut<[u8]>> Write for Cursor<T> {
write(&mut self, buf: &[u8]) -> Result<usize>105     fn write(&mut self, buf: &[u8]) -> Result<usize> {
106         write_impl(self.inner.as_mut(), &mut self.pos, buf)
107     }
108 
flush(&mut self) -> Result<()>109     fn flush(&mut self) -> Result<()> {
110         // Cursor does not provide any buffering so flush() is a noop.
111         Ok(())
112     }
113 }
114 
115 impl<T: AsRef<[u8]>> Seek for Cursor<T> {
seek(&mut self, pos: SeekFrom) -> Result<u64>116     fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
117         let new_pos = match pos {
118             SeekFrom::Start(pos) => pos,
119             SeekFrom::Current(pos) => (self.pos as u64)
120                 .checked_add_signed(pos)
121                 .ok_or(Error::OutOfRange)?,
122             SeekFrom::End(pos) => (self.len() as u64)
123                 .checked_add_signed(-pos)
124                 .ok_or(Error::OutOfRange)?,
125         };
126 
127         // Since Cursor operates on in memory buffers, it's limited by usize.
128         // Return an error if we are asked to seek beyond that limit.
129         let new_pos: usize = new_pos.try_into().map_err(|_| Error::OutOfRange)?;
130 
131         if new_pos > self.len() {
132             Err(Error::OutOfRange)
133         } else {
134             self.pos = new_pos;
135             Ok(new_pos as u64)
136         }
137     }
138 
139     // Implement more efficient versions of rewind, stream_len, stream_position.
rewind(&mut self) -> Result<()>140     fn rewind(&mut self) -> Result<()> {
141         self.pos = 0;
142         Ok(())
143     }
144 
stream_len(&mut self) -> Result<u64>145     fn stream_len(&mut self) -> Result<u64> {
146         Ok(self.len() as u64)
147     }
148 
stream_position(&mut self) -> Result<u64>149     fn stream_position(&mut self) -> Result<u64> {
150         Ok(self.pos as u64)
151     }
152 }
153 
154 macro_rules! cursor_read_type_impl {
155     ($ty:ident, $endian:ident) => {
156         paste! {
157           fn [<read_ $ty _ $endian>](&mut self) -> Result<$ty> {
158             const NUM_BYTES: usize = $ty::BITS as usize / 8;
159             if NUM_BYTES > self.remaining() {
160                 return Err(Error::OutOfRange);
161             }
162             let sub_slice = self
163                 .inner
164                 .as_ref()
165                 .get(self.pos..self.pos + NUM_BYTES)
166                 .ok_or_else(|| Error::InvalidArgument)?;
167             // Because we are code size conscious we want an infallible way to
168             // turn `sub_slice` into a fixed sized array as opposed to using
169             // something like `.try_into()?`.
170             //
171             // Safety:  We are both bounds checking and size constraining the
172             // slice in the above lines of code.
173             let sub_array: &[u8; NUM_BYTES] = unsafe { &*(sub_slice.as_ptr() as *const [u8; NUM_BYTES]) };
174             let value = $ty::[<from_ $endian _bytes>](*sub_array);
175 
176             self.pos += NUM_BYTES;
177             Ok(value)
178           }
179         }
180     };
181 }
182 
183 macro_rules! cursor_read_bits_impl {
184     ($bits:literal) => {
185         paste! {
186           cursor_read_type_impl!([<i $bits>], le);
187           cursor_read_type_impl!([<u $bits>], le);
188           cursor_read_type_impl!([<i $bits>], be);
189           cursor_read_type_impl!([<u $bits>], be);
190         }
191     };
192 }
193 
194 macro_rules! cursor_write_type_impl {
195     ($ty:ident, $endian:ident) => {
196         paste! {
197           fn [<write_ $ty _ $endian>](&mut self, value: &$ty) -> Result<()> {
198             const NUM_BYTES: usize = $ty::BITS as usize / 8;
199             if NUM_BYTES > self.remaining() {
200                 return Err(Error::OutOfRange);
201             }
202             let value_bytes = $ty::[<to_ $endian _bytes>](*value);
203             let sub_slice = self
204                 .inner
205                 .as_mut()
206                 .get_mut(self.pos..self.pos + NUM_BYTES)
207                 .ok_or_else(|| Error::InvalidArgument)?;
208 
209             sub_slice.copy_from_slice(&value_bytes[..]);
210 
211             self.pos += NUM_BYTES;
212             Ok(())
213           }
214         }
215     };
216 }
217 
218 macro_rules! cursor_write_bits_impl {
219     ($bits:literal) => {
220         paste! {
221           cursor_write_type_impl!([<i $bits>], le);
222           cursor_write_type_impl!([<u $bits>], le);
223           cursor_write_type_impl!([<i $bits>], be);
224           cursor_write_type_impl!([<u $bits>], be);
225         }
226     };
227 }
228 
229 impl<T: AsRef<[u8]>> crate::ReadInteger for Cursor<T> {
230     cursor_read_bits_impl!(8);
231     cursor_read_bits_impl!(16);
232     cursor_read_bits_impl!(32);
233     cursor_read_bits_impl!(64);
234     cursor_read_bits_impl!(128);
235 }
236 
237 impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteInteger for Cursor<T> {
238     cursor_write_bits_impl!(8);
239     cursor_write_bits_impl!(16);
240     cursor_write_bits_impl!(32);
241     cursor_write_bits_impl!(64);
242     cursor_write_bits_impl!(128);
243 }
244 
245 impl<T: AsRef<[u8]>> crate::ReadVarint for Cursor<T> {
read_varint(&mut self) -> Result<u64>246     fn read_varint(&mut self) -> Result<u64> {
247         let (len, value) = u64::varint_decode(self.remaining_slice())?;
248         self.pos += len;
249         Ok(value)
250     }
251 
read_signed_varint(&mut self) -> Result<i64>252     fn read_signed_varint(&mut self) -> Result<i64> {
253         let (len, value) = i64::varint_decode(self.remaining_slice())?;
254         self.pos += len;
255         Ok(value)
256     }
257 }
258 
259 impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteVarint for Cursor<T> {
write_varint(&mut self, value: u64) -> Result<()>260     fn write_varint(&mut self, value: u64) -> Result<()> {
261         let encoded_len = value.varint_encode(self.remaining_mut())?;
262         self.pos += encoded_len;
263         Ok(())
264     }
265 
write_signed_varint(&mut self, value: i64) -> Result<()>266     fn write_signed_varint(&mut self, value: i64) -> Result<()> {
267         let encoded_len = value.varint_encode(self.remaining_mut())?;
268         self.pos += encoded_len;
269         Ok(())
270     }
271 }
272 
273 #[cfg(test)]
274 mod tests {
275     use super::*;
276     use crate::{test_utils::*, ReadInteger, ReadVarint, WriteInteger, WriteVarint};
277 
278     #[test]
cursor_len_returns_total_bytes()279     fn cursor_len_returns_total_bytes() {
280         let cursor = Cursor {
281             inner: &[0u8; 64],
282             pos: 31,
283         };
284         assert_eq!(cursor.len(), 64);
285     }
286 
287     #[test]
cursor_remaining_returns_remaining_bytes()288     fn cursor_remaining_returns_remaining_bytes() {
289         let cursor = Cursor {
290             inner: &[0u8; 64],
291             pos: 31,
292         };
293         assert_eq!(cursor.remaining(), 33);
294     }
295 
296     #[test]
cursor_position_returns_current_position()297     fn cursor_position_returns_current_position() {
298         let cursor = Cursor {
299             inner: &[0u8; 64],
300             pos: 31,
301         };
302         assert_eq!(cursor.position(), 31);
303     }
304 
305     #[test]
cursor_read_of_partial_buffer_reads_correct_data()306     fn cursor_read_of_partial_buffer_reads_correct_data() {
307         let mut cursor = Cursor {
308             inner: &[1, 2, 3, 4, 5, 6, 7, 8],
309             pos: 4,
310         };
311         let mut buf = [0u8; 8];
312         assert_eq!(cursor.read(&mut buf), Ok(4));
313         assert_eq!(buf, [5, 6, 7, 8, 0, 0, 0, 0]);
314     }
315 
316     #[test]
cursor_write_of_partial_buffer_writes_correct_data()317     fn cursor_write_of_partial_buffer_writes_correct_data() {
318         let mut cursor = Cursor {
319             inner: &mut [0, 0, 0, 0, 0, 0, 0, 0],
320             pos: 4,
321         };
322         let buf = [1, 2, 3, 4, 5, 6, 7, 8];
323         assert_eq!(cursor.write(&buf), Ok(4));
324         assert_eq!(cursor.inner, &[0, 0, 0, 0, 1, 2, 3, 4]);
325     }
326 
327     #[test]
cursor_rewind_resets_position_to_zero()328     fn cursor_rewind_resets_position_to_zero() {
329         test_rewind_resets_position_to_zero::<64, _>(Cursor::new(&[0u8; 64]));
330     }
331 
332     #[test]
cursor_stream_pos_reports_correct_position()333     fn cursor_stream_pos_reports_correct_position() {
334         test_stream_pos_reports_correct_position::<64, _>(Cursor::new(&[0u8; 64]));
335     }
336 
337     #[test]
cursor_stream_len_reports_correct_length()338     fn cursor_stream_len_reports_correct_length() {
339         test_stream_len_reports_correct_length::<64, _>(Cursor::new(&[0u8; 64]));
340     }
341 
342     macro_rules! cursor_read_n_bit_integers_unpacks_data_correctly {
343         ($bits:literal) => {
344             paste! {
345               #[test]
346               fn [<cursor_read_ $bits _bit_integers_unpacks_data_correctly>]() {
347                   let (bytes, values) = [<integer_ $bits _bit_test_cases>]();
348                   let mut cursor = Cursor::new(&bytes);
349 
350                   assert_eq!(cursor.[<read_i $bits _le>](), Ok(values.0));
351                   assert_eq!(cursor.[<read_u $bits _le>](), Ok(values.1));
352                   assert_eq!(cursor.[<read_i $bits _be>](), Ok(values.2));
353                   assert_eq!(cursor.[<read_u $bits _be>](), Ok(values.3));
354               }
355             }
356         };
357     }
358 
359     macro_rules! cursor_write_n_bit_integers_packs_data_correctly {
360         ($bits:literal) => {
361             paste! {
362               #[test]
363               fn [<cursor_write_ $bits _bit_integers_packs_data_correctly>]() {
364                   let (expected_bytes, values) = [<integer_ $bits _bit_test_cases>]();
365                   let mut cursor = Cursor::new(vec![0u8; expected_bytes.len()]);
366                   cursor.[<write_i $bits _le>](&values.0).unwrap();
367                   cursor.[<write_u $bits _le>](&values.1).unwrap();
368                   cursor.[<write_i $bits _be>](&values.2).unwrap();
369                   cursor.[<write_u $bits _be>](&values.3).unwrap();
370 
371                   let result_bytes: Vec<u8> = cursor.into_inner().into();
372 
373                   assert_eq!(result_bytes, expected_bytes);
374               }
375             }
376         };
377     }
378 
integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8))379     fn integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8)) {
380         (
381             vec![
382                 0x0, // le i8
383                 0x1, // le u8
384                 0x2, // be i8
385                 0x3, // be u8
386             ],
387             (0, 1, 2, 3),
388         )
389     }
390 
391     cursor_read_n_bit_integers_unpacks_data_correctly!(8);
392     cursor_write_n_bit_integers_packs_data_correctly!(8);
393 
integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16))394     fn integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16)) {
395         (
396             vec![
397                 0x0, 0x80, // le i16
398                 0x1, 0x80, // le u16
399                 0x80, 0x2, // be i16
400                 0x80, 0x3, // be u16
401             ],
402             (
403                 i16::from_le_bytes([0x0, 0x80]),
404                 0x8001,
405                 i16::from_be_bytes([0x80, 0x2]),
406                 0x8003,
407             ),
408         )
409     }
410 
411     cursor_read_n_bit_integers_unpacks_data_correctly!(16);
412     cursor_write_n_bit_integers_packs_data_correctly!(16);
413 
integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32))414     fn integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32)) {
415         (
416             vec![
417                 0x0, 0x1, 0x2, 0x80, // le i32
418                 0x3, 0x4, 0x5, 0x80, // le u32
419                 0x80, 0x6, 0x7, 0x8, // be i32
420                 0x80, 0x9, 0xa, 0xb, // be u32
421             ],
422             (
423                 i32::from_le_bytes([0x0, 0x1, 0x2, 0x80]),
424                 0x8005_0403,
425                 i32::from_be_bytes([0x80, 0x6, 0x7, 0x8]),
426                 0x8009_0a0b,
427             ),
428         )
429     }
430 
431     cursor_read_n_bit_integers_unpacks_data_correctly!(32);
432     cursor_write_n_bit_integers_packs_data_correctly!(32);
433 
integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64))434     fn integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64)) {
435         (
436             vec![
437                 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80, // le i64
438                 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0x80, // le u64
439                 0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, // be i64
440                 0x80, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, // be u64
441             ],
442             (
443                 i64::from_le_bytes([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80]),
444                 0x800d_0c0b_0a09_0807,
445                 i64::from_be_bytes([0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]),
446                 0x8017_1819_1a1b_1c1d,
447             ),
448         )
449     }
450 
451     cursor_read_n_bit_integers_unpacks_data_correctly!(64);
452     cursor_write_n_bit_integers_packs_data_correctly!(64);
453 
integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128))454     fn integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128)) {
455         #[rustfmt::skip]
456         let val = (
457             vec![
458                 // le i128
459                 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
460                 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
461                 // le u128
462                 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
463                 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x8f,
464                 // be i128
465                 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
466                 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
467                 // be u128
468                 0x80, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
469                 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
470             ],
471             (
472                 i128::from_le_bytes([
473                     0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
474                     0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
475                 ]),
476                 0x8f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
477                 i128::from_be_bytes([
478                     0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
479                     0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
480                 ]),
481                 0x8031_3233_3435_3637_3839_3a3b_3c3d_3e3f,
482             ),
483         );
484         val
485     }
486 
487     cursor_read_n_bit_integers_unpacks_data_correctly!(128);
488     cursor_write_n_bit_integers_packs_data_correctly!(128);
489 
490     #[test]
read_varint_unpacks_data_correctly()491     pub fn read_varint_unpacks_data_correctly() {
492         let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
493         let value = cursor.read_varint().unwrap();
494         assert_eq!(value, 0xffff_fffe);
495 
496         let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
497         let value = cursor.read_varint().unwrap();
498         assert_eq!(value, 0xffff_ffff);
499     }
500 
501     #[test]
read_signed_varint_unpacks_data_correctly()502     pub fn read_signed_varint_unpacks_data_correctly() {
503         let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
504         let value = cursor.read_signed_varint().unwrap();
505         assert_eq!(value, i32::MAX.into());
506 
507         let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
508         let value = cursor.read_signed_varint().unwrap();
509         assert_eq!(value, i32::MIN.into());
510     }
511 
512     #[test]
write_varint_packs_data_correctly()513     pub fn write_varint_packs_data_correctly() {
514         let mut cursor = Cursor::new(vec![0u8; 8]);
515         cursor.write_varint(0xffff_fffe).unwrap();
516         let buf = cursor.into_inner();
517         assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
518 
519         let mut cursor = Cursor::new(vec![0u8; 8]);
520         cursor.write_varint(0xffff_ffff).unwrap();
521         let buf = cursor.into_inner();
522         assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
523     }
524 
525     #[test]
write_signed_varint_packs_data_correctly()526     pub fn write_signed_varint_packs_data_correctly() {
527         let mut cursor = Cursor::new(vec![0u8; 8]);
528         cursor.write_signed_varint(i32::MAX.into()).unwrap();
529         let buf = cursor.into_inner();
530         assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
531 
532         let mut cursor = Cursor::new(vec![0u8; 8]);
533         cursor.write_signed_varint(i32::MIN.into()).unwrap();
534         let buf = cursor.into_inner();
535         assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
536     }
537 }
538