1 use crate::{InsertError, MatchError, Params};
2 
3 use std::cell::UnsafeCell;
4 use std::cmp::min;
5 use std::mem;
6 
7 /// The types of nodes the tree can hold
8 #[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone)]
9 pub(crate) enum NodeType {
10     /// The root path
11     Root,
12     /// A route parameter, ex: `/:id`.
13     Param,
14     /// A catchall parameter, ex: `/*file`
15     CatchAll,
16     /// Anything else
17     Static,
18 }
19 
20 /// A radix tree used for URL path matching.
21 ///
22 /// See [the crate documentation](crate) for details.
23 pub struct Node<T> {
24     priority: u32,
25     wild_child: bool,
26     indices: Vec<u8>,
27     // see `at` for why an unsafe cell is needed
28     value: Option<UnsafeCell<T>>,
29     pub(crate) param_remapping: ParamRemapping,
30     pub(crate) node_type: NodeType,
31     pub(crate) prefix: Vec<u8>,
32     pub(crate) children: Vec<Self>,
33 }
34 
35 // SAFETY: we expose `value` per rust's usual borrowing rules, so we can just delegate these traits
36 unsafe impl<T: Send> Send for Node<T> {}
37 unsafe impl<T: Sync> Sync for Node<T> {}
38 
39 impl<T> Node<T> {
insert(&mut self, route: impl Into<String>, val: T) -> Result<(), InsertError>40     pub fn insert(&mut self, route: impl Into<String>, val: T) -> Result<(), InsertError> {
41         let route = route.into().into_bytes();
42         let (route, param_remapping) = normalize_params(route)?;
43         let mut prefix = route.as_ref();
44 
45         self.priority += 1;
46 
47         // the tree is empty
48         if self.prefix.is_empty() && self.children.is_empty() {
49             let last = self.insert_child(prefix, &route, val)?;
50             last.param_remapping = param_remapping;
51             self.node_type = NodeType::Root;
52             return Ok(());
53         }
54 
55         let mut current = self;
56 
57         'walk: loop {
58             // find the longest common prefix
59             let len = min(prefix.len(), current.prefix.len());
60             let common_prefix = (0..len)
61                 .find(|&i| prefix[i] != current.prefix[i])
62                 .unwrap_or(len);
63 
64             // the common prefix is a substring of the current node's prefix, split the node
65             if common_prefix < current.prefix.len() {
66                 let child = Node {
67                     prefix: current.prefix[common_prefix..].to_owned(),
68                     children: mem::take(&mut current.children),
69                     wild_child: current.wild_child,
70                     indices: current.indices.clone(),
71                     value: current.value.take(),
72                     param_remapping: mem::take(&mut current.param_remapping),
73                     priority: current.priority - 1,
74                     ..Node::default()
75                 };
76 
77                 // the current node now holds only the common prefix
78                 current.children = vec![child];
79                 current.indices = vec![current.prefix[common_prefix]];
80                 current.prefix = prefix[..common_prefix].to_owned();
81                 current.wild_child = false;
82             }
83 
84             // the route has a common prefix, search deeper
85             if prefix.len() > common_prefix {
86                 prefix = &prefix[common_prefix..];
87 
88                 let next = prefix[0];
89 
90                 // `/` after param
91                 if current.node_type == NodeType::Param
92                     && next == b'/'
93                     && current.children.len() == 1
94                 {
95                     current = &mut current.children[0];
96                     current.priority += 1;
97 
98                     continue 'walk;
99                 }
100 
101                 // find a child that matches the next path byte
102                 for mut i in 0..current.indices.len() {
103                     // found a match
104                     if next == current.indices[i] {
105                         i = current.update_child_priority(i);
106                         current = &mut current.children[i];
107                         continue 'walk;
108                     }
109                 }
110 
111                 // not a wildcard and there is no matching child node, create a new one
112                 if !matches!(next, b':' | b'*') && current.node_type != NodeType::CatchAll {
113                     current.indices.push(next);
114                     let mut child = current.add_child(Node::default());
115                     child = current.update_child_priority(child);
116 
117                     // insert into the new node
118                     let last = current.children[child].insert_child(prefix, &route, val)?;
119                     last.param_remapping = param_remapping;
120                     return Ok(());
121                 }
122 
123                 // inserting a wildcard, and this node already has a wildcard child
124                 if current.wild_child {
125                     // wildcards are always at the end
126                     current = current.children.last_mut().unwrap();
127                     current.priority += 1;
128 
129                     // make sure the wildcard matches
130                     if prefix.len() < current.prefix.len()
131                         || current.prefix != prefix[..current.prefix.len()]
132                         // catch-alls cannot have children
133                         || current.node_type == NodeType::CatchAll
134                         // check for longer wildcard, e.g. :name and :names
135                         || (current.prefix.len() < prefix.len()
136                             && prefix[current.prefix.len()] != b'/')
137                     {
138                         return Err(InsertError::conflict(&route, prefix, current));
139                     }
140 
141                     continue 'walk;
142                 }
143 
144                 // otherwise, create the wildcard node
145                 let last = current.insert_child(prefix, &route, val)?;
146                 last.param_remapping = param_remapping;
147                 return Ok(());
148             }
149 
150             // exact match, this node should be empty
151             if current.value.is_some() {
152                 return Err(InsertError::conflict(&route, prefix, current));
153             }
154 
155             // add the value to current node
156             current.value = Some(UnsafeCell::new(val));
157             current.param_remapping = param_remapping;
158 
159             return Ok(());
160         }
161     }
162 
163     // add a child node, keeping wildcards at the end
add_child(&mut self, child: Node<T>) -> usize164     fn add_child(&mut self, child: Node<T>) -> usize {
165         let len = self.children.len();
166 
167         if self.wild_child && len > 0 {
168             self.children.insert(len - 1, child);
169             len - 1
170         } else {
171             self.children.push(child);
172             len
173         }
174     }
175 
176     // increments priority of the given child and reorders if necessary.
177     //
178     // returns the new index of the child
update_child_priority(&mut self, i: usize) -> usize179     fn update_child_priority(&mut self, i: usize) -> usize {
180         self.children[i].priority += 1;
181         let priority = self.children[i].priority;
182 
183         // adjust position (move to front)
184         let mut updated = i;
185         while updated > 0 && self.children[updated - 1].priority < priority {
186             // swap node positions
187             self.children.swap(updated - 1, updated);
188             updated -= 1;
189         }
190 
191         // build new index list
192         if updated != i {
193             self.indices = [
194                 &self.indices[..updated],  // unchanged prefix, might be empty
195                 &self.indices[i..=i],      // the index char we move
196                 &self.indices[updated..i], // rest without char at 'pos'
197                 &self.indices[i + 1..],
198             ]
199             .concat();
200         }
201 
202         updated
203     }
204 
205     // insert a child node at this node
insert_child( &mut self, mut prefix: &[u8], route: &[u8], val: T, ) -> Result<&mut Node<T>, InsertError>206     fn insert_child(
207         &mut self,
208         mut prefix: &[u8],
209         route: &[u8],
210         val: T,
211     ) -> Result<&mut Node<T>, InsertError> {
212         let mut current = self;
213 
214         loop {
215             // search for a wildcard segment
216             let (wildcard, wildcard_index) = match find_wildcard(prefix)? {
217                 Some((w, i)) => (w, i),
218                 // no wildcard, simply use the current node
219                 None => {
220                     current.value = Some(UnsafeCell::new(val));
221                     current.prefix = prefix.to_owned();
222                     return Ok(current);
223                 }
224             };
225 
226             // regular route parameter
227             if wildcard[0] == b':' {
228                 // insert prefix before the current wildcard
229                 if wildcard_index > 0 {
230                     current.prefix = prefix[..wildcard_index].to_owned();
231                     prefix = &prefix[wildcard_index..];
232                 }
233 
234                 let child = Self {
235                     node_type: NodeType::Param,
236                     prefix: wildcard.to_owned(),
237                     ..Self::default()
238                 };
239 
240                 let child = current.add_child(child);
241                 current.wild_child = true;
242                 current = &mut current.children[child];
243                 current.priority += 1;
244 
245                 // if the route doesn't end with the wildcard, then there
246                 // will be another non-wildcard subroute starting with '/'
247                 if wildcard.len() < prefix.len() {
248                     prefix = &prefix[wildcard.len()..];
249                     let child = Self {
250                         priority: 1,
251                         ..Self::default()
252                     };
253 
254                     let child = current.add_child(child);
255                     current = &mut current.children[child];
256                     continue;
257                 }
258 
259                 // otherwise we're done. Insert the value in the new leaf
260                 current.value = Some(UnsafeCell::new(val));
261                 return Ok(current);
262 
263             // catch-all route
264             } else if wildcard[0] == b'*' {
265                 // "/foo/*x/bar"
266                 if wildcard_index + wildcard.len() != prefix.len() {
267                     return Err(InsertError::InvalidCatchAll);
268                 }
269 
270                 if let Some(i) = wildcard_index.checked_sub(1) {
271                     // "/foo/bar*x"
272                     if prefix[i] != b'/' {
273                         return Err(InsertError::InvalidCatchAll);
274                     }
275                 }
276 
277                 // "*x" without leading `/`
278                 if prefix == route && route[0] != b'/' {
279                     return Err(InsertError::InvalidCatchAll);
280                 }
281 
282                 // insert prefix before the current wildcard
283                 if wildcard_index > 0 {
284                     current.prefix = prefix[..wildcard_index].to_owned();
285                     prefix = &prefix[wildcard_index..];
286                 }
287 
288                 let child = Self {
289                     prefix: prefix.to_owned(),
290                     node_type: NodeType::CatchAll,
291                     value: Some(UnsafeCell::new(val)),
292                     priority: 1,
293                     ..Self::default()
294                 };
295 
296                 let i = current.add_child(child);
297                 current.wild_child = true;
298 
299                 return Ok(&mut current.children[i]);
300             }
301         }
302     }
303 }
304 
305 struct Skipped<'n, 'p, T> {
306     path: &'p [u8],
307     node: &'n Node<T>,
308     params: usize,
309 }
310 
311 #[rustfmt::skip]
312 macro_rules! backtracker {
313     ($skipped_nodes:ident, $path:ident, $current:ident, $params:ident, $backtracking:ident, $walk:lifetime) => {
314         macro_rules! try_backtrack {
315             () => {
316                 // try backtracking to any matching wildcard nodes we skipped while traversing
317                 // the tree
318                 while let Some(skipped) = $skipped_nodes.pop() {
319                     if skipped.path.ends_with($path) {
320                         $path = skipped.path;
321                         $current = &skipped.node;
322                         $params.truncate(skipped.params);
323                         $backtracking = true;
324                         continue $walk;
325                     }
326                 }
327             };
328         }
329     };
330 }
331 
332 impl<T> Node<T> {
333     // it's a bit sad that we have to introduce unsafe here but rust doesn't really have a way
334     // to abstract over mutability, so `UnsafeCell` lets us avoid having to duplicate logic between
335     // `at` and `at_mut`
at<'n, 'p>( &'n self, full_path: &'p [u8], ) -> Result<(&'n UnsafeCell<T>, Params<'n, 'p>), MatchError>336     pub fn at<'n, 'p>(
337         &'n self,
338         full_path: &'p [u8],
339     ) -> Result<(&'n UnsafeCell<T>, Params<'n, 'p>), MatchError> {
340         let mut current = self;
341         let mut path = full_path;
342         let mut backtracking = false;
343         let mut params = Params::new();
344         let mut skipped_nodes = Vec::new();
345 
346         'walk: loop {
347             backtracker!(skipped_nodes, path, current, params, backtracking, 'walk);
348 
349             // the path is longer than this node's prefix, we are expecting a child node
350             if path.len() > current.prefix.len() {
351                 let (prefix, rest) = path.split_at(current.prefix.len());
352 
353                 // the prefix matches
354                 if prefix == current.prefix {
355                     let first = rest[0];
356                     let consumed = path;
357                     path = rest;
358 
359                     // try searching for a matching static child unless we are currently
360                     // backtracking, which would mean we already traversed them
361                     if !backtracking {
362                         if let Some(i) = current.indices.iter().position(|&c| c == first) {
363                             // keep track of wildcard routes we skipped to backtrack to later if
364                             // we don't find a math
365                             if current.wild_child {
366                                 skipped_nodes.push(Skipped {
367                                     path: consumed,
368                                     node: current,
369                                     params: params.len(),
370                                 });
371                             }
372 
373                             // child won't match because of an extra trailing slash
374                             if path == b"/"
375                                 && current.children[i].prefix != b"/"
376                                 && current.value.is_some()
377                             {
378                                 return Err(MatchError::ExtraTrailingSlash);
379                             }
380 
381                             // continue with the child node
382                             current = &current.children[i];
383                             continue 'walk;
384                         }
385                     }
386 
387                     // we didn't find a match and there are no children with wildcards, there is no match
388                     if !current.wild_child {
389                         // extra trailing slash
390                         if path == b"/" && current.value.is_some() {
391                             return Err(MatchError::ExtraTrailingSlash);
392                         }
393 
394                         // try backtracking
395                         if path != b"/" {
396                             try_backtrack!();
397                         }
398 
399                         // nothing found
400                         return Err(MatchError::NotFound);
401                     }
402 
403                     // handle the wildcard child, which is always at the end of the list
404                     current = current.children.last().unwrap();
405 
406                     match current.node_type {
407                         NodeType::Param => {
408                             // check if there are more segments in the path other than this parameter
409                             match path.iter().position(|&c| c == b'/') {
410                                 Some(i) => {
411                                     let (param, rest) = path.split_at(i);
412 
413                                     if let [child] = current.children.as_slice() {
414                                         // child won't match because of an extra trailing slash
415                                         if rest == b"/"
416                                             && child.prefix != b"/"
417                                             && current.value.is_some()
418                                         {
419                                             return Err(MatchError::ExtraTrailingSlash);
420                                         }
421 
422                                         // store the parameter value
423                                         params.push(&current.prefix[1..], param);
424 
425                                         // continue with the child node
426                                         path = rest;
427                                         current = child;
428                                         backtracking = false;
429                                         continue 'walk;
430                                     }
431 
432                                     // this node has no children yet the path has more segments...
433                                     // either the path has an extra trailing slash or there is no match
434                                     if path.len() == i + 1 {
435                                         return Err(MatchError::ExtraTrailingSlash);
436                                     }
437 
438                                     // try backtracking
439                                     if path != b"/" {
440                                         try_backtrack!();
441                                     }
442 
443                                     return Err(MatchError::NotFound);
444                                 }
445                                 // this is the last path segment
446                                 None => {
447                                     // store the parameter value
448                                     params.push(&current.prefix[1..], path);
449 
450                                     // found the matching value
451                                     if let Some(ref value) = current.value {
452                                         // remap parameter keys
453                                         params.for_each_key_mut(|(i, key)| {
454                                             *key = &current.param_remapping[i][1..]
455                                         });
456 
457                                         return Ok((value, params));
458                                     }
459 
460                                     // check the child node in case the path is missing a trailing slash
461                                     if let [child] = current.children.as_slice() {
462                                         current = child;
463 
464                                         if (current.prefix == b"/" && current.value.is_some())
465                                             || (current.prefix.is_empty()
466                                                 && current.indices == b"/")
467                                         {
468                                             return Err(MatchError::MissingTrailingSlash);
469                                         }
470 
471                                         // no match, try backtracking
472                                         if path != b"/" {
473                                             try_backtrack!();
474                                         }
475                                     }
476 
477                                     // this node doesn't have the value, no match
478                                     return Err(MatchError::NotFound);
479                                 }
480                             }
481                         }
482                         NodeType::CatchAll => {
483                             // catch all segments are only allowed at the end of the route,
484                             // either this node has the value or there is no match
485                             return match current.value {
486                                 Some(ref value) => {
487                                     // remap parameter keys
488                                     params.for_each_key_mut(|(i, key)| {
489                                         *key = &current.param_remapping[i][1..]
490                                     });
491 
492                                     // store the final catch-all parameter
493                                     params.push(&current.prefix[1..], path);
494 
495                                     Ok((value, params))
496                                 }
497                                 None => Err(MatchError::NotFound),
498                             };
499                         }
500                         _ => unreachable!(),
501                     }
502                 }
503             }
504 
505             // this is it, we should have reached the node containing the value
506             if path == current.prefix {
507                 if let Some(ref value) = current.value {
508                     // remap parameter keys
509                     params.for_each_key_mut(|(i, key)| *key = &current.param_remapping[i][1..]);
510                     return Ok((value, params));
511                 }
512 
513                 // nope, try backtracking
514                 if path != b"/" {
515                     try_backtrack!();
516                 }
517 
518                 // TODO: does this *always* means there is an extra trailing slash?
519                 if path == b"/" && current.wild_child && current.node_type != NodeType::Root {
520                     return Err(MatchError::unsure(full_path));
521                 }
522 
523                 if !backtracking {
524                     // check if the path is missing a trailing slash
525                     if let Some(i) = current.indices.iter().position(|&c| c == b'/') {
526                         current = &current.children[i];
527 
528                         if current.prefix.len() == 1 && current.value.is_some() {
529                             return Err(MatchError::MissingTrailingSlash);
530                         }
531                     }
532                 }
533 
534                 return Err(MatchError::NotFound);
535             }
536 
537             // nothing matches, check for a missing trailing slash
538             if current.prefix.split_last() == Some((&b'/', path)) && current.value.is_some() {
539                 return Err(MatchError::MissingTrailingSlash);
540             }
541 
542             // last chance, try backtracking
543             if path != b"/" {
544                 try_backtrack!();
545             }
546 
547             return Err(MatchError::NotFound);
548         }
549     }
550 
551     #[cfg(feature = "__test_helpers")]
check_priorities(&self) -> Result<u32, (u32, u32)>552     pub fn check_priorities(&self) -> Result<u32, (u32, u32)> {
553         let mut priority: u32 = 0;
554         for child in &self.children {
555             priority += child.check_priorities()?;
556         }
557 
558         if self.value.is_some() {
559             priority += 1;
560         }
561 
562         if self.priority != priority {
563             return Err((self.priority, priority));
564         }
565 
566         Ok(priority)
567     }
568 }
569 
570 /// An ordered list of route parameters keys for a specific route, stored at leaf nodes.
571 type ParamRemapping = Vec<Vec<u8>>;
572 
573 /// Returns `path` with normalized route parameters, and a parameter remapping
574 /// to store at the leaf node for this route.
normalize_params(mut path: Vec<u8>) -> Result<(Vec<u8>, ParamRemapping), InsertError>575 fn normalize_params(mut path: Vec<u8>) -> Result<(Vec<u8>, ParamRemapping), InsertError> {
576     let mut start = 0;
577     let mut original = ParamRemapping::new();
578 
579     // parameter names are normalized alphabetically
580     let mut next = b'a';
581 
582     loop {
583         let (wildcard, mut wildcard_index) = match find_wildcard(&path[start..])? {
584             Some((w, i)) => (w, i),
585             None => return Ok((path, original)),
586         };
587 
588         // makes sure the param has a valid name
589         if wildcard.len() < 2 {
590             return Err(InsertError::UnnamedParam);
591         }
592 
593         // don't need to normalize catch-all parameters
594         if wildcard[0] == b'*' {
595             start += wildcard_index + wildcard.len();
596             continue;
597         }
598 
599         wildcard_index += start;
600 
601         // normalize the parameter
602         let removed = path.splice(
603             (wildcard_index)..(wildcard_index + wildcard.len()),
604             vec![b':', next],
605         );
606 
607         // remember the original name for remappings
608         original.push(removed.collect());
609 
610         // get the next key
611         next += 1;
612         if next > b'z' {
613             panic!("too many route parameters");
614         }
615 
616         start = wildcard_index + 2;
617     }
618 }
619 
620 /// Restores `route` to it's original, denormalized form.
denormalize_params(route: &mut Vec<u8>, params: &ParamRemapping)621 pub(crate) fn denormalize_params(route: &mut Vec<u8>, params: &ParamRemapping) {
622     let mut start = 0;
623     let mut i = 0;
624 
625     loop {
626         // find the next wildcard
627         let (wildcard, mut wildcard_index) = match find_wildcard(&route[start..]).unwrap() {
628             Some((w, i)) => (w, i),
629             None => return,
630         };
631 
632         wildcard_index += start;
633 
634         let next = match params.get(i) {
635             Some(param) => param.clone(),
636             None => return,
637         };
638 
639         // denormalize this parameter
640         route.splice(
641             (wildcard_index)..(wildcard_index + wildcard.len()),
642             next.clone(),
643         );
644 
645         i += 1;
646         start = wildcard_index + 2;
647     }
648 }
649 
650 // Searches for a wildcard segment and checks the path for invalid characters.
find_wildcard(path: &[u8]) -> Result<Option<(&[u8], usize)>, InsertError>651 fn find_wildcard(path: &[u8]) -> Result<Option<(&[u8], usize)>, InsertError> {
652     for (start, &c) in path.iter().enumerate() {
653         // a wildcard starts with ':' (param) or '*' (catch-all)
654         if c != b':' && c != b'*' {
655             continue;
656         }
657 
658         for (end, &c) in path[start + 1..].iter().enumerate() {
659             match c {
660                 b'/' => return Ok(Some((&path[start..start + 1 + end], start))),
661                 b':' | b'*' => return Err(InsertError::TooManyParams),
662                 _ => {}
663             }
664         }
665 
666         return Ok(Some((&path[start..], start)));
667     }
668 
669     Ok(None)
670 }
671 
672 impl<T> Clone for Node<T>
673 where
674     T: Clone,
675 {
clone(&self) -> Self676     fn clone(&self) -> Self {
677         let value = self.value.as_ref().map(|value| {
678             // safety: we only expose &mut T through &mut self
679             let value = unsafe { &*value.get() };
680             UnsafeCell::new(value.clone())
681         });
682 
683         Self {
684             value,
685             prefix: self.prefix.clone(),
686             wild_child: self.wild_child,
687             node_type: self.node_type.clone(),
688             indices: self.indices.clone(),
689             children: self.children.clone(),
690             param_remapping: self.param_remapping.clone(),
691             priority: self.priority,
692         }
693     }
694 }
695 
696 impl<T> Default for Node<T> {
default() -> Self697     fn default() -> Self {
698         Self {
699             param_remapping: ParamRemapping::new(),
700             prefix: Vec::new(),
701             wild_child: false,
702             node_type: NodeType::Static,
703             indices: Vec::new(),
704             children: Vec::new(),
705             value: None,
706             priority: 0,
707         }
708     }
709 }
710 
711 #[cfg(test)]
712 const _: () = {
713     use std::fmt::{self, Debug, Formatter};
714 
715     // visualize the tree structure when debugging
716     impl<T: Debug> Debug for Node<T> {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result717         fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
718             // safety: we only expose &mut T through &mut self
719             let value = unsafe { self.value.as_ref().map(|x| &*x.get()) };
720 
721             let indices = self
722                 .indices
723                 .iter()
724                 .map(|&x| char::from_u32(x as _))
725                 .collect::<Vec<_>>();
726 
727             let param_names = self
728                 .param_remapping
729                 .iter()
730                 .map(|x| std::str::from_utf8(x).unwrap())
731                 .collect::<Vec<_>>();
732 
733             let mut fmt = f.debug_struct("Node");
734             fmt.field("value", &value);
735             fmt.field("prefix", &std::str::from_utf8(&self.prefix));
736             fmt.field("node_type", &self.node_type);
737             fmt.field("children", &self.children);
738             fmt.field("param_names", &param_names);
739             fmt.field("indices", &indices);
740             fmt.finish()
741         }
742     }
743 };
744