1 use crate::{
2     err::{packet::TransportChecksumError, ValueTooBigError},
3     *,
4 };
5 
6 /// The possible headers on the transport layer
7 #[derive(Clone, Debug, Eq, PartialEq)]
8 pub enum TransportHeader {
9     Udp(UdpHeader),
10     Tcp(TcpHeader),
11     Icmpv4(Icmpv4Header),
12     Icmpv6(Icmpv6Header),
13 }
14 
15 impl TransportHeader {
16     /// Returns Result::Some containing the udp header if self has the value Udp.
17     /// Otherwise None is returned.
udp(self) -> Option<UdpHeader>18     pub fn udp(self) -> Option<UdpHeader> {
19         use crate::TransportHeader::*;
20         if let Udp(value) = self {
21             Some(value)
22         } else {
23             None
24         }
25     }
26 
27     /// Returns Result::Some containing the udp header if self has the value Udp.
28     /// Otherwise None is returned.
mut_udp(&mut self) -> Option<&mut UdpHeader>29     pub fn mut_udp(&mut self) -> Option<&mut UdpHeader> {
30         use crate::TransportHeader::*;
31         if let Udp(value) = self {
32             Some(value)
33         } else {
34             None
35         }
36     }
37 
38     /// Returns Result::Some containing the tcp header if self has the value Tcp.
39     /// Otherwise None is returned.
tcp(self) -> Option<TcpHeader>40     pub fn tcp(self) -> Option<TcpHeader> {
41         use crate::TransportHeader::*;
42         if let Tcp(value) = self {
43             Some(value)
44         } else {
45             None
46         }
47     }
48 
49     /// Returns Result::Some containing a mutable reference to the tcp header if self has the value Tcp.
50     /// Otherwise None is returned.
mut_tcp(&mut self) -> Option<&mut TcpHeader>51     pub fn mut_tcp(&mut self) -> Option<&mut TcpHeader> {
52         use crate::TransportHeader::*;
53         if let Tcp(value) = self {
54             Some(value)
55         } else {
56             None
57         }
58     }
59 
60     /// Returns Result::Some containing the ICMPv4 header if self has the value Icmpv4.
61     /// Otherwise None is returned.
icmpv4(self) -> Option<Icmpv4Header>62     pub fn icmpv4(self) -> Option<Icmpv4Header> {
63         use crate::TransportHeader::*;
64         if let Icmpv4(value) = self {
65             Some(value)
66         } else {
67             None
68         }
69     }
70 
71     /// Returns Result::Some containing the ICMPv4 header if self has the value Icmpv4.
72     /// Otherwise None is returned.
mut_icmpv4(&mut self) -> Option<&mut Icmpv4Header>73     pub fn mut_icmpv4(&mut self) -> Option<&mut Icmpv4Header> {
74         use crate::TransportHeader::*;
75         if let Icmpv4(value) = self {
76             Some(value)
77         } else {
78             None
79         }
80     }
81 
82     /// Returns Result::Some containing the ICMPv6 header if self has the value Icmpv6.
83     /// Otherwise None is returned.
icmpv6(self) -> Option<Icmpv6Header>84     pub fn icmpv6(self) -> Option<Icmpv6Header> {
85         use crate::TransportHeader::*;
86         if let Icmpv6(value) = self {
87             Some(value)
88         } else {
89             None
90         }
91     }
92 
93     /// Returns Result::Some containing the ICMPv6 header if self has the value Icmpv6.
94     /// Otherwise None is returned.
mut_icmpv6(&mut self) -> Option<&mut Icmpv6Header>95     pub fn mut_icmpv6(&mut self) -> Option<&mut Icmpv6Header> {
96         use crate::TransportHeader::*;
97         if let Icmpv6(value) = self {
98             Some(value)
99         } else {
100             None
101         }
102     }
103 
104     /// Returns the size of the transport header (in case of UDP fixed,
105     /// in case of TCP cotanining the options).
header_len(&self) -> usize106     pub fn header_len(&self) -> usize {
107         use crate::TransportHeader::*;
108         match self {
109             Udp(_) => UdpHeader::LEN,
110             Tcp(value) => value.header_len(),
111             Icmpv4(value) => value.header_len(),
112             Icmpv6(value) => value.header_len(),
113         }
114     }
115 
116     /// Calculates the checksum for the transport header & sets it in the header for
117     /// an ipv4 header.
update_checksum_ipv4( &mut self, ip_header: &Ipv4Header, payload: &[u8], ) -> Result<(), TransportChecksumError>118     pub fn update_checksum_ipv4(
119         &mut self,
120         ip_header: &Ipv4Header,
121         payload: &[u8],
122     ) -> Result<(), TransportChecksumError> {
123         use crate::{err::packet::TransportChecksumError::*, TransportHeader::*};
124         match self {
125             Udp(header) => {
126                 header.checksum = header
127                     .calc_checksum_ipv4(ip_header, payload)
128                     .map_err(PayloadLen)?;
129             }
130             Tcp(header) => {
131                 header.checksum = header
132                     .calc_checksum_ipv4(ip_header, payload)
133                     .map_err(PayloadLen)?;
134             }
135             Icmpv4(header) => {
136                 header.update_checksum(payload);
137             }
138             Icmpv6(_) => return Err(Icmpv6InIpv4),
139         }
140         Ok(())
141     }
142 
143     /// Calculates the checksum for the transport header & sets it in the header for
144     /// an ipv6 header.
update_checksum_ipv6( &mut self, ip_header: &Ipv6Header, payload: &[u8], ) -> Result<(), ValueTooBigError<usize>>145     pub fn update_checksum_ipv6(
146         &mut self,
147         ip_header: &Ipv6Header,
148         payload: &[u8],
149     ) -> Result<(), ValueTooBigError<usize>> {
150         use crate::TransportHeader::*;
151         match self {
152             Icmpv4(header) => header.update_checksum(payload),
153             Icmpv6(header) => {
154                 header.update_checksum(ip_header.source, ip_header.destination, payload)?
155             }
156             Udp(header) => {
157                 header.checksum = header.calc_checksum_ipv6(ip_header, payload)?;
158             }
159             Tcp(header) => {
160                 header.checksum = header.calc_checksum_ipv6(ip_header, payload)?;
161             }
162         }
163         Ok(())
164     }
165 
166     /// Write the transport header to the given writer.
167     #[cfg(feature = "std")]
168     #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
write<T: std::io::Write + Sized>(&self, writer: &mut T) -> Result<(), std::io::Error>169     pub fn write<T: std::io::Write + Sized>(&self, writer: &mut T) -> Result<(), std::io::Error> {
170         use crate::TransportHeader::*;
171         match self {
172             Icmpv4(value) => value.write(writer),
173             Icmpv6(value) => value.write(writer),
174             Udp(value) => value.write(writer),
175             Tcp(value) => value.write(writer),
176         }
177     }
178 }
179 
180 #[cfg(test)]
181 mod test {
182     use crate::{test_gens::*, *};
183     use alloc::{format, vec::Vec};
184     use core::slice;
185     use proptest::prelude::*;
186     use std::io::Cursor;
187 
188     proptest! {
189         #[test]
190         fn debug(
191             tcp in tcp_any(),
192             udp in udp_any(),
193             icmpv4 in icmpv4_header_any(),
194             icmpv6 in icmpv6_header_any(),
195         ) {
196             use TransportHeader::*;
197             assert_eq!(
198                 format!("Udp({:?})", udp),
199                 format!("{:?}", Udp(udp.clone())),
200             );
201             assert_eq!(
202                 format!("Tcp({:?})", tcp),
203                 format!("{:?}", Tcp(tcp.clone())),
204             );
205             assert_eq!(
206                 format!("Icmpv4({:?})", icmpv4),
207                 format!("{:?}", Icmpv4(icmpv4.clone())),
208             );
209             assert_eq!(
210                 format!("Icmpv6({:?})", icmpv6),
211                 format!("{:?}", Icmpv6(icmpv6.clone())),
212             );
213         }
214     }
215 
216     proptest! {
217         #[test]
218         fn clone_eq(
219             tcp in tcp_any(),
220             udp in udp_any(),
221             icmpv4 in icmpv4_header_any(),
222             icmpv6 in icmpv6_header_any(),
223         ) {
224             use TransportHeader::*;
225             let values = [
226                 Udp(udp),
227                 Tcp(tcp),
228                 Icmpv4(icmpv4),
229                 Icmpv6(icmpv6),
230             ];
231             for value in values {
232                 assert_eq!(value.clone(), value);
233             }
234         }
235     }
236 
237     #[test]
udp()238     fn udp() {
239         let udp: UdpHeader = Default::default();
240         assert_eq!(Some(udp.clone()), TransportHeader::Udp(udp).udp());
241         assert_eq!(None, TransportHeader::Tcp(Default::default()).udp());
242     }
243     #[test]
mut_udp()244     fn mut_udp() {
245         let udp: UdpHeader = Default::default();
246         assert_eq!(Some(&mut udp.clone()), TransportHeader::Udp(udp).mut_udp());
247         assert_eq!(None, TransportHeader::Tcp(Default::default()).mut_udp());
248     }
249     #[test]
tcp()250     fn tcp() {
251         let tcp: TcpHeader = Default::default();
252         assert_eq!(Some(tcp.clone()), TransportHeader::Tcp(tcp).tcp());
253         assert_eq!(None, TransportHeader::Udp(Default::default()).tcp());
254     }
255     #[test]
mut_tcp()256     fn mut_tcp() {
257         let tcp: TcpHeader = Default::default();
258         assert_eq!(Some(&mut tcp.clone()), TransportHeader::Tcp(tcp).mut_tcp());
259         assert_eq!(None, TransportHeader::Udp(Default::default()).mut_tcp());
260     }
261     proptest! {
262         #[test]
263         fn icmpv4(icmpv4 in icmpv4_header_any()) {
264             assert_eq!(Some(icmpv4.clone()), TransportHeader::Icmpv4(icmpv4).icmpv4());
265             assert_eq!(None, TransportHeader::Udp(Default::default()).icmpv4());
266         }
267     }
268     proptest! {
269         #[test]
270         fn mut_icmpv4(icmpv4 in icmpv4_header_any()) {
271             assert_eq!(Some(&mut icmpv4.clone()), TransportHeader::Icmpv4(icmpv4).mut_icmpv4());
272             assert_eq!(None, TransportHeader::Udp(Default::default()).mut_icmpv4());
273         }
274     }
275     proptest! {
276         #[test]
277         fn icmpv6(icmpv6 in icmpv6_header_any()) {
278             assert_eq!(Some(icmpv6.clone()), TransportHeader::Icmpv6(icmpv6).icmpv6());
279             assert_eq!(None, TransportHeader::Udp(Default::default()).icmpv6());
280         }
281     }
282     proptest! {
283         #[test]
284         fn mut_icmpv6(icmpv6 in icmpv6_header_any()) {
285             assert_eq!(Some(&mut icmpv6.clone()), TransportHeader::Icmpv6(icmpv6).mut_icmpv6());
286             assert_eq!(None, TransportHeader::Udp(Default::default()).mut_icmpv6());
287         }
288     }
289     proptest! {
290         #[test]
291         fn header_size(
292             udp in udp_any(),
293             tcp in tcp_any(),
294             icmpv4 in icmpv4_header_any(),
295             icmpv6 in icmpv6_header_any(),
296         ) {
297             assert_eq!(
298                 TransportHeader::Udp(udp).header_len(),
299                 UdpHeader::LEN
300             );
301             assert_eq!(
302                 TransportHeader::Tcp(tcp.clone()).header_len(),
303                 tcp.header_len() as usize
304             );
305             assert_eq!(
306                 TransportHeader::Icmpv4(icmpv4.clone()).header_len(),
307                 icmpv4.header_len()
308             );
309             assert_eq!(
310                 TransportHeader::Icmpv6(icmpv6.clone()).header_len(),
311                 icmpv6.header_len()
312             );
313         }
314     }
315     proptest! {
316         #[test]
317         fn update_checksum_ipv4(
318             ipv4 in ipv4_any(),
319             udp in udp_any(),
320             tcp in tcp_any(),
321             icmpv4 in icmpv4_header_any(),
322             icmpv6 in icmpv6_header_any()
323         ) {
324             use TransportHeader::*;
325             use crate::err::{ValueTooBigError, ValueType, packet::TransportChecksumError::*};
326 
327             // udp
328             {
329                 // ok case
330                 {
331                     let mut transport = Udp(udp.clone());
332                     let payload = Vec::new();
333                     transport.update_checksum_ipv4(&ipv4, &payload).unwrap();
334                     assert_eq!(transport.udp().unwrap().checksum,
335                                udp.calc_checksum_ipv4(&ipv4, &payload).unwrap());
336                 }
337                 // error case
338                 {
339                     let mut transport = Udp(udp.clone());
340                     let len = (core::u16::MAX as usize) - UdpHeader::LEN + 1;
341                     let tcp_payload = unsafe {
342                         //NOTE: The pointer must be initialized with a non null value
343                         //      otherwise a key constraint of slices is not fulfilled
344                         //      which can lead to crashes in release mode.
345                         use core::ptr::NonNull;
346                         slice::from_raw_parts(
347                             NonNull::<u8>::dangling().as_ptr(),
348                             len
349                         )
350                     };
351                     assert_eq!(
352                         transport.update_checksum_ipv4(&ipv4, &tcp_payload),
353                         Err(PayloadLen(ValueTooBigError{
354                             actual: len,
355                             max_allowed: (core::u16::MAX as usize) - UdpHeader::LEN,
356                             value_type: ValueType::UdpPayloadLengthIpv4
357                         }))
358                     );
359                 }
360             }
361             // tcp
362             {
363                 //ok case
364                 {
365                     let mut transport = Tcp(tcp.clone());
366                     let payload = Vec::new();
367                     transport.update_checksum_ipv4(&ipv4, &payload).unwrap();
368                     assert_eq!(transport.tcp().unwrap().checksum,
369                                tcp.calc_checksum_ipv4(&ipv4, &payload).unwrap());
370                 }
371                 //error case
372                 {
373                     let mut transport = Tcp(tcp.clone());
374                     let len = (core::u16::MAX - tcp.header_len_u16()) as usize + 1;
375                     let tcp_payload = unsafe {
376                         //NOTE: The pointer must be initialized with a non null value
377                         //      otherwise a key constraint of slices is not fulfilled
378                         //      which can lead to crashes in release mode.
379                         use core::ptr::NonNull;
380                         slice::from_raw_parts(
381                             NonNull::<u8>::dangling().as_ptr(),
382                             len
383                         )
384                     };
385                     assert_eq!(
386                         transport.update_checksum_ipv4(&ipv4, &tcp_payload),
387                         Err(PayloadLen(ValueTooBigError{
388                             actual: len,
389                             max_allowed: (core::u16::MAX as usize) - usize::from(tcp.header_len()),
390                             value_type: ValueType::TcpPayloadLengthIpv4
391                         }))
392                     );
393                 }
394             }
395 
396             // icmpv4
397             {
398                 let mut transport = Icmpv4(icmpv4.clone());
399                 let payload = Vec::new();
400                 transport.update_checksum_ipv4(&ipv4, &payload).unwrap();
401                 assert_eq!(
402                     transport.icmpv4().unwrap().checksum,
403                     icmpv4.icmp_type.calc_checksum(&payload)
404                 );
405             }
406 
407             // icmpv6 (error)
408             assert_eq!(
409                 Icmpv6(icmpv6).update_checksum_ipv4(&ipv4, &[]),
410                 Err(Icmpv6InIpv4)
411             );
412         }
413     }
414 
415     proptest! {
416         #[test]
417         #[cfg(target_pointer_width = "64")]
418         fn update_checksum_ipv6(
419             ipv6 in ipv6_any(),
420             udp in udp_any(),
421             tcp in tcp_any(),
422             icmpv4 in icmpv4_header_any(),
423             icmpv6 in icmpv6_header_any(),
424         ) {
425             use TransportHeader::*;
426             use crate::err::{ValueTooBigError, ValueType};
427 
428             // udp
429             {
430                 //ok case
431                 {
432                     let mut transport = Udp(udp.clone());
433                     let payload = Vec::new();
434                     transport.update_checksum_ipv6(&ipv6, &payload).unwrap();
435                     assert_eq!(transport.udp().unwrap().checksum,
436                                udp.calc_checksum_ipv6(&ipv6, &payload).unwrap());
437                 }
438                 //error case
439                 {
440                     let mut transport = Udp(udp.clone());
441                     let len = (core::u32::MAX as usize) - UdpHeader::LEN + 1;
442                     let payload = unsafe {
443                         //NOTE: The pointer must be initialized with a non null value
444                         //      otherwise a key constraint of slices is not fulfilled
445                         //      which can lead to crashes in release mode.
446                         use core::ptr::NonNull;
447                         slice::from_raw_parts(
448                             NonNull::<u8>::dangling().as_ptr(),
449                             len
450                         )
451                     };
452                     assert_eq!(
453                         transport.update_checksum_ipv6(&ipv6, &payload),
454                         Err(ValueTooBigError{
455                             actual: len,
456                             max_allowed: (core::u32::MAX as usize) - UdpHeader::LEN,
457                             value_type: ValueType::UdpPayloadLengthIpv6
458                         })
459                     );
460                 }
461             }
462 
463             // tcp
464             {
465                 //ok case
466                 {
467                     let mut transport = Tcp(tcp.clone());
468                     let payload = Vec::new();
469                     transport.update_checksum_ipv6(&ipv6, &payload).unwrap();
470                     assert_eq!(transport.tcp().unwrap().checksum,
471                                tcp.calc_checksum_ipv6(&ipv6, &payload).unwrap());
472                 }
473                 //error case
474                 {
475                     let mut transport = Tcp(tcp.clone());
476                     let len = (core::u32::MAX - tcp.header_len() as u32) as usize + 1;
477                     let tcp_payload = unsafe {
478                         //NOTE: The pointer must be initialized with a non null value
479                         //      otherwise a key constraint of slices is not fulfilled
480                         //      which can lead to crashes in release mode.
481                         use core::ptr::NonNull;
482                         slice::from_raw_parts(
483                             NonNull::<u8>::dangling().as_ptr(),
484                             len
485                         )
486                     };
487                     assert_eq!(
488                         transport.update_checksum_ipv6(&ipv6, &tcp_payload),
489                         Err(ValueTooBigError{
490                             actual: len,
491                             max_allowed: (core::u32::MAX - tcp.header_len() as u32) as usize,
492                             value_type: ValueType::TcpPayloadLengthIpv6
493                         })
494                     );
495                 }
496             }
497 
498             // icmpv4
499             {
500                 let mut transport = Icmpv4(icmpv4.clone());
501                 let payload = Vec::new();
502                 transport.update_checksum_ipv6(&ipv6, &payload).unwrap();
503                 assert_eq!(
504                     transport.icmpv4().unwrap().checksum,
505                     icmpv4.icmp_type.calc_checksum(&payload)
506                 );
507             }
508 
509             // icmpv6
510             {
511                 // normal case
512                 {
513                     let mut transport = Icmpv6(icmpv6.clone());
514                     let payload = Vec::new();
515                     transport.update_checksum_ipv6(&ipv6, &payload).unwrap();
516                     assert_eq!(
517                         transport.icmpv6().unwrap().checksum,
518                         icmpv6.icmp_type.calc_checksum(ipv6.source, ipv6.destination, &payload).unwrap()
519                     );
520                 }
521 
522                 // error case
523                 {
524                     let mut transport = Icmpv6(icmpv6.clone());
525                     // SAFETY: In case the error is not triggered
526                     //         a segmentation fault will be triggered.
527                     let too_big_slice = unsafe {
528                         //NOTE: The pointer must be initialized with a non null value
529                         //      otherwise a key constraint of slices is not fulfilled
530                         //      which can lead to crashes in release mode.
531                         use core::ptr::NonNull;
532                         core::slice::from_raw_parts(
533                             NonNull::<u8>::dangling().as_ptr(),
534                             (core::u32::MAX - 7) as usize
535                         )
536                     };
537                     assert_eq!(
538                         transport.update_checksum_ipv6(&ipv6, too_big_slice),
539                         Err(ValueTooBigError{
540                             actual: too_big_slice.len(),
541                             max_allowed: (core::u32::MAX - 8) as usize,
542                             value_type: ValueType::Icmpv6PayloadLength,
543                         })
544                     );
545                 }
546             }
547         }
548     }
549 
550     proptest! {
551         #[test]
552         fn write(
553             udp in udp_any(),
554             tcp in tcp_any(),
555             icmpv4 in icmpv4_header_any(),
556             icmpv6 in icmpv6_header_any(),
557         ) {
558             // udp
559             {
560                 //write
561                 {
562                     let result_input = {
563                         let mut buffer = Vec::new();
564                         udp.write(&mut buffer).unwrap();
565                         buffer
566                     };
567                     let result_transport = {
568                         let mut buffer = Vec::new();
569                         TransportHeader::Udp(udp.clone()).write(&mut buffer).unwrap();
570                         buffer
571                     };
572                     assert_eq!(result_input, result_transport);
573                 }
574                 //trigger an error
575                 {
576                     let mut a: [u8;0] = [];
577                     assert!(
578                         TransportHeader::Udp(udp.clone())
579                         .write(&mut Cursor::new(&mut a[..]))
580                         .is_err()
581                     );
582                 }
583             }
584             // tcp
585             {
586                 //write
587                 {
588                     let result_input = {
589                         let mut buffer = Vec::new();
590                         tcp.write(&mut buffer).unwrap();
591                         buffer
592                     };
593                     let result_transport = {
594                         let mut buffer = Vec::new();
595                         TransportHeader::Tcp(tcp.clone()).write(&mut buffer).unwrap();
596                         buffer
597                     };
598                     assert_eq!(result_input, result_transport);
599                 }
600                 //trigger an error
601                 {
602                     let mut a: [u8;0] = [];
603                     assert!(
604                         TransportHeader::Tcp(tcp.clone())
605                         .write(&mut Cursor::new(&mut a[..]))
606                         .is_err()
607                     );
608                 }
609             }
610 
611             // icmpv4
612             {
613                 // normal write
614                 {
615                     let result_input = {
616                         let mut buffer = Vec::new();
617                         icmpv4.write(&mut buffer).unwrap();
618                         buffer
619                     };
620                     let result_transport = {
621                         let mut buffer = Vec::new();
622                         TransportHeader::Icmpv4(icmpv4.clone()).write(&mut buffer).unwrap();
623                         buffer
624                     };
625                     assert_eq!(result_input, result_transport);
626                 }
627 
628                 // error during write
629                 {
630                     let mut a: [u8;0] = [];
631                     assert!(
632                         TransportHeader::Icmpv4(icmpv4.clone())
633                         .write(&mut Cursor::new(&mut a[..]))
634                         .is_err()
635                     );
636                 }
637             }
638 
639             // icmpv6
640             {
641                 // normal write
642                 {
643                     let result_input = {
644                         let mut buffer = Vec::new();
645                         icmpv6.write(&mut buffer).unwrap();
646                         buffer
647                     };
648                     let result_transport = {
649                         let mut buffer = Vec::new();
650                         TransportHeader::Icmpv6(icmpv6.clone()).write(&mut buffer).unwrap();
651                         buffer
652                     };
653                     assert_eq!(result_input, result_transport);
654                 }
655 
656                 // error during write
657                 {
658                     let mut a: [u8;0] = [];
659                     assert!(
660                         TransportHeader::Icmpv6(icmpv6.clone())
661                         .write(&mut Cursor::new(&mut a[..]))
662                         .is_err()
663                     );
664                 }
665             }
666         }
667     }
668 }
669