1 use std::{error::Error as StdError, fmt, iter, num, str};
2 
3 use serde::{
4     de::value::BorrowedBytesDeserializer,
5     de::{
6         Deserialize, DeserializeSeed, Deserializer, EnumAccess,
7         Error as SerdeError, IntoDeserializer, MapAccess, SeqAccess,
8         Unexpected, VariantAccess, Visitor,
9     },
10     serde_if_integer128,
11 };
12 
13 use crate::{
14     byte_record::{ByteRecord, ByteRecordIter},
15     error::{Error, ErrorKind},
16     string_record::{StringRecord, StringRecordIter},
17 };
18 
19 use self::DeserializeErrorKind as DEK;
20 
deserialize_string_record<'de, D: Deserialize<'de>>( record: &'de StringRecord, headers: Option<&'de StringRecord>, ) -> Result<D, Error>21 pub fn deserialize_string_record<'de, D: Deserialize<'de>>(
22     record: &'de StringRecord,
23     headers: Option<&'de StringRecord>,
24 ) -> Result<D, Error> {
25     let mut deser = DeRecordWrap(DeStringRecord {
26         it: record.iter().peekable(),
27         headers: headers.map(|r| r.iter()),
28         field: 0,
29     });
30     D::deserialize(&mut deser).map_err(|err| {
31         Error::new(ErrorKind::Deserialize {
32             pos: record.position().map(Clone::clone),
33             err,
34         })
35     })
36 }
37 
deserialize_byte_record<'de, D: Deserialize<'de>>( record: &'de ByteRecord, headers: Option<&'de ByteRecord>, ) -> Result<D, Error>38 pub fn deserialize_byte_record<'de, D: Deserialize<'de>>(
39     record: &'de ByteRecord,
40     headers: Option<&'de ByteRecord>,
41 ) -> Result<D, Error> {
42     let mut deser = DeRecordWrap(DeByteRecord {
43         it: record.iter().peekable(),
44         headers: headers.map(|r| r.iter()),
45         field: 0,
46     });
47     D::deserialize(&mut deser).map_err(|err| {
48         Error::new(ErrorKind::Deserialize {
49             pos: record.position().map(Clone::clone),
50             err,
51         })
52     })
53 }
54 
55 /// An over-engineered internal trait that permits writing a single Serde
56 /// deserializer that works on both ByteRecord and StringRecord.
57 ///
58 /// We *could* implement a single deserializer on `ByteRecord` and simply
59 /// convert `StringRecord`s to `ByteRecord`s, but then the implementation
60 /// would be required to redo UTF-8 validation checks in certain places.
61 ///
62 /// How does this work? We create a new `DeRecordWrap` type that wraps
63 /// either a `StringRecord` or a `ByteRecord`. We then implement
64 /// `DeRecord` for `DeRecordWrap<ByteRecord>` and `DeRecordWrap<StringRecord>`.
65 /// Finally, we impl `serde::Deserialize` for `DeRecordWrap<T>` where
66 /// `T: DeRecord`. That is, the `DeRecord` type corresponds to the differences
67 /// between deserializing into a `ByteRecord` and deserializing into a
68 /// `StringRecord`.
69 ///
70 /// The lifetime `'r` refers to the lifetime of the underlying record.
71 trait DeRecord<'r> {
72     /// Returns true if and only if this deserialize has access to headers.
has_headers(&self) -> bool73     fn has_headers(&self) -> bool;
74 
75     /// Extracts the next string header value from the underlying record.
next_header(&mut self) -> Result<Option<&'r str>, DeserializeError>76     fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError>;
77 
78     /// Extracts the next raw byte header value from the underlying record.
next_header_bytes( &mut self, ) -> Result<Option<&'r [u8]>, DeserializeError>79     fn next_header_bytes(
80         &mut self,
81     ) -> Result<Option<&'r [u8]>, DeserializeError>;
82 
83     /// Extracts the next string field from the underlying record.
next_field(&mut self) -> Result<&'r str, DeserializeError>84     fn next_field(&mut self) -> Result<&'r str, DeserializeError>;
85 
86     /// Extracts the next raw byte field from the underlying record.
next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError>87     fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError>;
88 
89     /// Peeks at the next field from the underlying record.
peek_field(&mut self) -> Option<&'r [u8]>90     fn peek_field(&mut self) -> Option<&'r [u8]>;
91 
92     /// Returns an error corresponding to the most recently extracted field.
error(&self, kind: DeserializeErrorKind) -> DeserializeError93     fn error(&self, kind: DeserializeErrorKind) -> DeserializeError;
94 
95     /// Infer the type of the next field and deserialize it.
infer_deserialize<'de, V: Visitor<'de>>( &mut self, visitor: V, ) -> Result<V::Value, DeserializeError>96     fn infer_deserialize<'de, V: Visitor<'de>>(
97         &mut self,
98         visitor: V,
99     ) -> Result<V::Value, DeserializeError>;
100 }
101 
102 struct DeRecordWrap<T>(T);
103 
104 impl<'r, T: DeRecord<'r>> DeRecord<'r> for DeRecordWrap<T> {
105     #[inline]
has_headers(&self) -> bool106     fn has_headers(&self) -> bool {
107         self.0.has_headers()
108     }
109 
110     #[inline]
next_header(&mut self) -> Result<Option<&'r str>, DeserializeError>111     fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
112         self.0.next_header()
113     }
114 
115     #[inline]
next_header_bytes( &mut self, ) -> Result<Option<&'r [u8]>, DeserializeError>116     fn next_header_bytes(
117         &mut self,
118     ) -> Result<Option<&'r [u8]>, DeserializeError> {
119         self.0.next_header_bytes()
120     }
121 
122     #[inline]
next_field(&mut self) -> Result<&'r str, DeserializeError>123     fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
124         self.0.next_field()
125     }
126 
127     #[inline]
next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError>128     fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
129         self.0.next_field_bytes()
130     }
131 
132     #[inline]
peek_field(&mut self) -> Option<&'r [u8]>133     fn peek_field(&mut self) -> Option<&'r [u8]> {
134         self.0.peek_field()
135     }
136 
137     #[inline]
error(&self, kind: DeserializeErrorKind) -> DeserializeError138     fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
139         self.0.error(kind)
140     }
141 
142     #[inline]
infer_deserialize<'de, V: Visitor<'de>>( &mut self, visitor: V, ) -> Result<V::Value, DeserializeError>143     fn infer_deserialize<'de, V: Visitor<'de>>(
144         &mut self,
145         visitor: V,
146     ) -> Result<V::Value, DeserializeError> {
147         self.0.infer_deserialize(visitor)
148     }
149 }
150 
151 struct DeStringRecord<'r> {
152     it: iter::Peekable<StringRecordIter<'r>>,
153     headers: Option<StringRecordIter<'r>>,
154     field: u64,
155 }
156 
157 impl<'r> DeRecord<'r> for DeStringRecord<'r> {
158     #[inline]
has_headers(&self) -> bool159     fn has_headers(&self) -> bool {
160         self.headers.is_some()
161     }
162 
163     #[inline]
next_header(&mut self) -> Result<Option<&'r str>, DeserializeError>164     fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
165         Ok(self.headers.as_mut().and_then(|it| it.next()))
166     }
167 
168     #[inline]
next_header_bytes( &mut self, ) -> Result<Option<&'r [u8]>, DeserializeError>169     fn next_header_bytes(
170         &mut self,
171     ) -> Result<Option<&'r [u8]>, DeserializeError> {
172         Ok(self.next_header()?.map(|s| s.as_bytes()))
173     }
174 
175     #[inline]
next_field(&mut self) -> Result<&'r str, DeserializeError>176     fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
177         match self.it.next() {
178             Some(field) => {
179                 self.field += 1;
180                 Ok(field)
181             }
182             None => Err(DeserializeError {
183                 field: None,
184                 kind: DEK::UnexpectedEndOfRow,
185             }),
186         }
187     }
188 
189     #[inline]
next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError>190     fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
191         self.next_field().map(|s| s.as_bytes())
192     }
193 
194     #[inline]
peek_field(&mut self) -> Option<&'r [u8]>195     fn peek_field(&mut self) -> Option<&'r [u8]> {
196         self.it.peek().map(|s| s.as_bytes())
197     }
198 
error(&self, kind: DeserializeErrorKind) -> DeserializeError199     fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
200         DeserializeError { field: Some(self.field.saturating_sub(1)), kind }
201     }
202 
infer_deserialize<'de, V: Visitor<'de>>( &mut self, visitor: V, ) -> Result<V::Value, DeserializeError>203     fn infer_deserialize<'de, V: Visitor<'de>>(
204         &mut self,
205         visitor: V,
206     ) -> Result<V::Value, DeserializeError> {
207         let x = self.next_field()?;
208         if x == "true" {
209             return visitor.visit_bool(true);
210         } else if x == "false" {
211             return visitor.visit_bool(false);
212         } else if let Some(n) = try_positive_integer64(x) {
213             return visitor.visit_u64(n);
214         } else if let Some(n) = try_negative_integer64(x) {
215             return visitor.visit_i64(n);
216         }
217         serde_if_integer128! {
218             if let Some(n) = try_positive_integer128(x) {
219                 return visitor.visit_u128(n);
220             } else if let Some(n) = try_negative_integer128(x) {
221                 return visitor.visit_i128(n);
222             }
223         }
224         if let Some(n) = try_float(x) {
225             visitor.visit_f64(n)
226         } else {
227             visitor.visit_str(x)
228         }
229     }
230 }
231 
232 struct DeByteRecord<'r> {
233     it: iter::Peekable<ByteRecordIter<'r>>,
234     headers: Option<ByteRecordIter<'r>>,
235     field: u64,
236 }
237 
238 impl<'r> DeRecord<'r> for DeByteRecord<'r> {
239     #[inline]
has_headers(&self) -> bool240     fn has_headers(&self) -> bool {
241         self.headers.is_some()
242     }
243 
244     #[inline]
next_header(&mut self) -> Result<Option<&'r str>, DeserializeError>245     fn next_header(&mut self) -> Result<Option<&'r str>, DeserializeError> {
246         match self.next_header_bytes() {
247             Ok(Some(field)) => Ok(Some(
248                 str::from_utf8(field)
249                     .map_err(|err| self.error(DEK::InvalidUtf8(err)))?,
250             )),
251             Ok(None) => Ok(None),
252             Err(err) => Err(err),
253         }
254     }
255 
256     #[inline]
next_header_bytes( &mut self, ) -> Result<Option<&'r [u8]>, DeserializeError>257     fn next_header_bytes(
258         &mut self,
259     ) -> Result<Option<&'r [u8]>, DeserializeError> {
260         Ok(self.headers.as_mut().and_then(|it| it.next()))
261     }
262 
263     #[inline]
next_field(&mut self) -> Result<&'r str, DeserializeError>264     fn next_field(&mut self) -> Result<&'r str, DeserializeError> {
265         self.next_field_bytes().and_then(|field| {
266             str::from_utf8(field)
267                 .map_err(|err| self.error(DEK::InvalidUtf8(err)))
268         })
269     }
270 
271     #[inline]
next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError>272     fn next_field_bytes(&mut self) -> Result<&'r [u8], DeserializeError> {
273         match self.it.next() {
274             Some(field) => {
275                 self.field += 1;
276                 Ok(field)
277             }
278             None => Err(DeserializeError {
279                 field: None,
280                 kind: DEK::UnexpectedEndOfRow,
281             }),
282         }
283     }
284 
285     #[inline]
peek_field(&mut self) -> Option<&'r [u8]>286     fn peek_field(&mut self) -> Option<&'r [u8]> {
287         self.it.peek().map(|s| *s)
288     }
289 
error(&self, kind: DeserializeErrorKind) -> DeserializeError290     fn error(&self, kind: DeserializeErrorKind) -> DeserializeError {
291         DeserializeError { field: Some(self.field.saturating_sub(1)), kind }
292     }
293 
infer_deserialize<'de, V: Visitor<'de>>( &mut self, visitor: V, ) -> Result<V::Value, DeserializeError>294     fn infer_deserialize<'de, V: Visitor<'de>>(
295         &mut self,
296         visitor: V,
297     ) -> Result<V::Value, DeserializeError> {
298         let x = self.next_field_bytes()?;
299         if x == b"true" {
300             return visitor.visit_bool(true);
301         } else if x == b"false" {
302             return visitor.visit_bool(false);
303         } else if let Some(n) = try_positive_integer64_bytes(x) {
304             return visitor.visit_u64(n);
305         } else if let Some(n) = try_negative_integer64_bytes(x) {
306             return visitor.visit_i64(n);
307         }
308         serde_if_integer128! {
309             if let Some(n) = try_positive_integer128_bytes(x) {
310                 return visitor.visit_u128(n);
311             } else if let Some(n) = try_negative_integer128_bytes(x) {
312                 return visitor.visit_i128(n);
313             }
314         }
315         if let Some(n) = try_float_bytes(x) {
316             visitor.visit_f64(n)
317         } else if let Ok(s) = str::from_utf8(x) {
318             visitor.visit_str(s)
319         } else {
320             visitor.visit_bytes(x)
321         }
322     }
323 }
324 
325 macro_rules! deserialize_int {
326     ($method:ident, $visit:ident, $inttype:ty) => {
327         fn $method<V: Visitor<'de>>(
328             self,
329             visitor: V,
330         ) -> Result<V::Value, Self::Error> {
331             let field = self.next_field()?;
332             let num = if field.starts_with("0x") {
333                 <$inttype>::from_str_radix(&field[2..], 16)
334             } else {
335                 field.parse()
336             };
337             visitor.$visit(num.map_err(|err| self.error(DEK::ParseInt(err)))?)
338         }
339     };
340 }
341 
342 impl<'a, 'de: 'a, T: DeRecord<'de>> Deserializer<'de>
343     for &'a mut DeRecordWrap<T>
344 {
345     type Error = DeserializeError;
346 
deserialize_any<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>347     fn deserialize_any<V: Visitor<'de>>(
348         self,
349         visitor: V,
350     ) -> Result<V::Value, Self::Error> {
351         self.infer_deserialize(visitor)
352     }
353 
deserialize_bool<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>354     fn deserialize_bool<V: Visitor<'de>>(
355         self,
356         visitor: V,
357     ) -> Result<V::Value, Self::Error> {
358         visitor.visit_bool(
359             self.next_field()?
360                 .parse()
361                 .map_err(|err| self.error(DEK::ParseBool(err)))?,
362         )
363     }
364 
365     deserialize_int!(deserialize_u8, visit_u8, u8);
366     deserialize_int!(deserialize_u16, visit_u16, u16);
367     deserialize_int!(deserialize_u32, visit_u32, u32);
368     deserialize_int!(deserialize_u64, visit_u64, u64);
369     serde_if_integer128! {
370         deserialize_int!(deserialize_u128, visit_u128, u128);
371     }
372     deserialize_int!(deserialize_i8, visit_i8, i8);
373     deserialize_int!(deserialize_i16, visit_i16, i16);
374     deserialize_int!(deserialize_i32, visit_i32, i32);
375     deserialize_int!(deserialize_i64, visit_i64, i64);
376     serde_if_integer128! {
377         deserialize_int!(deserialize_i128, visit_i128, i128);
378     }
379 
deserialize_f32<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>380     fn deserialize_f32<V: Visitor<'de>>(
381         self,
382         visitor: V,
383     ) -> Result<V::Value, Self::Error> {
384         visitor.visit_f32(
385             self.next_field()?
386                 .parse()
387                 .map_err(|err| self.error(DEK::ParseFloat(err)))?,
388         )
389     }
390 
deserialize_f64<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>391     fn deserialize_f64<V: Visitor<'de>>(
392         self,
393         visitor: V,
394     ) -> Result<V::Value, Self::Error> {
395         visitor.visit_f64(
396             self.next_field()?
397                 .parse()
398                 .map_err(|err| self.error(DEK::ParseFloat(err)))?,
399         )
400     }
401 
deserialize_char<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>402     fn deserialize_char<V: Visitor<'de>>(
403         self,
404         visitor: V,
405     ) -> Result<V::Value, Self::Error> {
406         let field = self.next_field()?;
407         let len = field.chars().count();
408         if len != 1 {
409             return Err(self.error(DEK::Message(format!(
410                 "expected single character but got {} characters in '{}'",
411                 len, field
412             ))));
413         }
414         visitor.visit_char(field.chars().next().unwrap())
415     }
416 
deserialize_str<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>417     fn deserialize_str<V: Visitor<'de>>(
418         self,
419         visitor: V,
420     ) -> Result<V::Value, Self::Error> {
421         self.next_field().and_then(|f| visitor.visit_borrowed_str(f))
422     }
423 
deserialize_string<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>424     fn deserialize_string<V: Visitor<'de>>(
425         self,
426         visitor: V,
427     ) -> Result<V::Value, Self::Error> {
428         self.next_field().and_then(|f| visitor.visit_str(f.into()))
429     }
430 
deserialize_bytes<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>431     fn deserialize_bytes<V: Visitor<'de>>(
432         self,
433         visitor: V,
434     ) -> Result<V::Value, Self::Error> {
435         self.next_field_bytes().and_then(|f| visitor.visit_borrowed_bytes(f))
436     }
437 
deserialize_byte_buf<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>438     fn deserialize_byte_buf<V: Visitor<'de>>(
439         self,
440         visitor: V,
441     ) -> Result<V::Value, Self::Error> {
442         self.next_field_bytes()
443             .and_then(|f| visitor.visit_byte_buf(f.to_vec()))
444     }
445 
deserialize_option<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>446     fn deserialize_option<V: Visitor<'de>>(
447         self,
448         visitor: V,
449     ) -> Result<V::Value, Self::Error> {
450         match self.peek_field() {
451             None => visitor.visit_none(),
452             Some(f) if f.is_empty() => {
453                 self.next_field().expect("empty field");
454                 visitor.visit_none()
455             }
456             Some(_) => visitor.visit_some(self),
457         }
458     }
459 
deserialize_unit<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>460     fn deserialize_unit<V: Visitor<'de>>(
461         self,
462         visitor: V,
463     ) -> Result<V::Value, Self::Error> {
464         visitor.visit_unit()
465     }
466 
deserialize_unit_struct<V: Visitor<'de>>( self, _name: &'static str, visitor: V, ) -> Result<V::Value, Self::Error>467     fn deserialize_unit_struct<V: Visitor<'de>>(
468         self,
469         _name: &'static str,
470         visitor: V,
471     ) -> Result<V::Value, Self::Error> {
472         visitor.visit_unit()
473     }
474 
deserialize_newtype_struct<V: Visitor<'de>>( self, _name: &'static str, visitor: V, ) -> Result<V::Value, Self::Error>475     fn deserialize_newtype_struct<V: Visitor<'de>>(
476         self,
477         _name: &'static str,
478         visitor: V,
479     ) -> Result<V::Value, Self::Error> {
480         visitor.visit_newtype_struct(self)
481     }
482 
deserialize_seq<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>483     fn deserialize_seq<V: Visitor<'de>>(
484         self,
485         visitor: V,
486     ) -> Result<V::Value, Self::Error> {
487         visitor.visit_seq(self)
488     }
489 
deserialize_tuple<V: Visitor<'de>>( self, _len: usize, visitor: V, ) -> Result<V::Value, Self::Error>490     fn deserialize_tuple<V: Visitor<'de>>(
491         self,
492         _len: usize,
493         visitor: V,
494     ) -> Result<V::Value, Self::Error> {
495         visitor.visit_seq(self)
496     }
497 
deserialize_tuple_struct<V: Visitor<'de>>( self, _name: &'static str, _len: usize, visitor: V, ) -> Result<V::Value, Self::Error>498     fn deserialize_tuple_struct<V: Visitor<'de>>(
499         self,
500         _name: &'static str,
501         _len: usize,
502         visitor: V,
503     ) -> Result<V::Value, Self::Error> {
504         visitor.visit_seq(self)
505     }
506 
deserialize_map<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>507     fn deserialize_map<V: Visitor<'de>>(
508         self,
509         visitor: V,
510     ) -> Result<V::Value, Self::Error> {
511         if !self.has_headers() {
512             visitor.visit_seq(self)
513         } else {
514             visitor.visit_map(self)
515         }
516     }
517 
deserialize_struct<V: Visitor<'de>>( self, _name: &'static str, _fields: &'static [&'static str], visitor: V, ) -> Result<V::Value, Self::Error>518     fn deserialize_struct<V: Visitor<'de>>(
519         self,
520         _name: &'static str,
521         _fields: &'static [&'static str],
522         visitor: V,
523     ) -> Result<V::Value, Self::Error> {
524         if !self.has_headers() {
525             visitor.visit_seq(self)
526         } else {
527             visitor.visit_map(self)
528         }
529     }
530 
deserialize_identifier<V: Visitor<'de>>( self, _visitor: V, ) -> Result<V::Value, Self::Error>531     fn deserialize_identifier<V: Visitor<'de>>(
532         self,
533         _visitor: V,
534     ) -> Result<V::Value, Self::Error> {
535         Err(self.error(DEK::Unsupported("deserialize_identifier".into())))
536     }
537 
deserialize_enum<V: Visitor<'de>>( self, _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result<V::Value, Self::Error>538     fn deserialize_enum<V: Visitor<'de>>(
539         self,
540         _name: &'static str,
541         _variants: &'static [&'static str],
542         visitor: V,
543     ) -> Result<V::Value, Self::Error> {
544         visitor.visit_enum(self)
545     }
546 
deserialize_ignored_any<V: Visitor<'de>>( self, visitor: V, ) -> Result<V::Value, Self::Error>547     fn deserialize_ignored_any<V: Visitor<'de>>(
548         self,
549         visitor: V,
550     ) -> Result<V::Value, Self::Error> {
551         // Read and drop the next field.
552         // This code is reached, e.g., when trying to deserialize a header
553         // that doesn't exist in the destination struct.
554         let _ = self.next_field_bytes()?;
555         visitor.visit_unit()
556     }
557 }
558 
559 impl<'a, 'de: 'a, T: DeRecord<'de>> EnumAccess<'de>
560     for &'a mut DeRecordWrap<T>
561 {
562     type Error = DeserializeError;
563     type Variant = Self;
564 
variant_seed<V: DeserializeSeed<'de>>( self, seed: V, ) -> Result<(V::Value, Self::Variant), Self::Error>565     fn variant_seed<V: DeserializeSeed<'de>>(
566         self,
567         seed: V,
568     ) -> Result<(V::Value, Self::Variant), Self::Error> {
569         let variant_name = self.next_field()?;
570         seed.deserialize(variant_name.into_deserializer()).map(|v| (v, self))
571     }
572 }
573 
574 impl<'a, 'de: 'a, T: DeRecord<'de>> VariantAccess<'de>
575     for &'a mut DeRecordWrap<T>
576 {
577     type Error = DeserializeError;
578 
unit_variant(self) -> Result<(), Self::Error>579     fn unit_variant(self) -> Result<(), Self::Error> {
580         Ok(())
581     }
582 
newtype_variant_seed<U: DeserializeSeed<'de>>( self, _seed: U, ) -> Result<U::Value, Self::Error>583     fn newtype_variant_seed<U: DeserializeSeed<'de>>(
584         self,
585         _seed: U,
586     ) -> Result<U::Value, Self::Error> {
587         let unexp = Unexpected::UnitVariant;
588         Err(DeserializeError::invalid_type(unexp, &"newtype variant"))
589     }
590 
tuple_variant<V: Visitor<'de>>( self, _len: usize, _visitor: V, ) -> Result<V::Value, Self::Error>591     fn tuple_variant<V: Visitor<'de>>(
592         self,
593         _len: usize,
594         _visitor: V,
595     ) -> Result<V::Value, Self::Error> {
596         let unexp = Unexpected::UnitVariant;
597         Err(DeserializeError::invalid_type(unexp, &"tuple variant"))
598     }
599 
struct_variant<V: Visitor<'de>>( self, _fields: &'static [&'static str], _visitor: V, ) -> Result<V::Value, Self::Error>600     fn struct_variant<V: Visitor<'de>>(
601         self,
602         _fields: &'static [&'static str],
603         _visitor: V,
604     ) -> Result<V::Value, Self::Error> {
605         let unexp = Unexpected::UnitVariant;
606         Err(DeserializeError::invalid_type(unexp, &"struct variant"))
607     }
608 }
609 
610 impl<'a, 'de: 'a, T: DeRecord<'de>> SeqAccess<'de>
611     for &'a mut DeRecordWrap<T>
612 {
613     type Error = DeserializeError;
614 
next_element_seed<U: DeserializeSeed<'de>>( &mut self, seed: U, ) -> Result<Option<U::Value>, Self::Error>615     fn next_element_seed<U: DeserializeSeed<'de>>(
616         &mut self,
617         seed: U,
618     ) -> Result<Option<U::Value>, Self::Error> {
619         if self.peek_field().is_none() {
620             Ok(None)
621         } else {
622             seed.deserialize(&mut **self).map(Some)
623         }
624     }
625 }
626 
627 impl<'a, 'de: 'a, T: DeRecord<'de>> MapAccess<'de>
628     for &'a mut DeRecordWrap<T>
629 {
630     type Error = DeserializeError;
631 
next_key_seed<K: DeserializeSeed<'de>>( &mut self, seed: K, ) -> Result<Option<K::Value>, Self::Error>632     fn next_key_seed<K: DeserializeSeed<'de>>(
633         &mut self,
634         seed: K,
635     ) -> Result<Option<K::Value>, Self::Error> {
636         assert!(self.has_headers());
637         let field = match self.next_header_bytes()? {
638             None => return Ok(None),
639             Some(field) => field,
640         };
641         seed.deserialize(BorrowedBytesDeserializer::new(field)).map(Some)
642     }
643 
next_value_seed<K: DeserializeSeed<'de>>( &mut self, seed: K, ) -> Result<K::Value, Self::Error>644     fn next_value_seed<K: DeserializeSeed<'de>>(
645         &mut self,
646         seed: K,
647     ) -> Result<K::Value, Self::Error> {
648         seed.deserialize(&mut **self)
649     }
650 }
651 
652 /// An Serde deserialization error.
653 #[derive(Clone, Debug, Eq, PartialEq)]
654 pub struct DeserializeError {
655     field: Option<u64>,
656     kind: DeserializeErrorKind,
657 }
658 
659 /// The type of a Serde deserialization error.
660 #[derive(Clone, Debug, Eq, PartialEq)]
661 pub enum DeserializeErrorKind {
662     /// A generic Serde deserialization error.
663     Message(String),
664     /// A generic Serde unsupported error.
665     Unsupported(String),
666     /// This error occurs when a Rust type expects to decode another field
667     /// from a row, but no more fields exist.
668     UnexpectedEndOfRow,
669     /// This error occurs when UTF-8 validation on a field fails. UTF-8
670     /// validation is only performed when the Rust type requires it (e.g.,
671     /// a `String` or `&str` type).
672     InvalidUtf8(str::Utf8Error),
673     /// This error occurs when a boolean value fails to parse.
674     ParseBool(str::ParseBoolError),
675     /// This error occurs when an integer value fails to parse.
676     ParseInt(num::ParseIntError),
677     /// This error occurs when a float value fails to parse.
678     ParseFloat(num::ParseFloatError),
679 }
680 
681 impl SerdeError for DeserializeError {
custom<T: fmt::Display>(msg: T) -> DeserializeError682     fn custom<T: fmt::Display>(msg: T) -> DeserializeError {
683         DeserializeError { field: None, kind: DEK::Message(msg.to_string()) }
684     }
685 }
686 
687 impl StdError for DeserializeError {
description(&self) -> &str688     fn description(&self) -> &str {
689         self.kind.description()
690     }
691 }
692 
693 impl fmt::Display for DeserializeError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result694     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
695         if let Some(field) = self.field {
696             write!(f, "field {}: {}", field, self.kind)
697         } else {
698             write!(f, "{}", self.kind)
699         }
700     }
701 }
702 
703 impl fmt::Display for DeserializeErrorKind {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result704     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
705         use self::DeserializeErrorKind::*;
706 
707         match *self {
708             Message(ref msg) => write!(f, "{}", msg),
709             Unsupported(ref which) => {
710                 write!(f, "unsupported deserializer method: {}", which)
711             }
712             UnexpectedEndOfRow => write!(f, "{}", self.description()),
713             InvalidUtf8(ref err) => err.fmt(f),
714             ParseBool(ref err) => err.fmt(f),
715             ParseInt(ref err) => err.fmt(f),
716             ParseFloat(ref err) => err.fmt(f),
717         }
718     }
719 }
720 
721 impl DeserializeError {
722     /// Return the field index (starting at 0) of this error, if available.
field(&self) -> Option<u64>723     pub fn field(&self) -> Option<u64> {
724         self.field
725     }
726 
727     /// Return the underlying error kind.
kind(&self) -> &DeserializeErrorKind728     pub fn kind(&self) -> &DeserializeErrorKind {
729         &self.kind
730     }
731 }
732 
733 impl DeserializeErrorKind {
734     #[allow(deprecated)]
description(&self) -> &str735     fn description(&self) -> &str {
736         use self::DeserializeErrorKind::*;
737 
738         match *self {
739             Message(_) => "deserialization error",
740             Unsupported(_) => "unsupported deserializer method",
741             UnexpectedEndOfRow => "expected field, but got end of row",
742             InvalidUtf8(ref err) => err.description(),
743             ParseBool(ref err) => err.description(),
744             ParseInt(ref err) => err.description(),
745             ParseFloat(ref err) => err.description(),
746         }
747     }
748 }
749 
750 serde_if_integer128! {
751     fn try_positive_integer128(s: &str) -> Option<u128> {
752         s.parse().ok()
753     }
754 
755     fn try_negative_integer128(s: &str) -> Option<i128> {
756         s.parse().ok()
757     }
758 }
759 
try_positive_integer64(s: &str) -> Option<u64>760 fn try_positive_integer64(s: &str) -> Option<u64> {
761     s.parse().ok()
762 }
763 
try_negative_integer64(s: &str) -> Option<i64>764 fn try_negative_integer64(s: &str) -> Option<i64> {
765     s.parse().ok()
766 }
767 
try_float(s: &str) -> Option<f64>768 fn try_float(s: &str) -> Option<f64> {
769     s.parse().ok()
770 }
771 
try_positive_integer64_bytes(s: &[u8]) -> Option<u64>772 fn try_positive_integer64_bytes(s: &[u8]) -> Option<u64> {
773     str::from_utf8(s).ok().and_then(|s| s.parse().ok())
774 }
775 
try_negative_integer64_bytes(s: &[u8]) -> Option<i64>776 fn try_negative_integer64_bytes(s: &[u8]) -> Option<i64> {
777     str::from_utf8(s).ok().and_then(|s| s.parse().ok())
778 }
779 
780 serde_if_integer128! {
781     fn try_positive_integer128_bytes(s: &[u8]) -> Option<u128> {
782         str::from_utf8(s).ok().and_then(|s| s.parse().ok())
783     }
784 
785     fn try_negative_integer128_bytes(s: &[u8]) -> Option<i128> {
786         str::from_utf8(s).ok().and_then(|s| s.parse().ok())
787     }
788 }
789 
try_float_bytes(s: &[u8]) -> Option<f64>790 fn try_float_bytes(s: &[u8]) -> Option<f64> {
791     str::from_utf8(s).ok().and_then(|s| s.parse().ok())
792 }
793 
794 #[cfg(test)]
795 mod tests {
796     use std::collections::HashMap;
797 
798     use {
799         bstr::BString,
800         serde::{de::DeserializeOwned, serde_if_integer128, Deserialize},
801     };
802 
803     use crate::{
804         byte_record::ByteRecord, error::Error, string_record::StringRecord,
805     };
806 
807     use super::{deserialize_byte_record, deserialize_string_record};
808 
de<D: DeserializeOwned>(fields: &[&str]) -> Result<D, Error>809     fn de<D: DeserializeOwned>(fields: &[&str]) -> Result<D, Error> {
810         let record = StringRecord::from(fields);
811         deserialize_string_record(&record, None)
812     }
813 
de_headers<D: DeserializeOwned>( headers: &[&str], fields: &[&str], ) -> Result<D, Error>814     fn de_headers<D: DeserializeOwned>(
815         headers: &[&str],
816         fields: &[&str],
817     ) -> Result<D, Error> {
818         let headers = StringRecord::from(headers);
819         let record = StringRecord::from(fields);
820         deserialize_string_record(&record, Some(&headers))
821     }
822 
b<'a, T: AsRef<[u8]> + ?Sized>(bytes: &'a T) -> &'a [u8]823     fn b<'a, T: AsRef<[u8]> + ?Sized>(bytes: &'a T) -> &'a [u8] {
824         bytes.as_ref()
825     }
826 
827     #[test]
with_header()828     fn with_header() {
829         #[derive(Deserialize, Debug, PartialEq)]
830         struct Foo {
831             z: f64,
832             y: i32,
833             x: String,
834         }
835 
836         let got: Foo =
837             de_headers(&["x", "y", "z"], &["hi", "42", "1.3"]).unwrap();
838         assert_eq!(got, Foo { x: "hi".into(), y: 42, z: 1.3 });
839     }
840 
841     #[test]
with_header_unknown()842     fn with_header_unknown() {
843         #[derive(Deserialize, Debug, PartialEq)]
844         #[serde(deny_unknown_fields)]
845         struct Foo {
846             z: f64,
847             y: i32,
848             x: String,
849         }
850         assert!(de_headers::<Foo>(
851             &["a", "x", "y", "z"],
852             &["foo", "hi", "42", "1.3"],
853         )
854         .is_err());
855     }
856 
857     #[test]
with_header_missing()858     fn with_header_missing() {
859         #[derive(Deserialize, Debug, PartialEq)]
860         struct Foo {
861             z: f64,
862             y: i32,
863             x: String,
864         }
865         assert!(de_headers::<Foo>(&["y", "z"], &["42", "1.3"],).is_err());
866     }
867 
868     #[test]
with_header_missing_ok()869     fn with_header_missing_ok() {
870         #[derive(Deserialize, Debug, PartialEq)]
871         struct Foo {
872             z: f64,
873             y: i32,
874             x: Option<String>,
875         }
876 
877         let got: Foo = de_headers(&["y", "z"], &["42", "1.3"]).unwrap();
878         assert_eq!(got, Foo { x: None, y: 42, z: 1.3 });
879     }
880 
881     #[test]
with_header_no_fields()882     fn with_header_no_fields() {
883         #[derive(Deserialize, Debug, PartialEq)]
884         struct Foo {
885             z: f64,
886             y: i32,
887             x: Option<String>,
888         }
889 
890         let got = de_headers::<Foo>(&["y", "z"], &[]);
891         assert!(got.is_err());
892     }
893 
894     #[test]
with_header_empty()895     fn with_header_empty() {
896         #[derive(Deserialize, Debug, PartialEq)]
897         struct Foo {
898             z: f64,
899             y: i32,
900             x: Option<String>,
901         }
902 
903         let got = de_headers::<Foo>(&[], &[]);
904         assert!(got.is_err());
905     }
906 
907     #[test]
with_header_empty_ok()908     fn with_header_empty_ok() {
909         #[derive(Deserialize, Debug, PartialEq)]
910         struct Foo;
911 
912         #[derive(Deserialize, Debug, PartialEq)]
913         struct Bar {}
914 
915         let got = de_headers::<Foo>(&[], &[]);
916         assert_eq!(got.unwrap(), Foo);
917 
918         let got = de_headers::<Bar>(&[], &[]);
919         assert_eq!(got.unwrap(), Bar {});
920 
921         let got = de_headers::<()>(&[], &[]);
922         assert_eq!(got.unwrap(), ());
923     }
924 
925     #[test]
without_header()926     fn without_header() {
927         #[derive(Deserialize, Debug, PartialEq)]
928         struct Foo {
929             z: f64,
930             y: i32,
931             x: String,
932         }
933 
934         let got: Foo = de(&["1.3", "42", "hi"]).unwrap();
935         assert_eq!(got, Foo { x: "hi".into(), y: 42, z: 1.3 });
936     }
937 
938     #[test]
no_fields()939     fn no_fields() {
940         assert!(de::<String>(&[]).is_err());
941     }
942 
943     #[test]
one_field()944     fn one_field() {
945         let got: i32 = de(&["42"]).unwrap();
946         assert_eq!(got, 42);
947     }
948 
949     serde_if_integer128! {
950         #[test]
951         fn one_field_128() {
952             let got: i128 = de(&["2010223372036854775808"]).unwrap();
953             assert_eq!(got, 2010223372036854775808);
954         }
955     }
956 
957     #[test]
two_fields()958     fn two_fields() {
959         let got: (i32, bool) = de(&["42", "true"]).unwrap();
960         assert_eq!(got, (42, true));
961 
962         #[derive(Deserialize, Debug, PartialEq)]
963         struct Foo(i32, bool);
964 
965         let got: Foo = de(&["42", "true"]).unwrap();
966         assert_eq!(got, Foo(42, true));
967     }
968 
969     #[test]
two_fields_too_many()970     fn two_fields_too_many() {
971         let got: (i32, bool) = de(&["42", "true", "z", "z"]).unwrap();
972         assert_eq!(got, (42, true));
973     }
974 
975     #[test]
two_fields_too_few()976     fn two_fields_too_few() {
977         assert!(de::<(i32, bool)>(&["42"]).is_err());
978     }
979 
980     #[test]
one_char()981     fn one_char() {
982         let got: char = de(&["a"]).unwrap();
983         assert_eq!(got, 'a');
984     }
985 
986     #[test]
no_chars()987     fn no_chars() {
988         assert!(de::<char>(&[""]).is_err());
989     }
990 
991     #[test]
too_many_chars()992     fn too_many_chars() {
993         assert!(de::<char>(&["ab"]).is_err());
994     }
995 
996     #[test]
simple_seq()997     fn simple_seq() {
998         let got: Vec<i32> = de(&["1", "5", "10"]).unwrap();
999         assert_eq!(got, vec![1, 5, 10]);
1000     }
1001 
1002     #[test]
simple_hex_seq()1003     fn simple_hex_seq() {
1004         let got: Vec<i32> = de(&["0x7F", "0xA9", "0x10"]).unwrap();
1005         assert_eq!(got, vec![0x7F, 0xA9, 0x10]);
1006     }
1007 
1008     #[test]
mixed_hex_seq()1009     fn mixed_hex_seq() {
1010         let got: Vec<i32> = de(&["0x7F", "0xA9", "10"]).unwrap();
1011         assert_eq!(got, vec![0x7F, 0xA9, 10]);
1012     }
1013 
1014     #[test]
bad_hex_seq()1015     fn bad_hex_seq() {
1016         assert!(de::<Vec<u8>>(&["7F", "0xA9", "10"]).is_err());
1017     }
1018 
1019     #[test]
seq_in_struct()1020     fn seq_in_struct() {
1021         #[derive(Deserialize, Debug, PartialEq)]
1022         struct Foo {
1023             xs: Vec<i32>,
1024         }
1025         let got: Foo = de(&["1", "5", "10"]).unwrap();
1026         assert_eq!(got, Foo { xs: vec![1, 5, 10] });
1027     }
1028 
1029     #[test]
seq_in_struct_tail()1030     fn seq_in_struct_tail() {
1031         #[derive(Deserialize, Debug, PartialEq)]
1032         struct Foo {
1033             label: String,
1034             xs: Vec<i32>,
1035         }
1036         let got: Foo = de(&["foo", "1", "5", "10"]).unwrap();
1037         assert_eq!(got, Foo { label: "foo".into(), xs: vec![1, 5, 10] });
1038     }
1039 
1040     #[test]
map_headers()1041     fn map_headers() {
1042         let got: HashMap<String, i32> =
1043             de_headers(&["a", "b", "c"], &["1", "5", "10"]).unwrap();
1044         assert_eq!(got.len(), 3);
1045         assert_eq!(got["a"], 1);
1046         assert_eq!(got["b"], 5);
1047         assert_eq!(got["c"], 10);
1048     }
1049 
1050     #[test]
map_no_headers()1051     fn map_no_headers() {
1052         let got = de::<HashMap<String, i32>>(&["1", "5", "10"]);
1053         assert!(got.is_err());
1054     }
1055 
1056     #[test]
bytes()1057     fn bytes() {
1058         let got: Vec<u8> = de::<BString>(&["foobar"]).unwrap().into();
1059         assert_eq!(got, b"foobar".to_vec());
1060     }
1061 
1062     #[test]
adjacent_fixed_arrays()1063     fn adjacent_fixed_arrays() {
1064         let got: ([u32; 2], [u32; 2]) = de(&["1", "5", "10", "15"]).unwrap();
1065         assert_eq!(got, ([1, 5], [10, 15]));
1066     }
1067 
1068     #[test]
enum_label_simple_tagged()1069     fn enum_label_simple_tagged() {
1070         #[derive(Deserialize, Debug, PartialEq)]
1071         struct Row {
1072             label: Label,
1073             x: f64,
1074         }
1075 
1076         #[derive(Deserialize, Debug, PartialEq)]
1077         #[serde(rename_all = "snake_case")]
1078         enum Label {
1079             Foo,
1080             Bar,
1081             Baz,
1082         }
1083 
1084         let got: Row = de_headers(&["label", "x"], &["bar", "5"]).unwrap();
1085         assert_eq!(got, Row { label: Label::Bar, x: 5.0 });
1086     }
1087 
1088     #[test]
enum_untagged()1089     fn enum_untagged() {
1090         #[derive(Deserialize, Debug, PartialEq)]
1091         struct Row {
1092             x: Boolish,
1093             y: Boolish,
1094             z: Boolish,
1095         }
1096 
1097         #[derive(Deserialize, Debug, PartialEq)]
1098         #[serde(rename_all = "snake_case")]
1099         #[serde(untagged)]
1100         enum Boolish {
1101             Bool(bool),
1102             Number(i64),
1103             String(String),
1104         }
1105 
1106         let got: Row =
1107             de_headers(&["x", "y", "z"], &["true", "null", "1"]).unwrap();
1108         assert_eq!(
1109             got,
1110             Row {
1111                 x: Boolish::Bool(true),
1112                 y: Boolish::String("null".into()),
1113                 z: Boolish::Number(1),
1114             }
1115         );
1116     }
1117 
1118     #[test]
option_empty_field()1119     fn option_empty_field() {
1120         #[derive(Deserialize, Debug, PartialEq)]
1121         struct Foo {
1122             a: Option<i32>,
1123             b: String,
1124             c: Option<i32>,
1125         }
1126 
1127         let got: Foo =
1128             de_headers(&["a", "b", "c"], &["", "foo", "5"]).unwrap();
1129         assert_eq!(got, Foo { a: None, b: "foo".into(), c: Some(5) });
1130     }
1131 
1132     #[test]
option_invalid_field()1133     fn option_invalid_field() {
1134         #[derive(Deserialize, Debug, PartialEq)]
1135         struct Foo {
1136             #[serde(deserialize_with = "crate::invalid_option")]
1137             a: Option<i32>,
1138             #[serde(deserialize_with = "crate::invalid_option")]
1139             b: Option<i32>,
1140             #[serde(deserialize_with = "crate::invalid_option")]
1141             c: Option<i32>,
1142         }
1143 
1144         let got: Foo =
1145             de_headers(&["a", "b", "c"], &["xyz", "", "5"]).unwrap();
1146         assert_eq!(got, Foo { a: None, b: None, c: Some(5) });
1147     }
1148 
1149     #[test]
borrowed()1150     fn borrowed() {
1151         #[derive(Deserialize, Debug, PartialEq)]
1152         struct Foo<'a, 'c> {
1153             a: &'a str,
1154             b: i32,
1155             c: &'c str,
1156         }
1157 
1158         let headers = StringRecord::from(vec!["a", "b", "c"]);
1159         let record = StringRecord::from(vec!["foo", "5", "bar"]);
1160         let got: Foo =
1161             deserialize_string_record(&record, Some(&headers)).unwrap();
1162         assert_eq!(got, Foo { a: "foo", b: 5, c: "bar" });
1163     }
1164 
1165     #[test]
borrowed_map()1166     fn borrowed_map() {
1167         use std::collections::HashMap;
1168 
1169         let headers = StringRecord::from(vec!["a", "b", "c"]);
1170         let record = StringRecord::from(vec!["aardvark", "bee", "cat"]);
1171         let got: HashMap<&str, &str> =
1172             deserialize_string_record(&record, Some(&headers)).unwrap();
1173 
1174         let expected: HashMap<&str, &str> =
1175             headers.iter().zip(&record).collect();
1176         assert_eq!(got, expected);
1177     }
1178 
1179     #[test]
borrowed_map_bytes()1180     fn borrowed_map_bytes() {
1181         use std::collections::HashMap;
1182 
1183         let headers = ByteRecord::from(vec![b"a", b"\xFF", b"c"]);
1184         let record = ByteRecord::from(vec!["aardvark", "bee", "cat"]);
1185         let got: HashMap<&[u8], &[u8]> =
1186             deserialize_byte_record(&record, Some(&headers)).unwrap();
1187 
1188         let expected: HashMap<&[u8], &[u8]> =
1189             headers.iter().zip(&record).collect();
1190         assert_eq!(got, expected);
1191     }
1192 
1193     #[test]
flatten()1194     fn flatten() {
1195         #[derive(Deserialize, Debug, PartialEq)]
1196         struct Input {
1197             x: f64,
1198             y: f64,
1199         }
1200 
1201         #[derive(Deserialize, Debug, PartialEq)]
1202         struct Properties {
1203             prop1: f64,
1204             prop2: f64,
1205         }
1206 
1207         #[derive(Deserialize, Debug, PartialEq)]
1208         struct Row {
1209             #[serde(flatten)]
1210             input: Input,
1211             #[serde(flatten)]
1212             properties: Properties,
1213         }
1214 
1215         let header = StringRecord::from(vec!["x", "y", "prop1", "prop2"]);
1216         let record = StringRecord::from(vec!["1", "2", "3", "4"]);
1217         let got: Row = record.deserialize(Some(&header)).unwrap();
1218         assert_eq!(
1219             got,
1220             Row {
1221                 input: Input { x: 1.0, y: 2.0 },
1222                 properties: Properties { prop1: 3.0, prop2: 4.0 },
1223             }
1224         );
1225     }
1226 
1227     #[test]
partially_invalid_utf8()1228     fn partially_invalid_utf8() {
1229         #[derive(Debug, Deserialize, PartialEq)]
1230         struct Row {
1231             h1: String,
1232             h2: BString,
1233             h3: String,
1234         }
1235 
1236         let headers = ByteRecord::from(vec![b"h1", b"h2", b"h3"]);
1237         let record =
1238             ByteRecord::from(vec![b(b"baz"), b(b"foo\xFFbar"), b(b"quux")]);
1239         let got: Row =
1240             deserialize_byte_record(&record, Some(&headers)).unwrap();
1241         assert_eq!(
1242             got,
1243             Row {
1244                 h1: "baz".to_string(),
1245                 h2: BString::from(b"foo\xFFbar".to_vec()),
1246                 h3: "quux".to_string(),
1247             }
1248         );
1249     }
1250 }
1251