1 use crate::{ 2 err::{packet::TransportChecksumError, ValueTooBigError}, 3 *, 4 }; 5 6 /// The possible headers on the transport layer 7 #[derive(Clone, Debug, Eq, PartialEq)] 8 pub enum TransportHeader { 9 Udp(UdpHeader), 10 Tcp(TcpHeader), 11 Icmpv4(Icmpv4Header), 12 Icmpv6(Icmpv6Header), 13 } 14 15 impl TransportHeader { 16 /// Returns Result::Some containing the udp header if self has the value Udp. 17 /// Otherwise None is returned. udp(self) -> Option<UdpHeader>18 pub fn udp(self) -> Option<UdpHeader> { 19 use crate::TransportHeader::*; 20 if let Udp(value) = self { 21 Some(value) 22 } else { 23 None 24 } 25 } 26 27 /// Returns Result::Some containing the udp header if self has the value Udp. 28 /// Otherwise None is returned. mut_udp(&mut self) -> Option<&mut UdpHeader>29 pub fn mut_udp(&mut self) -> Option<&mut UdpHeader> { 30 use crate::TransportHeader::*; 31 if let Udp(value) = self { 32 Some(value) 33 } else { 34 None 35 } 36 } 37 38 /// Returns Result::Some containing the tcp header if self has the value Tcp. 39 /// Otherwise None is returned. tcp(self) -> Option<TcpHeader>40 pub fn tcp(self) -> Option<TcpHeader> { 41 use crate::TransportHeader::*; 42 if let Tcp(value) = self { 43 Some(value) 44 } else { 45 None 46 } 47 } 48 49 /// Returns Result::Some containing a mutable reference to the tcp header if self has the value Tcp. 50 /// Otherwise None is returned. mut_tcp(&mut self) -> Option<&mut TcpHeader>51 pub fn mut_tcp(&mut self) -> Option<&mut TcpHeader> { 52 use crate::TransportHeader::*; 53 if let Tcp(value) = self { 54 Some(value) 55 } else { 56 None 57 } 58 } 59 60 /// Returns Result::Some containing the ICMPv4 header if self has the value Icmpv4. 61 /// Otherwise None is returned. icmpv4(self) -> Option<Icmpv4Header>62 pub fn icmpv4(self) -> Option<Icmpv4Header> { 63 use crate::TransportHeader::*; 64 if let Icmpv4(value) = self { 65 Some(value) 66 } else { 67 None 68 } 69 } 70 71 /// Returns Result::Some containing the ICMPv4 header if self has the value Icmpv4. 72 /// Otherwise None is returned. mut_icmpv4(&mut self) -> Option<&mut Icmpv4Header>73 pub fn mut_icmpv4(&mut self) -> Option<&mut Icmpv4Header> { 74 use crate::TransportHeader::*; 75 if let Icmpv4(value) = self { 76 Some(value) 77 } else { 78 None 79 } 80 } 81 82 /// Returns Result::Some containing the ICMPv6 header if self has the value Icmpv6. 83 /// Otherwise None is returned. icmpv6(self) -> Option<Icmpv6Header>84 pub fn icmpv6(self) -> Option<Icmpv6Header> { 85 use crate::TransportHeader::*; 86 if let Icmpv6(value) = self { 87 Some(value) 88 } else { 89 None 90 } 91 } 92 93 /// Returns Result::Some containing the ICMPv6 header if self has the value Icmpv6. 94 /// Otherwise None is returned. mut_icmpv6(&mut self) -> Option<&mut Icmpv6Header>95 pub fn mut_icmpv6(&mut self) -> Option<&mut Icmpv6Header> { 96 use crate::TransportHeader::*; 97 if let Icmpv6(value) = self { 98 Some(value) 99 } else { 100 None 101 } 102 } 103 104 /// Returns the size of the transport header (in case of UDP fixed, 105 /// in case of TCP cotanining the options). header_len(&self) -> usize106 pub fn header_len(&self) -> usize { 107 use crate::TransportHeader::*; 108 match self { 109 Udp(_) => UdpHeader::LEN, 110 Tcp(value) => value.header_len(), 111 Icmpv4(value) => value.header_len(), 112 Icmpv6(value) => value.header_len(), 113 } 114 } 115 116 /// Calculates the checksum for the transport header & sets it in the header for 117 /// an ipv4 header. update_checksum_ipv4( &mut self, ip_header: &Ipv4Header, payload: &[u8], ) -> Result<(), TransportChecksumError>118 pub fn update_checksum_ipv4( 119 &mut self, 120 ip_header: &Ipv4Header, 121 payload: &[u8], 122 ) -> Result<(), TransportChecksumError> { 123 use crate::{err::packet::TransportChecksumError::*, TransportHeader::*}; 124 match self { 125 Udp(header) => { 126 header.checksum = header 127 .calc_checksum_ipv4(ip_header, payload) 128 .map_err(PayloadLen)?; 129 } 130 Tcp(header) => { 131 header.checksum = header 132 .calc_checksum_ipv4(ip_header, payload) 133 .map_err(PayloadLen)?; 134 } 135 Icmpv4(header) => { 136 header.update_checksum(payload); 137 } 138 Icmpv6(_) => return Err(Icmpv6InIpv4), 139 } 140 Ok(()) 141 } 142 143 /// Calculates the checksum for the transport header & sets it in the header for 144 /// an ipv6 header. update_checksum_ipv6( &mut self, ip_header: &Ipv6Header, payload: &[u8], ) -> Result<(), ValueTooBigError<usize>>145 pub fn update_checksum_ipv6( 146 &mut self, 147 ip_header: &Ipv6Header, 148 payload: &[u8], 149 ) -> Result<(), ValueTooBigError<usize>> { 150 use crate::TransportHeader::*; 151 match self { 152 Icmpv4(header) => header.update_checksum(payload), 153 Icmpv6(header) => { 154 header.update_checksum(ip_header.source, ip_header.destination, payload)? 155 } 156 Udp(header) => { 157 header.checksum = header.calc_checksum_ipv6(ip_header, payload)?; 158 } 159 Tcp(header) => { 160 header.checksum = header.calc_checksum_ipv6(ip_header, payload)?; 161 } 162 } 163 Ok(()) 164 } 165 166 /// Write the transport header to the given writer. 167 #[cfg(feature = "std")] 168 #[cfg_attr(docsrs, doc(cfg(feature = "std")))] write<T: std::io::Write + Sized>(&self, writer: &mut T) -> Result<(), std::io::Error>169 pub fn write<T: std::io::Write + Sized>(&self, writer: &mut T) -> Result<(), std::io::Error> { 170 use crate::TransportHeader::*; 171 match self { 172 Icmpv4(value) => value.write(writer), 173 Icmpv6(value) => value.write(writer), 174 Udp(value) => value.write(writer), 175 Tcp(value) => value.write(writer), 176 } 177 } 178 } 179 180 #[cfg(test)] 181 mod test { 182 use crate::{test_gens::*, *}; 183 use alloc::{format, vec::Vec}; 184 use core::slice; 185 use proptest::prelude::*; 186 use std::io::Cursor; 187 188 proptest! { 189 #[test] 190 fn debug( 191 tcp in tcp_any(), 192 udp in udp_any(), 193 icmpv4 in icmpv4_header_any(), 194 icmpv6 in icmpv6_header_any(), 195 ) { 196 use TransportHeader::*; 197 assert_eq!( 198 format!("Udp({:?})", udp), 199 format!("{:?}", Udp(udp.clone())), 200 ); 201 assert_eq!( 202 format!("Tcp({:?})", tcp), 203 format!("{:?}", Tcp(tcp.clone())), 204 ); 205 assert_eq!( 206 format!("Icmpv4({:?})", icmpv4), 207 format!("{:?}", Icmpv4(icmpv4.clone())), 208 ); 209 assert_eq!( 210 format!("Icmpv6({:?})", icmpv6), 211 format!("{:?}", Icmpv6(icmpv6.clone())), 212 ); 213 } 214 } 215 216 proptest! { 217 #[test] 218 fn clone_eq( 219 tcp in tcp_any(), 220 udp in udp_any(), 221 icmpv4 in icmpv4_header_any(), 222 icmpv6 in icmpv6_header_any(), 223 ) { 224 use TransportHeader::*; 225 let values = [ 226 Udp(udp), 227 Tcp(tcp), 228 Icmpv4(icmpv4), 229 Icmpv6(icmpv6), 230 ]; 231 for value in values { 232 assert_eq!(value.clone(), value); 233 } 234 } 235 } 236 237 #[test] udp()238 fn udp() { 239 let udp: UdpHeader = Default::default(); 240 assert_eq!(Some(udp.clone()), TransportHeader::Udp(udp).udp()); 241 assert_eq!(None, TransportHeader::Tcp(Default::default()).udp()); 242 } 243 #[test] mut_udp()244 fn mut_udp() { 245 let udp: UdpHeader = Default::default(); 246 assert_eq!(Some(&mut udp.clone()), TransportHeader::Udp(udp).mut_udp()); 247 assert_eq!(None, TransportHeader::Tcp(Default::default()).mut_udp()); 248 } 249 #[test] tcp()250 fn tcp() { 251 let tcp: TcpHeader = Default::default(); 252 assert_eq!(Some(tcp.clone()), TransportHeader::Tcp(tcp).tcp()); 253 assert_eq!(None, TransportHeader::Udp(Default::default()).tcp()); 254 } 255 #[test] mut_tcp()256 fn mut_tcp() { 257 let tcp: TcpHeader = Default::default(); 258 assert_eq!(Some(&mut tcp.clone()), TransportHeader::Tcp(tcp).mut_tcp()); 259 assert_eq!(None, TransportHeader::Udp(Default::default()).mut_tcp()); 260 } 261 proptest! { 262 #[test] 263 fn icmpv4(icmpv4 in icmpv4_header_any()) { 264 assert_eq!(Some(icmpv4.clone()), TransportHeader::Icmpv4(icmpv4).icmpv4()); 265 assert_eq!(None, TransportHeader::Udp(Default::default()).icmpv4()); 266 } 267 } 268 proptest! { 269 #[test] 270 fn mut_icmpv4(icmpv4 in icmpv4_header_any()) { 271 assert_eq!(Some(&mut icmpv4.clone()), TransportHeader::Icmpv4(icmpv4).mut_icmpv4()); 272 assert_eq!(None, TransportHeader::Udp(Default::default()).mut_icmpv4()); 273 } 274 } 275 proptest! { 276 #[test] 277 fn icmpv6(icmpv6 in icmpv6_header_any()) { 278 assert_eq!(Some(icmpv6.clone()), TransportHeader::Icmpv6(icmpv6).icmpv6()); 279 assert_eq!(None, TransportHeader::Udp(Default::default()).icmpv6()); 280 } 281 } 282 proptest! { 283 #[test] 284 fn mut_icmpv6(icmpv6 in icmpv6_header_any()) { 285 assert_eq!(Some(&mut icmpv6.clone()), TransportHeader::Icmpv6(icmpv6).mut_icmpv6()); 286 assert_eq!(None, TransportHeader::Udp(Default::default()).mut_icmpv6()); 287 } 288 } 289 proptest! { 290 #[test] 291 fn header_size( 292 udp in udp_any(), 293 tcp in tcp_any(), 294 icmpv4 in icmpv4_header_any(), 295 icmpv6 in icmpv6_header_any(), 296 ) { 297 assert_eq!( 298 TransportHeader::Udp(udp).header_len(), 299 UdpHeader::LEN 300 ); 301 assert_eq!( 302 TransportHeader::Tcp(tcp.clone()).header_len(), 303 tcp.header_len() as usize 304 ); 305 assert_eq!( 306 TransportHeader::Icmpv4(icmpv4.clone()).header_len(), 307 icmpv4.header_len() 308 ); 309 assert_eq!( 310 TransportHeader::Icmpv6(icmpv6.clone()).header_len(), 311 icmpv6.header_len() 312 ); 313 } 314 } 315 proptest! { 316 #[test] 317 fn update_checksum_ipv4( 318 ipv4 in ipv4_any(), 319 udp in udp_any(), 320 tcp in tcp_any(), 321 icmpv4 in icmpv4_header_any(), 322 icmpv6 in icmpv6_header_any() 323 ) { 324 use TransportHeader::*; 325 use crate::err::{ValueTooBigError, ValueType, packet::TransportChecksumError::*}; 326 327 // udp 328 { 329 // ok case 330 { 331 let mut transport = Udp(udp.clone()); 332 let payload = Vec::new(); 333 transport.update_checksum_ipv4(&ipv4, &payload).unwrap(); 334 assert_eq!(transport.udp().unwrap().checksum, 335 udp.calc_checksum_ipv4(&ipv4, &payload).unwrap()); 336 } 337 // error case 338 { 339 let mut transport = Udp(udp.clone()); 340 let len = (core::u16::MAX as usize) - UdpHeader::LEN + 1; 341 let tcp_payload = unsafe { 342 //NOTE: The pointer must be initialized with a non null value 343 // otherwise a key constraint of slices is not fulfilled 344 // which can lead to crashes in release mode. 345 use core::ptr::NonNull; 346 slice::from_raw_parts( 347 NonNull::<u8>::dangling().as_ptr(), 348 len 349 ) 350 }; 351 assert_eq!( 352 transport.update_checksum_ipv4(&ipv4, &tcp_payload), 353 Err(PayloadLen(ValueTooBigError{ 354 actual: len, 355 max_allowed: (core::u16::MAX as usize) - UdpHeader::LEN, 356 value_type: ValueType::UdpPayloadLengthIpv4 357 })) 358 ); 359 } 360 } 361 // tcp 362 { 363 //ok case 364 { 365 let mut transport = Tcp(tcp.clone()); 366 let payload = Vec::new(); 367 transport.update_checksum_ipv4(&ipv4, &payload).unwrap(); 368 assert_eq!(transport.tcp().unwrap().checksum, 369 tcp.calc_checksum_ipv4(&ipv4, &payload).unwrap()); 370 } 371 //error case 372 { 373 let mut transport = Tcp(tcp.clone()); 374 let len = (core::u16::MAX - tcp.header_len_u16()) as usize + 1; 375 let tcp_payload = unsafe { 376 //NOTE: The pointer must be initialized with a non null value 377 // otherwise a key constraint of slices is not fulfilled 378 // which can lead to crashes in release mode. 379 use core::ptr::NonNull; 380 slice::from_raw_parts( 381 NonNull::<u8>::dangling().as_ptr(), 382 len 383 ) 384 }; 385 assert_eq!( 386 transport.update_checksum_ipv4(&ipv4, &tcp_payload), 387 Err(PayloadLen(ValueTooBigError{ 388 actual: len, 389 max_allowed: (core::u16::MAX as usize) - usize::from(tcp.header_len()), 390 value_type: ValueType::TcpPayloadLengthIpv4 391 })) 392 ); 393 } 394 } 395 396 // icmpv4 397 { 398 let mut transport = Icmpv4(icmpv4.clone()); 399 let payload = Vec::new(); 400 transport.update_checksum_ipv4(&ipv4, &payload).unwrap(); 401 assert_eq!( 402 transport.icmpv4().unwrap().checksum, 403 icmpv4.icmp_type.calc_checksum(&payload) 404 ); 405 } 406 407 // icmpv6 (error) 408 assert_eq!( 409 Icmpv6(icmpv6).update_checksum_ipv4(&ipv4, &[]), 410 Err(Icmpv6InIpv4) 411 ); 412 } 413 } 414 415 proptest! { 416 #[test] 417 #[cfg(target_pointer_width = "64")] 418 fn update_checksum_ipv6( 419 ipv6 in ipv6_any(), 420 udp in udp_any(), 421 tcp in tcp_any(), 422 icmpv4 in icmpv4_header_any(), 423 icmpv6 in icmpv6_header_any(), 424 ) { 425 use TransportHeader::*; 426 use crate::err::{ValueTooBigError, ValueType}; 427 428 // udp 429 { 430 //ok case 431 { 432 let mut transport = Udp(udp.clone()); 433 let payload = Vec::new(); 434 transport.update_checksum_ipv6(&ipv6, &payload).unwrap(); 435 assert_eq!(transport.udp().unwrap().checksum, 436 udp.calc_checksum_ipv6(&ipv6, &payload).unwrap()); 437 } 438 //error case 439 { 440 let mut transport = Udp(udp.clone()); 441 let len = (core::u32::MAX as usize) - UdpHeader::LEN + 1; 442 let payload = unsafe { 443 //NOTE: The pointer must be initialized with a non null value 444 // otherwise a key constraint of slices is not fulfilled 445 // which can lead to crashes in release mode. 446 use core::ptr::NonNull; 447 slice::from_raw_parts( 448 NonNull::<u8>::dangling().as_ptr(), 449 len 450 ) 451 }; 452 assert_eq!( 453 transport.update_checksum_ipv6(&ipv6, &payload), 454 Err(ValueTooBigError{ 455 actual: len, 456 max_allowed: (core::u32::MAX as usize) - UdpHeader::LEN, 457 value_type: ValueType::UdpPayloadLengthIpv6 458 }) 459 ); 460 } 461 } 462 463 // tcp 464 { 465 //ok case 466 { 467 let mut transport = Tcp(tcp.clone()); 468 let payload = Vec::new(); 469 transport.update_checksum_ipv6(&ipv6, &payload).unwrap(); 470 assert_eq!(transport.tcp().unwrap().checksum, 471 tcp.calc_checksum_ipv6(&ipv6, &payload).unwrap()); 472 } 473 //error case 474 { 475 let mut transport = Tcp(tcp.clone()); 476 let len = (core::u32::MAX - tcp.header_len() as u32) as usize + 1; 477 let tcp_payload = unsafe { 478 //NOTE: The pointer must be initialized with a non null value 479 // otherwise a key constraint of slices is not fulfilled 480 // which can lead to crashes in release mode. 481 use core::ptr::NonNull; 482 slice::from_raw_parts( 483 NonNull::<u8>::dangling().as_ptr(), 484 len 485 ) 486 }; 487 assert_eq!( 488 transport.update_checksum_ipv6(&ipv6, &tcp_payload), 489 Err(ValueTooBigError{ 490 actual: len, 491 max_allowed: (core::u32::MAX - tcp.header_len() as u32) as usize, 492 value_type: ValueType::TcpPayloadLengthIpv6 493 }) 494 ); 495 } 496 } 497 498 // icmpv4 499 { 500 let mut transport = Icmpv4(icmpv4.clone()); 501 let payload = Vec::new(); 502 transport.update_checksum_ipv6(&ipv6, &payload).unwrap(); 503 assert_eq!( 504 transport.icmpv4().unwrap().checksum, 505 icmpv4.icmp_type.calc_checksum(&payload) 506 ); 507 } 508 509 // icmpv6 510 { 511 // normal case 512 { 513 let mut transport = Icmpv6(icmpv6.clone()); 514 let payload = Vec::new(); 515 transport.update_checksum_ipv6(&ipv6, &payload).unwrap(); 516 assert_eq!( 517 transport.icmpv6().unwrap().checksum, 518 icmpv6.icmp_type.calc_checksum(ipv6.source, ipv6.destination, &payload).unwrap() 519 ); 520 } 521 522 // error case 523 { 524 let mut transport = Icmpv6(icmpv6.clone()); 525 // SAFETY: In case the error is not triggered 526 // a segmentation fault will be triggered. 527 let too_big_slice = unsafe { 528 //NOTE: The pointer must be initialized with a non null value 529 // otherwise a key constraint of slices is not fulfilled 530 // which can lead to crashes in release mode. 531 use core::ptr::NonNull; 532 core::slice::from_raw_parts( 533 NonNull::<u8>::dangling().as_ptr(), 534 (core::u32::MAX - 7) as usize 535 ) 536 }; 537 assert_eq!( 538 transport.update_checksum_ipv6(&ipv6, too_big_slice), 539 Err(ValueTooBigError{ 540 actual: too_big_slice.len(), 541 max_allowed: (core::u32::MAX - 8) as usize, 542 value_type: ValueType::Icmpv6PayloadLength, 543 }) 544 ); 545 } 546 } 547 } 548 } 549 550 proptest! { 551 #[test] 552 fn write( 553 udp in udp_any(), 554 tcp in tcp_any(), 555 icmpv4 in icmpv4_header_any(), 556 icmpv6 in icmpv6_header_any(), 557 ) { 558 // udp 559 { 560 //write 561 { 562 let result_input = { 563 let mut buffer = Vec::new(); 564 udp.write(&mut buffer).unwrap(); 565 buffer 566 }; 567 let result_transport = { 568 let mut buffer = Vec::new(); 569 TransportHeader::Udp(udp.clone()).write(&mut buffer).unwrap(); 570 buffer 571 }; 572 assert_eq!(result_input, result_transport); 573 } 574 //trigger an error 575 { 576 let mut a: [u8;0] = []; 577 assert!( 578 TransportHeader::Udp(udp.clone()) 579 .write(&mut Cursor::new(&mut a[..])) 580 .is_err() 581 ); 582 } 583 } 584 // tcp 585 { 586 //write 587 { 588 let result_input = { 589 let mut buffer = Vec::new(); 590 tcp.write(&mut buffer).unwrap(); 591 buffer 592 }; 593 let result_transport = { 594 let mut buffer = Vec::new(); 595 TransportHeader::Tcp(tcp.clone()).write(&mut buffer).unwrap(); 596 buffer 597 }; 598 assert_eq!(result_input, result_transport); 599 } 600 //trigger an error 601 { 602 let mut a: [u8;0] = []; 603 assert!( 604 TransportHeader::Tcp(tcp.clone()) 605 .write(&mut Cursor::new(&mut a[..])) 606 .is_err() 607 ); 608 } 609 } 610 611 // icmpv4 612 { 613 // normal write 614 { 615 let result_input = { 616 let mut buffer = Vec::new(); 617 icmpv4.write(&mut buffer).unwrap(); 618 buffer 619 }; 620 let result_transport = { 621 let mut buffer = Vec::new(); 622 TransportHeader::Icmpv4(icmpv4.clone()).write(&mut buffer).unwrap(); 623 buffer 624 }; 625 assert_eq!(result_input, result_transport); 626 } 627 628 // error during write 629 { 630 let mut a: [u8;0] = []; 631 assert!( 632 TransportHeader::Icmpv4(icmpv4.clone()) 633 .write(&mut Cursor::new(&mut a[..])) 634 .is_err() 635 ); 636 } 637 } 638 639 // icmpv6 640 { 641 // normal write 642 { 643 let result_input = { 644 let mut buffer = Vec::new(); 645 icmpv6.write(&mut buffer).unwrap(); 646 buffer 647 }; 648 let result_transport = { 649 let mut buffer = Vec::new(); 650 TransportHeader::Icmpv6(icmpv6.clone()).write(&mut buffer).unwrap(); 651 buffer 652 }; 653 assert_eq!(result_input, result_transport); 654 } 655 656 // error during write 657 { 658 let mut a: [u8;0] = []; 659 assert!( 660 TransportHeader::Icmpv6(icmpv6.clone()) 661 .write(&mut Cursor::new(&mut a[..])) 662 .is_err() 663 ); 664 } 665 } 666 } 667 } 668 } 669