1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 use core::cmp::min;
16
17 use paste::paste;
18 use pw_status::{Error, Result};
19 use pw_varint::{VarintDecode, VarintEncode};
20
21 use super::{Read, Seek, SeekFrom, Write};
22
23 /// Wraps an <code>[AsRef]<[u8]></code> in a container implementing
24 /// [`Read`], [`Write`], and [`Seek`].
25 ///
26 /// [`Write`] support requires the inner type also implement
27 /// <code>[AsMut]<[u8]></code>.
28 pub struct Cursor<T>
29 where
30 T: AsRef<[u8]>,
31 {
32 inner: T,
33 pos: usize,
34 }
35
36 impl<T: AsRef<[u8]>> Cursor<T> {
37 /// Create a new Cursor wrapping `inner` with an initial position of 0.
38 ///
39 /// Semantics match [`std::io::Cursor::new()`].
new(inner: T) -> Self40 pub fn new(inner: T) -> Self {
41 Self { inner, pos: 0 }
42 }
43
44 /// Consumes the cursor and returns the inner wrapped data.
into_inner(self) -> T45 pub fn into_inner(self) -> T {
46 self.inner
47 }
48
49 /// Returns the number of remaining bytes in the Cursor.
remaining(&self) -> usize50 pub fn remaining(&self) -> usize {
51 self.len() - self.pos
52 }
53
54 /// Returns the total length of the Cursor.
55 // Empty is ambiguous whether it should refer to len() or remaining() so
56 // we don't provide it.
57 #[allow(clippy::len_without_is_empty)]
len(&self) -> usize58 pub fn len(&self) -> usize {
59 self.inner.as_ref().len()
60 }
61
62 /// Returns current IO position of the Cursor.
position(&self) -> usize63 pub fn position(&self) -> usize {
64 self.pos
65 }
66
remaining_slice(&mut self) -> &[u8]67 fn remaining_slice(&mut self) -> &[u8] {
68 &self.inner.as_ref()[self.pos..]
69 }
70 }
71
72 impl<T: AsRef<[u8]> + AsMut<[u8]>> Cursor<T> {
remaining_mut(&mut self) -> &mut [u8]73 fn remaining_mut(&mut self) -> &mut [u8] {
74 &mut self.inner.as_mut()[self.pos..]
75 }
76 }
77
78 // Implement `read()` as a concrete function to avoid extra monomorphization
79 // overhead.
read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize>80 fn read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize> {
81 let remaining = inner.len() - *pos;
82 let read_len = min(remaining, buf.len());
83 buf[..read_len].copy_from_slice(&inner[*pos..(*pos + read_len)]);
84 *pos += read_len;
85 Ok(read_len)
86 }
87
88 impl<T: AsRef<[u8]>> Read for Cursor<T> {
read(&mut self, buf: &mut [u8]) -> Result<usize>89 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
90 read_impl(self.inner.as_ref(), &mut self.pos, buf)
91 }
92 }
93
94 // Implement `write()` as a concrete function to avoid extra monomorphization
95 // overhead.
write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize>96 fn write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize> {
97 let remaining = inner.len() - *pos;
98 let write_len = min(remaining, buf.len());
99 inner[*pos..(*pos + write_len)].copy_from_slice(&buf[0..write_len]);
100 *pos += write_len;
101 Ok(write_len)
102 }
103
104 impl<T: AsRef<[u8]> + AsMut<[u8]>> Write for Cursor<T> {
write(&mut self, buf: &[u8]) -> Result<usize>105 fn write(&mut self, buf: &[u8]) -> Result<usize> {
106 write_impl(self.inner.as_mut(), &mut self.pos, buf)
107 }
108
flush(&mut self) -> Result<()>109 fn flush(&mut self) -> Result<()> {
110 // Cursor does not provide any buffering so flush() is a noop.
111 Ok(())
112 }
113 }
114
115 impl<T: AsRef<[u8]>> Seek for Cursor<T> {
seek(&mut self, pos: SeekFrom) -> Result<u64>116 fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
117 let new_pos = match pos {
118 SeekFrom::Start(pos) => pos,
119 SeekFrom::Current(pos) => (self.pos as u64)
120 .checked_add_signed(pos)
121 .ok_or(Error::OutOfRange)?,
122 SeekFrom::End(pos) => (self.len() as u64)
123 .checked_add_signed(-pos)
124 .ok_or(Error::OutOfRange)?,
125 };
126
127 // Since Cursor operates on in memory buffers, it's limited by usize.
128 // Return an error if we are asked to seek beyond that limit.
129 let new_pos: usize = new_pos.try_into().map_err(|_| Error::OutOfRange)?;
130
131 if new_pos > self.len() {
132 Err(Error::OutOfRange)
133 } else {
134 self.pos = new_pos;
135 Ok(new_pos as u64)
136 }
137 }
138
139 // Implement more efficient versions of rewind, stream_len, stream_position.
rewind(&mut self) -> Result<()>140 fn rewind(&mut self) -> Result<()> {
141 self.pos = 0;
142 Ok(())
143 }
144
stream_len(&mut self) -> Result<u64>145 fn stream_len(&mut self) -> Result<u64> {
146 Ok(self.len() as u64)
147 }
148
stream_position(&mut self) -> Result<u64>149 fn stream_position(&mut self) -> Result<u64> {
150 Ok(self.pos as u64)
151 }
152 }
153
154 macro_rules! cursor_read_type_impl {
155 ($ty:ident, $endian:ident) => {
156 paste! {
157 fn [<read_ $ty _ $endian>](&mut self) -> Result<$ty> {
158 const NUM_BYTES: usize = $ty::BITS as usize / 8;
159 if NUM_BYTES > self.remaining() {
160 return Err(Error::OutOfRange);
161 }
162 let sub_slice = self
163 .inner
164 .as_ref()
165 .get(self.pos..self.pos + NUM_BYTES)
166 .ok_or_else(|| Error::InvalidArgument)?;
167 // Because we are code size conscious we want an infallible way to
168 // turn `sub_slice` into a fixed sized array as opposed to using
169 // something like `.try_into()?`.
170 //
171 // Safety: We are both bounds checking and size constraining the
172 // slice in the above lines of code.
173 let sub_array: &[u8; NUM_BYTES] = unsafe { &*(sub_slice.as_ptr() as *const [u8; NUM_BYTES]) };
174 let value = $ty::[<from_ $endian _bytes>](*sub_array);
175
176 self.pos += NUM_BYTES;
177 Ok(value)
178 }
179 }
180 };
181 }
182
183 macro_rules! cursor_read_bits_impl {
184 ($bits:literal) => {
185 paste! {
186 cursor_read_type_impl!([<i $bits>], le);
187 cursor_read_type_impl!([<u $bits>], le);
188 cursor_read_type_impl!([<i $bits>], be);
189 cursor_read_type_impl!([<u $bits>], be);
190 }
191 };
192 }
193
194 macro_rules! cursor_write_type_impl {
195 ($ty:ident, $endian:ident) => {
196 paste! {
197 fn [<write_ $ty _ $endian>](&mut self, value: &$ty) -> Result<()> {
198 const NUM_BYTES: usize = $ty::BITS as usize / 8;
199 if NUM_BYTES > self.remaining() {
200 return Err(Error::OutOfRange);
201 }
202 let value_bytes = $ty::[<to_ $endian _bytes>](*value);
203 let sub_slice = self
204 .inner
205 .as_mut()
206 .get_mut(self.pos..self.pos + NUM_BYTES)
207 .ok_or_else(|| Error::InvalidArgument)?;
208
209 sub_slice.copy_from_slice(&value_bytes[..]);
210
211 self.pos += NUM_BYTES;
212 Ok(())
213 }
214 }
215 };
216 }
217
218 macro_rules! cursor_write_bits_impl {
219 ($bits:literal) => {
220 paste! {
221 cursor_write_type_impl!([<i $bits>], le);
222 cursor_write_type_impl!([<u $bits>], le);
223 cursor_write_type_impl!([<i $bits>], be);
224 cursor_write_type_impl!([<u $bits>], be);
225 }
226 };
227 }
228
229 impl<T: AsRef<[u8]>> crate::ReadInteger for Cursor<T> {
230 cursor_read_bits_impl!(8);
231 cursor_read_bits_impl!(16);
232 cursor_read_bits_impl!(32);
233 cursor_read_bits_impl!(64);
234 cursor_read_bits_impl!(128);
235 }
236
237 impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteInteger for Cursor<T> {
238 cursor_write_bits_impl!(8);
239 cursor_write_bits_impl!(16);
240 cursor_write_bits_impl!(32);
241 cursor_write_bits_impl!(64);
242 cursor_write_bits_impl!(128);
243 }
244
245 impl<T: AsRef<[u8]>> crate::ReadVarint for Cursor<T> {
read_varint(&mut self) -> Result<u64>246 fn read_varint(&mut self) -> Result<u64> {
247 let (len, value) = u64::varint_decode(self.remaining_slice())?;
248 self.pos += len;
249 Ok(value)
250 }
251
read_signed_varint(&mut self) -> Result<i64>252 fn read_signed_varint(&mut self) -> Result<i64> {
253 let (len, value) = i64::varint_decode(self.remaining_slice())?;
254 self.pos += len;
255 Ok(value)
256 }
257 }
258
259 impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteVarint for Cursor<T> {
write_varint(&mut self, value: u64) -> Result<()>260 fn write_varint(&mut self, value: u64) -> Result<()> {
261 let encoded_len = value.varint_encode(self.remaining_mut())?;
262 self.pos += encoded_len;
263 Ok(())
264 }
265
write_signed_varint(&mut self, value: i64) -> Result<()>266 fn write_signed_varint(&mut self, value: i64) -> Result<()> {
267 let encoded_len = value.varint_encode(self.remaining_mut())?;
268 self.pos += encoded_len;
269 Ok(())
270 }
271 }
272
273 #[cfg(test)]
274 mod tests {
275 use super::*;
276 use crate::{test_utils::*, ReadInteger, ReadVarint, WriteInteger, WriteVarint};
277
278 #[test]
cursor_len_returns_total_bytes()279 fn cursor_len_returns_total_bytes() {
280 let cursor = Cursor {
281 inner: &[0u8; 64],
282 pos: 31,
283 };
284 assert_eq!(cursor.len(), 64);
285 }
286
287 #[test]
cursor_remaining_returns_remaining_bytes()288 fn cursor_remaining_returns_remaining_bytes() {
289 let cursor = Cursor {
290 inner: &[0u8; 64],
291 pos: 31,
292 };
293 assert_eq!(cursor.remaining(), 33);
294 }
295
296 #[test]
cursor_position_returns_current_position()297 fn cursor_position_returns_current_position() {
298 let cursor = Cursor {
299 inner: &[0u8; 64],
300 pos: 31,
301 };
302 assert_eq!(cursor.position(), 31);
303 }
304
305 #[test]
cursor_read_of_partial_buffer_reads_correct_data()306 fn cursor_read_of_partial_buffer_reads_correct_data() {
307 let mut cursor = Cursor {
308 inner: &[1, 2, 3, 4, 5, 6, 7, 8],
309 pos: 4,
310 };
311 let mut buf = [0u8; 8];
312 assert_eq!(cursor.read(&mut buf), Ok(4));
313 assert_eq!(buf, [5, 6, 7, 8, 0, 0, 0, 0]);
314 }
315
316 #[test]
cursor_write_of_partial_buffer_writes_correct_data()317 fn cursor_write_of_partial_buffer_writes_correct_data() {
318 let mut cursor = Cursor {
319 inner: &mut [0, 0, 0, 0, 0, 0, 0, 0],
320 pos: 4,
321 };
322 let buf = [1, 2, 3, 4, 5, 6, 7, 8];
323 assert_eq!(cursor.write(&buf), Ok(4));
324 assert_eq!(cursor.inner, &[0, 0, 0, 0, 1, 2, 3, 4]);
325 }
326
327 #[test]
cursor_rewind_resets_position_to_zero()328 fn cursor_rewind_resets_position_to_zero() {
329 test_rewind_resets_position_to_zero::<64, _>(Cursor::new(&[0u8; 64]));
330 }
331
332 #[test]
cursor_stream_pos_reports_correct_position()333 fn cursor_stream_pos_reports_correct_position() {
334 test_stream_pos_reports_correct_position::<64, _>(Cursor::new(&[0u8; 64]));
335 }
336
337 #[test]
cursor_stream_len_reports_correct_length()338 fn cursor_stream_len_reports_correct_length() {
339 test_stream_len_reports_correct_length::<64, _>(Cursor::new(&[0u8; 64]));
340 }
341
342 macro_rules! cursor_read_n_bit_integers_unpacks_data_correctly {
343 ($bits:literal) => {
344 paste! {
345 #[test]
346 fn [<cursor_read_ $bits _bit_integers_unpacks_data_correctly>]() {
347 let (bytes, values) = [<integer_ $bits _bit_test_cases>]();
348 let mut cursor = Cursor::new(&bytes);
349
350 assert_eq!(cursor.[<read_i $bits _le>](), Ok(values.0));
351 assert_eq!(cursor.[<read_u $bits _le>](), Ok(values.1));
352 assert_eq!(cursor.[<read_i $bits _be>](), Ok(values.2));
353 assert_eq!(cursor.[<read_u $bits _be>](), Ok(values.3));
354 }
355 }
356 };
357 }
358
359 macro_rules! cursor_write_n_bit_integers_packs_data_correctly {
360 ($bits:literal) => {
361 paste! {
362 #[test]
363 fn [<cursor_write_ $bits _bit_integers_packs_data_correctly>]() {
364 let (expected_bytes, values) = [<integer_ $bits _bit_test_cases>]();
365 let mut cursor = Cursor::new(vec![0u8; expected_bytes.len()]);
366 cursor.[<write_i $bits _le>](&values.0).unwrap();
367 cursor.[<write_u $bits _le>](&values.1).unwrap();
368 cursor.[<write_i $bits _be>](&values.2).unwrap();
369 cursor.[<write_u $bits _be>](&values.3).unwrap();
370
371 let result_bytes: Vec<u8> = cursor.into_inner().into();
372
373 assert_eq!(result_bytes, expected_bytes);
374 }
375 }
376 };
377 }
378
integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8))379 fn integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8)) {
380 (
381 vec![
382 0x0, // le i8
383 0x1, // le u8
384 0x2, // be i8
385 0x3, // be u8
386 ],
387 (0, 1, 2, 3),
388 )
389 }
390
391 cursor_read_n_bit_integers_unpacks_data_correctly!(8);
392 cursor_write_n_bit_integers_packs_data_correctly!(8);
393
integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16))394 fn integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16)) {
395 (
396 vec![
397 0x0, 0x80, // le i16
398 0x1, 0x80, // le u16
399 0x80, 0x2, // be i16
400 0x80, 0x3, // be u16
401 ],
402 (
403 i16::from_le_bytes([0x0, 0x80]),
404 0x8001,
405 i16::from_be_bytes([0x80, 0x2]),
406 0x8003,
407 ),
408 )
409 }
410
411 cursor_read_n_bit_integers_unpacks_data_correctly!(16);
412 cursor_write_n_bit_integers_packs_data_correctly!(16);
413
integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32))414 fn integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32)) {
415 (
416 vec![
417 0x0, 0x1, 0x2, 0x80, // le i32
418 0x3, 0x4, 0x5, 0x80, // le u32
419 0x80, 0x6, 0x7, 0x8, // be i32
420 0x80, 0x9, 0xa, 0xb, // be u32
421 ],
422 (
423 i32::from_le_bytes([0x0, 0x1, 0x2, 0x80]),
424 0x8005_0403,
425 i32::from_be_bytes([0x80, 0x6, 0x7, 0x8]),
426 0x8009_0a0b,
427 ),
428 )
429 }
430
431 cursor_read_n_bit_integers_unpacks_data_correctly!(32);
432 cursor_write_n_bit_integers_packs_data_correctly!(32);
433
integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64))434 fn integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64)) {
435 (
436 vec![
437 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80, // le i64
438 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0x80, // le u64
439 0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, // be i64
440 0x80, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, // be u64
441 ],
442 (
443 i64::from_le_bytes([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80]),
444 0x800d_0c0b_0a09_0807,
445 i64::from_be_bytes([0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]),
446 0x8017_1819_1a1b_1c1d,
447 ),
448 )
449 }
450
451 cursor_read_n_bit_integers_unpacks_data_correctly!(64);
452 cursor_write_n_bit_integers_packs_data_correctly!(64);
453
integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128))454 fn integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128)) {
455 #[rustfmt::skip]
456 let val = (
457 vec![
458 // le i128
459 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
460 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
461 // le u128
462 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
463 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x8f,
464 // be i128
465 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
466 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
467 // be u128
468 0x80, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
469 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
470 ],
471 (
472 i128::from_le_bytes([
473 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
474 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
475 ]),
476 0x8f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
477 i128::from_be_bytes([
478 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
479 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
480 ]),
481 0x8031_3233_3435_3637_3839_3a3b_3c3d_3e3f,
482 ),
483 );
484 val
485 }
486
487 cursor_read_n_bit_integers_unpacks_data_correctly!(128);
488 cursor_write_n_bit_integers_packs_data_correctly!(128);
489
490 #[test]
read_varint_unpacks_data_correctly()491 pub fn read_varint_unpacks_data_correctly() {
492 let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
493 let value = cursor.read_varint().unwrap();
494 assert_eq!(value, 0xffff_fffe);
495
496 let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
497 let value = cursor.read_varint().unwrap();
498 assert_eq!(value, 0xffff_ffff);
499 }
500
501 #[test]
read_signed_varint_unpacks_data_correctly()502 pub fn read_signed_varint_unpacks_data_correctly() {
503 let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
504 let value = cursor.read_signed_varint().unwrap();
505 assert_eq!(value, i32::MAX.into());
506
507 let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
508 let value = cursor.read_signed_varint().unwrap();
509 assert_eq!(value, i32::MIN.into());
510 }
511
512 #[test]
write_varint_packs_data_correctly()513 pub fn write_varint_packs_data_correctly() {
514 let mut cursor = Cursor::new(vec![0u8; 8]);
515 cursor.write_varint(0xffff_fffe).unwrap();
516 let buf = cursor.into_inner();
517 assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
518
519 let mut cursor = Cursor::new(vec![0u8; 8]);
520 cursor.write_varint(0xffff_ffff).unwrap();
521 let buf = cursor.into_inner();
522 assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
523 }
524
525 #[test]
write_signed_varint_packs_data_correctly()526 pub fn write_signed_varint_packs_data_correctly() {
527 let mut cursor = Cursor::new(vec![0u8; 8]);
528 cursor.write_signed_varint(i32::MAX.into()).unwrap();
529 let buf = cursor.into_inner();
530 assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
531
532 let mut cursor = Cursor::new(vec![0u8; 8]);
533 cursor.write_signed_varint(i32::MIN.into()).unwrap();
534 let buf = cursor.into_inner();
535 assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
536 }
537 }
538