1 // Copyright 2013 The rust-url developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 
9 //! Punycode ([RFC 3492](http://tools.ietf.org/html/rfc3492)) implementation.
10 //!
11 //! Since Punycode fundamentally works on unicode code points,
12 //! `encode` and `decode` take and return slices and vectors of `char`.
13 //! `encode_str` and `decode_to_string` provide convenience wrappers
14 //! that convert from and to Rust’s UTF-8 based `str` and `String` types.
15 
16 use alloc::{string::String, vec::Vec};
17 use core::char;
18 use core::u32;
19 
20 // Bootstring parameters for Punycode
21 static BASE: u32 = 36;
22 static T_MIN: u32 = 1;
23 static T_MAX: u32 = 26;
24 static SKEW: u32 = 38;
25 static DAMP: u32 = 700;
26 static INITIAL_BIAS: u32 = 72;
27 static INITIAL_N: u32 = 0x80;
28 static DELIMITER: char = '-';
29 
30 #[inline]
adapt(mut delta: u32, num_points: u32, first_time: bool) -> u3231 fn adapt(mut delta: u32, num_points: u32, first_time: bool) -> u32 {
32     delta /= if first_time { DAMP } else { 2 };
33     delta += delta / num_points;
34     let mut k = 0;
35     while delta > ((BASE - T_MIN) * T_MAX) / 2 {
36         delta /= BASE - T_MIN;
37         k += BASE;
38     }
39     k + (((BASE - T_MIN + 1) * delta) / (delta + SKEW))
40 }
41 
42 /// Convert Punycode to an Unicode `String`.
43 ///
44 /// This is a convenience wrapper around `decode`.
45 #[inline]
decode_to_string(input: &str) -> Option<String>46 pub fn decode_to_string(input: &str) -> Option<String> {
47     decode(input).map(|chars| chars.into_iter().collect())
48 }
49 
50 /// Convert Punycode to Unicode.
51 ///
52 /// Return None on malformed input or overflow.
53 /// Overflow can only happen on inputs that take more than
54 /// 63 encoded bytes, the DNS limit on domain name labels.
decode(input: &str) -> Option<Vec<char>>55 pub fn decode(input: &str) -> Option<Vec<char>> {
56     Some(Decoder::default().decode(input).ok()?.collect())
57 }
58 
59 #[derive(Default)]
60 pub(crate) struct Decoder {
61     insertions: Vec<(usize, char)>,
62 }
63 
64 impl Decoder {
65     /// Split the input iterator and return a Vec with insertions of encoded characters
decode<'a>(&'a mut self, input: &'a str) -> Result<Decode<'a>, ()>66     pub(crate) fn decode<'a>(&'a mut self, input: &'a str) -> Result<Decode<'a>, ()> {
67         self.insertions.clear();
68         // Handle "basic" (ASCII) code points.
69         // They are encoded as-is before the last delimiter, if any.
70         let (base, input) = match input.rfind(DELIMITER) {
71             None => ("", input),
72             Some(position) => (
73                 &input[..position],
74                 if position > 0 {
75                     &input[position + 1..]
76                 } else {
77                     input
78                 },
79             ),
80         };
81 
82         if !base.is_ascii() {
83             return Err(());
84         }
85 
86         let base_len = base.len();
87         let mut length = base_len as u32;
88         let mut code_point = INITIAL_N;
89         let mut bias = INITIAL_BIAS;
90         let mut i = 0;
91         let mut iter = input.bytes();
92         loop {
93             let previous_i = i;
94             let mut weight = 1;
95             let mut k = BASE;
96             let mut byte = match iter.next() {
97                 None => break,
98                 Some(byte) => byte,
99             };
100 
101             // Decode a generalized variable-length integer into delta,
102             // which gets added to i.
103             loop {
104                 let digit = match byte {
105                     byte @ b'0'..=b'9' => byte - b'0' + 26,
106                     byte @ b'A'..=b'Z' => byte - b'A',
107                     byte @ b'a'..=b'z' => byte - b'a',
108                     _ => return Err(()),
109                 } as u32;
110                 if digit > (u32::MAX - i) / weight {
111                     return Err(()); // Overflow
112                 }
113                 i += digit * weight;
114                 let t = if k <= bias {
115                     T_MIN
116                 } else if k >= bias + T_MAX {
117                     T_MAX
118                 } else {
119                     k - bias
120                 };
121                 if digit < t {
122                     break;
123                 }
124                 if weight > u32::MAX / (BASE - t) {
125                     return Err(()); // Overflow
126                 }
127                 weight *= BASE - t;
128                 k += BASE;
129                 byte = match iter.next() {
130                     None => return Err(()), // End of input before the end of this delta
131                     Some(byte) => byte,
132                 };
133             }
134 
135             bias = adapt(i - previous_i, length + 1, previous_i == 0);
136             if i / (length + 1) > u32::MAX - code_point {
137                 return Err(()); // Overflow
138             }
139 
140             // i was supposed to wrap around from length+1 to 0,
141             // incrementing code_point each time.
142             code_point += i / (length + 1);
143             i %= length + 1;
144             let c = match char::from_u32(code_point) {
145                 Some(c) => c,
146                 None => return Err(()),
147             };
148 
149             // Move earlier insertions farther out in the string
150             for (idx, _) in &mut self.insertions {
151                 if *idx >= i as usize {
152                     *idx += 1;
153                 }
154             }
155             self.insertions.push((i as usize, c));
156             length += 1;
157             i += 1;
158         }
159 
160         self.insertions.sort_by_key(|(i, _)| *i);
161         Ok(Decode {
162             base: base.chars(),
163             insertions: &self.insertions,
164             inserted: 0,
165             position: 0,
166             len: base_len + self.insertions.len(),
167         })
168     }
169 }
170 
171 pub(crate) struct Decode<'a> {
172     base: core::str::Chars<'a>,
173     pub(crate) insertions: &'a [(usize, char)],
174     inserted: usize,
175     position: usize,
176     len: usize,
177 }
178 
179 impl<'a> Iterator for Decode<'a> {
180     type Item = char;
181 
next(&mut self) -> Option<Self::Item>182     fn next(&mut self) -> Option<Self::Item> {
183         loop {
184             match self.insertions.get(self.inserted) {
185                 Some((pos, c)) if *pos == self.position => {
186                     self.inserted += 1;
187                     self.position += 1;
188                     return Some(*c);
189                 }
190                 _ => {}
191             }
192             if let Some(c) = self.base.next() {
193                 self.position += 1;
194                 return Some(c);
195             } else if self.inserted >= self.insertions.len() {
196                 return None;
197             }
198         }
199     }
200 
size_hint(&self) -> (usize, Option<usize>)201     fn size_hint(&self) -> (usize, Option<usize>) {
202         let len = self.len - self.position;
203         (len, Some(len))
204     }
205 }
206 
207 impl<'a> ExactSizeIterator for Decode<'a> {
len(&self) -> usize208     fn len(&self) -> usize {
209         self.len - self.position
210     }
211 }
212 
213 /// Convert an Unicode `str` to Punycode.
214 ///
215 /// This is a convenience wrapper around `encode`.
216 #[inline]
encode_str(input: &str) -> Option<String>217 pub fn encode_str(input: &str) -> Option<String> {
218     if input.len() > u32::MAX as usize {
219         return None;
220     }
221     let mut buf = String::with_capacity(input.len());
222     encode_into(input.chars(), &mut buf).ok().map(|()| buf)
223 }
224 
225 /// Convert Unicode to Punycode.
226 ///
227 /// Return None on overflow, which can only happen on inputs that would take more than
228 /// 63 encoded bytes, the DNS limit on domain name labels.
encode(input: &[char]) -> Option<String>229 pub fn encode(input: &[char]) -> Option<String> {
230     if input.len() > u32::MAX as usize {
231         return None;
232     }
233     let mut buf = String::with_capacity(input.len());
234     encode_into(input.iter().copied(), &mut buf)
235         .ok()
236         .map(|()| buf)
237 }
238 
encode_into<I>(input: I, output: &mut String) -> Result<(), ()> where I: Iterator<Item = char> + Clone,239 pub(crate) fn encode_into<I>(input: I, output: &mut String) -> Result<(), ()>
240 where
241     I: Iterator<Item = char> + Clone,
242 {
243     // Handle "basic" (ASCII) code points. They are encoded as-is.
244     let (mut input_length, mut basic_length) = (0u32, 0);
245     for c in input.clone() {
246         input_length = input_length.checked_add(1).ok_or(())?;
247         if c.is_ascii() {
248             output.push(c);
249             basic_length += 1;
250         }
251     }
252 
253     if basic_length > 0 {
254         output.push('-')
255     }
256     let mut code_point = INITIAL_N;
257     let mut delta = 0;
258     let mut bias = INITIAL_BIAS;
259     let mut processed = basic_length;
260     while processed < input_length {
261         // All code points < code_point have been handled already.
262         // Find the next larger one.
263         let min_code_point = input
264             .clone()
265             .map(|c| c as u32)
266             .filter(|&c| c >= code_point)
267             .min()
268             .unwrap();
269         if min_code_point - code_point > (u32::MAX - delta) / (processed + 1) {
270             return Err(()); // Overflow
271         }
272         // Increase delta to advance the decoder’s <code_point,i> state to <min_code_point,0>
273         delta += (min_code_point - code_point) * (processed + 1);
274         code_point = min_code_point;
275         for c in input.clone() {
276             let c = c as u32;
277             if c < code_point {
278                 delta = delta.checked_add(1).ok_or(())?;
279             }
280             if c == code_point {
281                 // Represent delta as a generalized variable-length integer:
282                 let mut q = delta;
283                 let mut k = BASE;
284                 loop {
285                     let t = if k <= bias {
286                         T_MIN
287                     } else if k >= bias + T_MAX {
288                         T_MAX
289                     } else {
290                         k - bias
291                     };
292                     if q < t {
293                         break;
294                     }
295                     let value = t + ((q - t) % (BASE - t));
296                     output.push(value_to_digit(value));
297                     q = (q - t) / (BASE - t);
298                     k += BASE;
299                 }
300                 output.push(value_to_digit(q));
301                 bias = adapt(delta, processed + 1, processed == basic_length);
302                 delta = 0;
303                 processed += 1;
304             }
305         }
306         delta += 1;
307         code_point += 1;
308     }
309     Ok(())
310 }
311 
312 #[inline]
value_to_digit(value: u32) -> char313 fn value_to_digit(value: u32) -> char {
314     match value {
315         0..=25 => (value as u8 + b'a') as char,       // a..z
316         26..=35 => (value as u8 - 26 + b'0') as char, // 0..9
317         _ => panic!(),
318     }
319 }
320 
321 #[test]
322 #[ignore = "slow"]
323 #[cfg(target_pointer_width = "64")]
huge_encode()324 fn huge_encode() {
325     let mut buf = String::new();
326     assert!(encode_into(std::iter::repeat('ß').take(u32::MAX as usize + 1), &mut buf).is_err());
327     assert_eq!(buf.len(), 0);
328 }
329