1 use crate::value::de::{MapDeserializer, MapRefDeserializer, SeqDeserializer, SeqRefDeserializer};
2 use crate::value::Value;
3 use crate::Error;
4 use serde::de::value::{BorrowedStrDeserializer, StrDeserializer};
5 use serde::de::{
6     Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error as _, VariantAccess, Visitor,
7 };
8 use serde::forward_to_deserialize_any;
9 use serde::ser::{Serialize, SerializeMap, Serializer};
10 use std::cmp::Ordering;
11 use std::fmt::{self, Debug, Display};
12 use std::hash::{Hash, Hasher};
13 use std::mem;
14 
15 /// A representation of YAML's `!Tag` syntax, used for enums.
16 ///
17 /// Refer to the example code on [`TaggedValue`] for an example of deserializing
18 /// tagged values.
19 #[derive(Clone)]
20 pub struct Tag {
21     pub(crate) string: String,
22 }
23 
24 /// A `Tag` + `Value` representing a tagged YAML scalar, sequence, or mapping.
25 ///
26 /// ```
27 /// use serde_yaml::value::TaggedValue;
28 /// use std::collections::BTreeMap;
29 ///
30 /// let yaml = "
31 ///     scalar: !Thing x
32 ///     sequence_flow: !Thing [first]
33 ///     sequence_block: !Thing
34 ///       - first
35 ///     mapping_flow: !Thing {k: v}
36 ///     mapping_block: !Thing
37 ///       k: v
38 /// ";
39 ///
40 /// let data: BTreeMap<String, TaggedValue> = serde_yaml::from_str(yaml).unwrap();
41 /// assert!(data["scalar"].tag == "Thing");
42 /// assert!(data["sequence_flow"].tag == "Thing");
43 /// assert!(data["sequence_block"].tag == "Thing");
44 /// assert!(data["mapping_flow"].tag == "Thing");
45 /// assert!(data["mapping_block"].tag == "Thing");
46 ///
47 /// // The leading '!' in tags are not significant. The following is also true.
48 /// assert!(data["scalar"].tag == "!Thing");
49 /// ```
50 #[derive(Clone, PartialEq, PartialOrd, Hash, Debug)]
51 pub struct TaggedValue {
52     #[allow(missing_docs)]
53     pub tag: Tag,
54     #[allow(missing_docs)]
55     pub value: Value,
56 }
57 
58 impl Tag {
59     /// Create tag.
60     ///
61     /// The leading '!' is not significant. It may be provided, but does not
62     /// have to be. The following are equivalent:
63     ///
64     /// ```
65     /// use serde_yaml::value::Tag;
66     ///
67     /// assert_eq!(Tag::new("!Thing"), Tag::new("Thing"));
68     ///
69     /// let tag = Tag::new("Thing");
70     /// assert!(tag == "Thing");
71     /// assert!(tag == "!Thing");
72     /// assert!(tag.to_string() == "!Thing");
73     ///
74     /// let tag = Tag::new("!Thing");
75     /// assert!(tag == "Thing");
76     /// assert!(tag == "!Thing");
77     /// assert!(tag.to_string() == "!Thing");
78     /// ```
79     ///
80     /// Such a tag would serialize to `!Thing` in YAML regardless of whether a
81     /// '!' was included in the call to `Tag::new`.
82     ///
83     /// # Panics
84     ///
85     /// Panics if `string.is_empty()`. There is no syntax in YAML for an empty
86     /// tag.
new(string: impl Into<String>) -> Self87     pub fn new(string: impl Into<String>) -> Self {
88         let tag: String = string.into();
89         assert!(!tag.is_empty(), "empty YAML tag is not allowed");
90         Tag { string: tag }
91     }
92 }
93 
94 impl Value {
untag(self) -> Self95     pub(crate) fn untag(self) -> Self {
96         let mut cur = self;
97         while let Value::Tagged(tagged) = cur {
98             cur = tagged.value;
99         }
100         cur
101     }
102 
untag_ref(&self) -> &Self103     pub(crate) fn untag_ref(&self) -> &Self {
104         let mut cur = self;
105         while let Value::Tagged(tagged) = cur {
106             cur = &tagged.value;
107         }
108         cur
109     }
110 
untag_mut(&mut self) -> &mut Self111     pub(crate) fn untag_mut(&mut self) -> &mut Self {
112         let mut cur = self;
113         while let Value::Tagged(tagged) = cur {
114             cur = &mut tagged.value;
115         }
116         cur
117     }
118 }
119 
nobang(maybe_banged: &str) -> &str120 pub(crate) fn nobang(maybe_banged: &str) -> &str {
121     match maybe_banged.strip_prefix('!') {
122         Some("") | None => maybe_banged,
123         Some(unbanged) => unbanged,
124     }
125 }
126 
127 impl Eq for Tag {}
128 
129 impl PartialEq for Tag {
eq(&self, other: &Tag) -> bool130     fn eq(&self, other: &Tag) -> bool {
131         PartialEq::eq(nobang(&self.string), nobang(&other.string))
132     }
133 }
134 
135 impl<T> PartialEq<T> for Tag
136 where
137     T: ?Sized + AsRef<str>,
138 {
eq(&self, other: &T) -> bool139     fn eq(&self, other: &T) -> bool {
140         PartialEq::eq(nobang(&self.string), nobang(other.as_ref()))
141     }
142 }
143 
144 impl Ord for Tag {
cmp(&self, other: &Self) -> Ordering145     fn cmp(&self, other: &Self) -> Ordering {
146         Ord::cmp(nobang(&self.string), nobang(&other.string))
147     }
148 }
149 
150 impl PartialOrd for Tag {
partial_cmp(&self, other: &Self) -> Option<Ordering>151     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152         Some(self.cmp(other))
153     }
154 }
155 
156 impl Hash for Tag {
hash<H: Hasher>(&self, hasher: &mut H)157     fn hash<H: Hasher>(&self, hasher: &mut H) {
158         nobang(&self.string).hash(hasher);
159     }
160 }
161 
162 impl Display for Tag {
fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result163     fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
164         write!(formatter, "!{}", nobang(&self.string))
165     }
166 }
167 
168 impl Debug for Tag {
fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result169     fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
170         Display::fmt(self, formatter)
171     }
172 }
173 
174 impl Serialize for TaggedValue {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer,175     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
176     where
177         S: Serializer,
178     {
179         struct SerializeTag<'a>(&'a Tag);
180 
181         impl<'a> Serialize for SerializeTag<'a> {
182             fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
183             where
184                 S: Serializer,
185             {
186                 serializer.collect_str(self.0)
187             }
188         }
189 
190         let mut map = serializer.serialize_map(Some(1))?;
191         map.serialize_entry(&SerializeTag(&self.tag), &self.value)?;
192         map.end()
193     }
194 }
195 
196 impl<'de> Deserialize<'de> for TaggedValue {
deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de>,197     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
198     where
199         D: Deserializer<'de>,
200     {
201         struct TaggedValueVisitor;
202 
203         impl<'de> Visitor<'de> for TaggedValueVisitor {
204             type Value = TaggedValue;
205 
206             fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
207                 formatter.write_str("a YAML value with a !Tag")
208             }
209 
210             fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
211             where
212                 A: EnumAccess<'de>,
213             {
214                 let (tag, contents) = data.variant_seed(TagStringVisitor)?;
215                 let value = contents.newtype_variant()?;
216                 Ok(TaggedValue { tag, value })
217             }
218         }
219 
220         deserializer.deserialize_any(TaggedValueVisitor)
221     }
222 }
223 
224 impl<'de> Deserializer<'de> for TaggedValue {
225     type Error = Error;
226 
deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error> where V: Visitor<'de>,227     fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
228     where
229         V: Visitor<'de>,
230     {
231         visitor.visit_enum(self)
232     }
233 
deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error> where V: Visitor<'de>,234     fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error>
235     where
236         V: Visitor<'de>,
237     {
238         drop(self);
239         visitor.visit_unit()
240     }
241 
242     forward_to_deserialize_any! {
243         bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
244         byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct
245         map struct enum identifier
246     }
247 }
248 
249 impl<'de> EnumAccess<'de> for TaggedValue {
250     type Error = Error;
251     type Variant = Value;
252 
variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Error> where V: DeserializeSeed<'de>,253     fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Error>
254     where
255         V: DeserializeSeed<'de>,
256     {
257         let tag = StrDeserializer::<Error>::new(nobang(&self.tag.string));
258         let value = seed.deserialize(tag)?;
259         Ok((value, self.value))
260     }
261 }
262 
263 impl<'de> VariantAccess<'de> for Value {
264     type Error = Error;
265 
unit_variant(self) -> Result<(), Error>266     fn unit_variant(self) -> Result<(), Error> {
267         Deserialize::deserialize(self)
268     }
269 
newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error> where T: DeserializeSeed<'de>,270     fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error>
271     where
272         T: DeserializeSeed<'de>,
273     {
274         seed.deserialize(self)
275     }
276 
tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error> where V: Visitor<'de>,277     fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error>
278     where
279         V: Visitor<'de>,
280     {
281         if let Value::Sequence(v) = self {
282             Deserializer::deserialize_any(SeqDeserializer::new(v), visitor)
283         } else {
284             Err(Error::invalid_type(self.unexpected(), &"tuple variant"))
285         }
286     }
287 
struct_variant<V>( self, _fields: &'static [&'static str], visitor: V, ) -> Result<V::Value, Error> where V: Visitor<'de>,288     fn struct_variant<V>(
289         self,
290         _fields: &'static [&'static str],
291         visitor: V,
292     ) -> Result<V::Value, Error>
293     where
294         V: Visitor<'de>,
295     {
296         if let Value::Mapping(v) = self {
297             Deserializer::deserialize_any(MapDeserializer::new(v), visitor)
298         } else {
299             Err(Error::invalid_type(self.unexpected(), &"struct variant"))
300         }
301     }
302 }
303 
304 impl<'de> Deserializer<'de> for &'de TaggedValue {
305     type Error = Error;
306 
deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error> where V: Visitor<'de>,307     fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
308     where
309         V: Visitor<'de>,
310     {
311         visitor.visit_enum(self)
312     }
313 
deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error> where V: Visitor<'de>,314     fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error>
315     where
316         V: Visitor<'de>,
317     {
318         visitor.visit_unit()
319     }
320 
321     forward_to_deserialize_any! {
322         bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
323         byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct
324         map struct enum identifier
325     }
326 }
327 
328 impl<'de> EnumAccess<'de> for &'de TaggedValue {
329     type Error = Error;
330     type Variant = &'de Value;
331 
variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Error> where V: DeserializeSeed<'de>,332     fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Error>
333     where
334         V: DeserializeSeed<'de>,
335     {
336         let tag = BorrowedStrDeserializer::<Error>::new(nobang(&self.tag.string));
337         let value = seed.deserialize(tag)?;
338         Ok((value, &self.value))
339     }
340 }
341 
342 impl<'de> VariantAccess<'de> for &'de Value {
343     type Error = Error;
344 
unit_variant(self) -> Result<(), Error>345     fn unit_variant(self) -> Result<(), Error> {
346         Deserialize::deserialize(self)
347     }
348 
newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error> where T: DeserializeSeed<'de>,349     fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Error>
350     where
351         T: DeserializeSeed<'de>,
352     {
353         seed.deserialize(self)
354     }
355 
tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error> where V: Visitor<'de>,356     fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error>
357     where
358         V: Visitor<'de>,
359     {
360         if let Value::Sequence(v) = self {
361             Deserializer::deserialize_any(SeqRefDeserializer::new(v), visitor)
362         } else {
363             Err(Error::invalid_type(self.unexpected(), &"tuple variant"))
364         }
365     }
366 
struct_variant<V>( self, _fields: &'static [&'static str], visitor: V, ) -> Result<V::Value, Error> where V: Visitor<'de>,367     fn struct_variant<V>(
368         self,
369         _fields: &'static [&'static str],
370         visitor: V,
371     ) -> Result<V::Value, Error>
372     where
373         V: Visitor<'de>,
374     {
375         if let Value::Mapping(v) = self {
376             Deserializer::deserialize_any(MapRefDeserializer::new(v), visitor)
377         } else {
378             Err(Error::invalid_type(self.unexpected(), &"struct variant"))
379         }
380     }
381 }
382 
383 pub(crate) struct TagStringVisitor;
384 
385 impl<'de> Visitor<'de> for TagStringVisitor {
386     type Value = Tag;
387 
expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result388     fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
389         formatter.write_str("a YAML tag string")
390     }
391 
visit_str<E>(self, string: &str) -> Result<Self::Value, E> where E: serde::de::Error,392     fn visit_str<E>(self, string: &str) -> Result<Self::Value, E>
393     where
394         E: serde::de::Error,
395     {
396         self.visit_string(string.to_owned())
397     }
398 
visit_string<E>(self, string: String) -> Result<Self::Value, E> where E: serde::de::Error,399     fn visit_string<E>(self, string: String) -> Result<Self::Value, E>
400     where
401         E: serde::de::Error,
402     {
403         if string.is_empty() {
404             return Err(E::custom("empty YAML tag is not allowed"));
405         }
406         Ok(Tag::new(string))
407     }
408 }
409 
410 impl<'de> DeserializeSeed<'de> for TagStringVisitor {
411     type Value = Tag;
412 
deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error> where D: Deserializer<'de>,413     fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
414     where
415         D: Deserializer<'de>,
416     {
417         deserializer.deserialize_string(self)
418     }
419 }
420 
421 pub(crate) enum MaybeTag<T> {
422     Tag(String),
423     NotTag(T),
424 }
425 
check_for_tag<T>(value: &T) -> MaybeTag<String> where T: ?Sized + Display,426 pub(crate) fn check_for_tag<T>(value: &T) -> MaybeTag<String>
427 where
428     T: ?Sized + Display,
429 {
430     enum CheckForTag {
431         Empty,
432         Bang,
433         Tag(String),
434         NotTag(String),
435     }
436 
437     impl fmt::Write for CheckForTag {
438         fn write_str(&mut self, s: &str) -> fmt::Result {
439             if s.is_empty() {
440                 return Ok(());
441             }
442             match self {
443                 CheckForTag::Empty => {
444                     if s == "!" {
445                         *self = CheckForTag::Bang;
446                     } else {
447                         *self = CheckForTag::NotTag(s.to_owned());
448                     }
449                 }
450                 CheckForTag::Bang => {
451                     *self = CheckForTag::Tag(s.to_owned());
452                 }
453                 CheckForTag::Tag(string) => {
454                     let mut string = mem::take(string);
455                     string.push_str(s);
456                     *self = CheckForTag::NotTag(string);
457                 }
458                 CheckForTag::NotTag(string) => {
459                     string.push_str(s);
460                 }
461             }
462             Ok(())
463         }
464     }
465 
466     let mut check_for_tag = CheckForTag::Empty;
467     fmt::write(&mut check_for_tag, format_args!("{}", value)).unwrap();
468     match check_for_tag {
469         CheckForTag::Empty => MaybeTag::NotTag(String::new()),
470         CheckForTag::Bang => MaybeTag::NotTag("!".to_owned()),
471         CheckForTag::Tag(string) => MaybeTag::Tag(string),
472         CheckForTag::NotTag(string) => MaybeTag::NotTag(string),
473     }
474 }
475