1#!/usr/bin/python3 2# 3# Copyright 2014 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Base module for multinetwork tests.""" 18 19import errno 20import fcntl 21import os 22import posix 23import random 24import re 25from socket import * # pylint: disable=wildcard-import 26import struct 27import time 28 29from scapy import all as scapy 30 31import csocket 32import iproute 33import net_test 34 35 36IFF_TUN = 1 37IFF_TAP = 2 38IFF_NO_PI = 0x1000 39TUNSETIFF = 0x400454ca 40 41SO_BINDTODEVICE = 25 42 43# Setsockopt values. 44IP_UNICAST_IF = 50 45IPV6_MULTICAST_IF = 17 46IPV6_UNICAST_IF = 76 47 48# Cmsg values. 49IP_TTL = 2 50IPV6_2292PKTOPTIONS = 6 51IPV6_FLOWINFO = 11 52IPV6_HOPLIMIT = 52 # Different from IPV6_UNICAST_HOPS, this is cmsg only. 53 54 55ACCEPT_RA_MIN_LFT_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_min_lft" 56AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table" 57IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect" 58IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect" 59RA_HONOR_PIO_LIFE_SYSCTL = "/proc/sys/net/ipv6/conf/default/ra_honor_pio_life" 60RA_HONOR_PIO_PFLAG = "/proc/sys/net/ipv6/conf/default/ra_honor_pio_pflag" 61 62HAVE_ACCEPT_RA_MIN_LFT = (os.path.isfile(ACCEPT_RA_MIN_LFT_SYSCTL) or 63 net_test.NonGXI(5, 10) or 64 net_test.KernelAtLeast([(5, 10, 199), (5, 15, 136), 65 (6, 1, 57), (6, 6, 0)])) 66HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL) 67HAVE_RA_HONOR_PIO_LIFE = (os.path.isfile(RA_HONOR_PIO_LIFE_SYSCTL) or 68 net_test.KernelAtLeast([(6, 7, 0)])) 69HAVE_RA_HONOR_PIO_PFLAG = (os.path.isfile(RA_HONOR_PIO_PFLAG) or 70 net_test.KernelAtLeast([(6, 12, 0)])) 71 72HAVE_USEROPT_PIO_FIX = net_test.KernelAtLeast([(4, 19, 320), (5, 4, 282), 73 (5, 10, 224), (5, 15, 165), 74 (6, 1, 104), (6, 6, 45), 75 (6, 9, 13), (6, 10, 4), 76 (6, 11, 0)]) 77 78 79class ConfigurationError(AssertionError): 80 pass 81 82 83class UnexpectedPacketError(AssertionError): 84 pass 85 86 87def MakePktInfo(version, addr, ifindex): 88 family = {4: AF_INET, 6: AF_INET6}[version] 89 if not addr: 90 addr = {4: "0.0.0.0", 6: "::"}[version] 91 if addr: 92 addr = inet_pton(family, addr) 93 if version == 6: 94 return csocket.In6Pktinfo((addr, ifindex)).Pack() 95 else: 96 return csocket.InPktinfo((ifindex, addr, b"\x00" * 4)).Pack() 97 98 99class MultiNetworkBaseTest(net_test.NetworkTest): 100 """Base class for all multinetwork tests. 101 102 This class does not contain any test code, but contains code to set up and 103 tear a multi-network environment using multiple tun interfaces. The 104 environment is designed to be similar to a real Android device in terms of 105 rules and routes, and supports IPv4 and IPv6. 106 107 Tests wishing to use this environment should inherit from this class and 108 ensure that any setupClass, tearDownClass, setUp, and tearDown methods they 109 implement also call the superclass versions. 110 """ 111 112 # Must be between 1 and 256, since we put them in MAC addresses and IIDs. 113 NETIDS = [100, 150, 200, 250] 114 115 # Stores sysctl values to write back when the test completes. 116 saved_sysctls = {} 117 118 # Wether to output setup commands. 119 DEBUG = False 120 121 UID_RANGE_START = 2000 122 UID_RANGE_END = 9999 123 UID_RANGE_SIZE = UID_RANGE_END - UID_RANGE_START + 1 124 125 # Rule priorities. 126 PRIORITY_UID = 100 127 PRIORITY_OIF = 200 128 PRIORITY_FWMARK = 300 129 PRIORITY_IIF = 400 130 PRIORITY_DEFAULT = 999 131 PRIORITY_UNREACHABLE = 1000 132 133 # Actual device routing is more complicated, involving more than one rule 134 # per NetId, but here we make do with just one rule that selects the lower 135 # 16 bits. 136 NETID_FWMASK = 0xffff 137 138 # For convenience. 139 IPV4_ADDR = net_test.IPV4_ADDR 140 IPV6_ADDR = net_test.IPV6_ADDR 141 IPV4_ADDR2 = net_test.IPV4_ADDR2 142 IPV6_ADDR2 = net_test.IPV6_ADDR2 143 IPV4_PING = net_test.IPV4_PING 144 IPV6_PING = net_test.IPV6_PING 145 146 RA_VALIDITY = 600 # seconds 147 148 @classmethod 149 def UidRangeForNetid(cls, netid): 150 per_netid_range = int(cls.UID_RANGE_SIZE / len(cls.NETIDS)) 151 idx = cls.NETIDS.index(netid) 152 return ( 153 cls.UID_RANGE_START + per_netid_range * idx, 154 cls.UID_RANGE_START + per_netid_range * (idx + 1) - 1 155 ) 156 157 @classmethod 158 def UidForNetid(cls, netid): 159 if not netid: 160 return 0 161 return random.randint(*cls.UidRangeForNetid(netid)) 162 163 @classmethod 164 def _TableForNetid(cls, netid): 165 if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices: 166 return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET) 167 else: 168 return netid 169 170 @staticmethod 171 def GetInterfaceName(netid): 172 return "nettest%d" % netid 173 174 @staticmethod 175 def RouterMacAddress(netid): 176 return "02:00:00:00:%02x:00" % netid 177 178 @staticmethod 179 def MyMacAddress(netid): 180 return "02:00:00:00:%02x:01" % netid 181 182 @staticmethod 183 def _RouterAddress(netid, version): 184 if version == 6: 185 return "fe80::%02x00" % netid 186 elif version == 4: 187 return "10.0.%d.1" % netid 188 else: 189 raise ValueError("Don't support IPv%s" % version) 190 191 @classmethod 192 def _MyIPv4Address(cls, netid): 193 return "10.0.%d.2" % netid 194 195 @classmethod 196 def _MyIPv6Address(cls, netid): 197 return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False) 198 199 @classmethod 200 def MyAddress(cls, version, netid): 201 return {4: cls._MyIPv4Address(netid), 202 5: cls._MyIPv4Address(netid), 203 6: cls._MyIPv6Address(netid)}[version] 204 205 @classmethod 206 def MySocketAddress(cls, version, netid): 207 addr = cls.MyAddress(version, netid) 208 return "::ffff:" + addr if version == 5 else addr 209 210 @classmethod 211 def MyLinkLocalAddress(cls, netid): 212 return net_test.GetLinkAddress(cls.GetInterfaceName(netid), True) 213 214 @staticmethod 215 def OnlinkPrefixLen(version): 216 return {4: 24, 6: 64}[version] 217 218 @staticmethod 219 def OnlinkPrefix(version, netid): 220 return {4: "10.0.%d.0" % netid, 221 6: "2001:db8:%02x::" % netid}[version] 222 223 @staticmethod 224 def GetRandomDestination(prefix): 225 if "." in prefix: 226 return prefix + "%d.%d" % (random.randint(0, 255), random.randint(0, 255)) 227 else: 228 return prefix + "%x:%x" % (random.randint(0, 65535), 229 random.randint(0, 65535)) 230 231 def GetProtocolFamily(self, version): 232 return {4: AF_INET, 6: AF_INET6}[version] 233 234 @classmethod 235 def CreateTunInterface(cls, netid): 236 iface = cls.GetInterfaceName(netid) 237 try: 238 f = open("/dev/net/tun", "r+b", buffering=0) 239 except IOError: 240 f = open("/dev/tun", "r+b", buffering=0) 241 ifr = struct.pack("16sH", iface.encode(), IFF_TAP | IFF_NO_PI) 242 ifr += b"\x00" * (40 - len(ifr)) 243 fcntl.ioctl(f, TUNSETIFF, ifr) 244 # Give ourselves a predictable MAC address. 245 net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid)) 246 # Disable DAD so we don't have to wait for it. 247 cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0) 248 # Set accept_ra to 2, because that's what we use. 249 cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_ra" % iface, 2) 250 net_test.SetInterfaceUp(iface) 251 net_test.SetNonBlocking(f) 252 return f 253 254 @classmethod 255 def SendRA(cls, netid, retranstimer=None, reachabletime=0, routerlft=RA_VALIDITY, 256 piolft=RA_VALIDITY, m=0, o=0, piopflag=0, options=()): 257 macaddr = cls.RouterMacAddress(netid) 258 lladdr = cls._RouterAddress(netid, 6) 259 260 if retranstimer is None: 261 # If no retrans timer was specified, pick one that's as long as the 262 # router lifetime. This ensures that no spurious ND retransmits 263 # will interfere with test expectations. 264 retranstimer = routerlft * 1000 # Lifetime is in s, retrans timer in ms. 265 266 # We don't want any routes in the main table. If the kernel doesn't support 267 # putting RA routes into per-interface tables, configure routing manually. 268 if not HAVE_AUTOCONF_TABLE: 269 routerlft = 0 270 271 res1 = 0x10 if piopflag else 0 272 273 ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") / 274 scapy.IPv6(src=lladdr, hlim=255) / 275 scapy.ICMPv6ND_RA(reachabletime=reachabletime, 276 retranstimer=retranstimer, 277 routerlifetime=routerlft, 278 M=m, O=o) / 279 scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) / 280 scapy.ICMPv6NDOptPrefixInfo(prefix=cls.OnlinkPrefix(6, netid), 281 prefixlen=cls.OnlinkPrefixLen(6), 282 L=1, A=1, res1=res1, 283 validlifetime=piolft, 284 preferredlifetime=piolft)) 285 for option in options: 286 ra /= option 287 posix.write(cls.tuns[netid].fileno(), bytes(ra)) 288 289 @classmethod 290 def _RunSetupCommands(cls, netid, is_add): 291 for version in [4, 6]: 292 # Find out how to configure things. 293 iface = cls.GetInterfaceName(netid) 294 ifindex = cls.ifindices[netid] 295 macaddr = cls.RouterMacAddress(netid) 296 router = cls._RouterAddress(netid, version) 297 table = cls._TableForNetid(netid) 298 299 # Set up routing rules. 300 start, end = cls.UidRangeForNetid(netid) 301 cls.iproute.UidRangeRule(version, is_add, start, end, table, 302 cls.PRIORITY_UID) 303 cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF) 304 cls.iproute.FwmarkRule(version, is_add, netid, cls.NETID_FWMASK, table, 305 cls.PRIORITY_FWMARK) 306 307 # Configure routing and addressing. 308 # 309 # IPv6 uses autoconf for everything, except if per-device autoconf routing 310 # tables are not supported, in which case the default route (only) is 311 # configured manually. For IPv4 we have to manually configure addresses, 312 # routes, and neighbour cache entries (since we don't reply to ARP or ND). 313 # 314 # Since deleting addresses also causes routes to be deleted, we need to 315 # be careful with ordering or the delete commands will fail with ENOENT. 316 # 317 # A real Android system will have both IPv4 and IPv6 routes for 318 # directly-connected subnets in the per-interface routing tables. Ensure 319 # we create those as well. 320 do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None) 321 if is_add: 322 if version == 4: 323 cls.iproute.AddAddress(cls._MyIPv4Address(netid), 324 cls.OnlinkPrefixLen(4), ifindex) 325 cls.iproute.AddNeighbour(version, router, macaddr, ifindex) 326 if do_routing: 327 cls.iproute.AddRoute(version, table, 328 cls.OnlinkPrefix(version, netid), 329 cls.OnlinkPrefixLen(version), None, ifindex) 330 cls.iproute.AddRoute(version, table, "default", 0, router, ifindex) 331 else: 332 if do_routing: 333 cls.iproute.DelRoute(version, table, "default", 0, router, ifindex) 334 cls.iproute.DelRoute(version, table, 335 cls.OnlinkPrefix(version, netid), 336 cls.OnlinkPrefixLen(version), None, ifindex) 337 if version == 4: 338 cls.iproute.DelNeighbour(version, router, macaddr, ifindex) 339 cls.iproute.DelAddress(cls._MyIPv4Address(netid), 340 cls.OnlinkPrefixLen(4), ifindex) 341 342 @classmethod 343 def SetMarkReflectSysctls(cls, value): 344 """Makes kernel-generated replies use the mark of the original packet.""" 345 cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value) 346 cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value) 347 348 @classmethod 349 def _SetInboundMarking(cls, netid, iface, is_add): 350 for version in [4, 6]: 351 # Run iptables to set up incoming packet marking. 352 add_del = "-A" if is_add else "-D" 353 iptables = {4: "iptables", 6: "ip6tables"}[version] 354 args = "%s INPUT -t mangle -i %s -j MARK --set-mark %d" % ( 355 add_del, iface, netid) 356 if net_test.RunIptablesCommand(version, args): 357 raise ConfigurationError("Setup command failed: %s" % args) 358 359 @classmethod 360 def SetInboundMarks(cls, is_add): 361 for netid in cls.tuns: 362 cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), is_add) 363 364 @classmethod 365 def SetDefaultNetwork(cls, netid): 366 table = cls._TableForNetid(netid) if netid else None 367 for version in [4, 6]: 368 is_add = table is not None 369 cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT) 370 371 @classmethod 372 def ClearDefaultNetwork(cls): 373 cls.SetDefaultNetwork(None) 374 375 @classmethod 376 def GetSysctl(cls, sysctl): 377 with open(sysctl, "r") as sysctl_file: 378 return sysctl_file.read() 379 380 @classmethod 381 def SetSysctl(cls, sysctl, value): 382 # Only save each sysctl value the first time we set it. This is so we can 383 # set it to arbitrary values multiple times and still write it back 384 # correctly at the end. 385 if sysctl not in cls.saved_sysctls: 386 cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl) 387 with open(sysctl, "w") as sysctl_file: 388 sysctl_file.write(str(value) + "\n") 389 390 @classmethod 391 def SetIPv6SysctlOnAllIfaces(cls, sysctl, value): 392 for netid in cls.tuns: 393 iface = cls.GetInterfaceName(netid) 394 name = "/proc/sys/net/ipv6/conf/%s/%s" % (iface, sysctl) 395 cls.SetSysctl(name, value) 396 397 @classmethod 398 def _RestoreSysctls(cls): 399 for sysctl, value in cls.saved_sysctls.items(): 400 try: 401 with open(sysctl, "w") as sysctl_file: 402 sysctl_file.write(value) 403 except IOError: 404 pass 405 406 @classmethod 407 def _ICMPRatelimitFilename(cls, version): 408 return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit", 409 6: "ipv6/icmp/ratelimit"}[version] 410 411 @classmethod 412 def _SetICMPRatelimit(cls, version, limit): 413 cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit) 414 415 @classmethod 416 def setUpClass(cls): 417 # This is per-class setup instead of per-testcase setup because shelling out 418 # to ip and iptables is slow, and because routing configuration doesn't 419 # change during the test. 420 cls.iproute = iproute.IPRoute() 421 cls.tuns = {} 422 cls.ifindices = {} 423 if HAVE_AUTOCONF_TABLE: 424 cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000) 425 cls.AUTOCONF_TABLE_OFFSET = -1000 426 else: 427 cls.AUTOCONF_TABLE_OFFSET = None 428 429 # Disable ICMP rate limits. These will be restored by _RestoreSysctls. 430 for version in [4, 6]: 431 cls._SetICMPRatelimit(version, 0) 432 433 for version in [4, 6]: 434 cls.iproute.UnreachableRule(version, True, cls.PRIORITY_UNREACHABLE) 435 436 for netid in cls.NETIDS: 437 cls.tuns[netid] = cls.CreateTunInterface(netid) 438 iface = cls.GetInterfaceName(netid) 439 cls.ifindices[netid] = net_test.GetInterfaceIndex(iface) 440 441 cls.SendRA(netid) 442 cls._RunSetupCommands(netid, True) 443 444 # Don't print lots of "device foo entered promiscuous mode" warnings. 445 cls.loglevel = cls.GetConsoleLogLevel() 446 cls.SetConsoleLogLevel(net_test.KERN_INFO) 447 448 # When running on device, don't send connections through FwmarkServer. 449 os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] = "1" 450 451 # Uncomment to look around at interface and rule configuration while 452 # running in the background. (Once the test finishes running, all the 453 # interfaces and rules are gone.) 454 # time.sleep(30) 455 456 @classmethod 457 def tearDownClass(cls): 458 del os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] 459 460 for version in [4, 6]: 461 try: 462 cls.iproute.UnreachableRule(version, False, cls.PRIORITY_UNREACHABLE) 463 except IOError: 464 pass 465 466 for netid in cls.tuns: 467 cls._RunSetupCommands(netid, False) 468 cls.tuns[netid].close() 469 470 cls.iproute.close() 471 cls._RestoreSysctls() 472 cls.SetConsoleLogLevel(cls.loglevel) 473 474 def setUp(self): 475 self.ClearTunQueues() 476 477 def SetSocketMark(self, s, netid): 478 if netid is None: 479 netid = 0 480 s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid) 481 482 def GetSocketMark(self, s): 483 return s.getsockopt(SOL_SOCKET, net_test.SO_MARK) 484 485 def ClearSocketMark(self, s): 486 self.SetSocketMark(s, 0) 487 488 def BindToDevice(self, s, iface): 489 if not iface: 490 iface = "" 491 s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface.encode()) 492 493 def SetUnicastInterface(self, s, ifindex): 494 # Otherwise, Python thinks it's a 1-byte option. 495 ifindex = struct.pack("!I", ifindex) 496 497 # Always set the IPv4 interface, because it will be used even on IPv6 498 # sockets if the destination address is a mapped address. 499 s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex) 500 if s.family == AF_INET6: 501 s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex) 502 503 def GetRemoteAddress(self, version): 504 return {4: self.IPV4_ADDR, 505 5: self.IPV4_ADDR, # see GetRemoteSocketAddress() 506 6: self.IPV6_ADDR}[version] 507 508 def GetRemoteSocketAddress(self, version): 509 addr = self.GetRemoteAddress(version) 510 return "::ffff:" + addr if version == 5 else addr 511 512 def GetOtherRemoteSocketAddress(self, version): 513 return {4: self.IPV4_ADDR2, 514 5: "::ffff:" + self.IPV4_ADDR2, 515 6: self.IPV6_ADDR2}[version] 516 517 def SelectInterface(self, s, netid, mode): 518 if mode == "uid": 519 os.fchown(s.fileno(), self.UidForNetid(netid), -1) 520 elif mode == "mark": 521 self.SetSocketMark(s, netid) 522 elif mode == "oif": 523 iface = self.GetInterfaceName(netid) if netid else "" 524 self.BindToDevice(s, iface) 525 elif mode == "ucast_oif": 526 self.SetUnicastInterface(s, self.ifindices.get(netid, 0)) 527 else: 528 raise ValueError("Unknown interface selection mode %s" % mode) 529 530 def BuildSocket(self, version, constructor, netid, routing_mode): 531 if version == 5: version = 6 532 s = constructor(self.GetProtocolFamily(version)) 533 534 if routing_mode not in [None, "uid"]: 535 self.SelectInterface(s, netid, routing_mode) 536 elif routing_mode == "uid": 537 os.fchown(s.fileno(), self.UidForNetid(netid), -1) 538 539 return s 540 541 def RandomNetid(self, exclude=None): 542 """Return a random netid from the list of netids 543 544 Args: 545 exclude: a netid or list of netids that should not be chosen 546 """ 547 if exclude is None: 548 exclude = [] 549 elif isinstance(exclude, int): 550 exclude = [exclude] 551 diff = [netid for netid in self.NETIDS if netid not in exclude] 552 return random.choice(diff) 553 554 def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs): 555 if netid is not None: 556 pktinfo = MakePktInfo(version, None, self.ifindices[netid]) 557 cmsg_level, cmsg_name = { 558 4: (net_test.SOL_IP, csocket.IP_PKTINFO), 559 6: (net_test.SOL_IPV6, csocket.IPV6_PKTINFO)}[version] 560 cmsgs.append((cmsg_level, cmsg_name, pktinfo)) 561 csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM) 562 563 def ReceiveEtherPacketOn(self, netid, packet): 564 posix.write(self.tuns[netid].fileno(), bytes(packet)) 565 566 def ReceivePacketOn(self, netid, ip_packet): 567 routermac = self.RouterMacAddress(netid) 568 mymac = self.MyMacAddress(netid) 569 packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet 570 self.ReceiveEtherPacketOn(netid, packet) 571 572 def ReadAllPacketsOn(self, netid, include_multicast=False): 573 """Return all queued packets on a netid as a list. 574 575 Args: 576 netid: The netid from which to read packets 577 include_multicast: A boolean, whether to remove multicast packets 578 (default=False) 579 """ 580 packets = [] 581 retries = 0 582 max_retries = 1 583 while True: 584 try: 585 packet = posix.read(self.tuns[netid].fileno(), 4096) 586 if not packet: 587 break 588 ether = scapy.Ether(packet) 589 # Multicast frames are frames where the first byte of the destination 590 # MAC address has 1 in the least-significant bit. 591 if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1: 592 packets.append(ether.payload) 593 except OSError as e: 594 # EAGAIN means there are no more packets waiting. 595 if e.errno == errno.EAGAIN: 596 # If we didn't see any packets, try again for good luck. 597 if not packets and retries < max_retries: 598 time.sleep(0.01) 599 retries += 1 600 continue 601 else: 602 break 603 # Anything else is unexpected. 604 else: 605 raise e 606 return packets 607 608 def InvalidateDstCache(self, version, netid): 609 """Invalidates destination cache entries of sockets on the specified table. 610 611 Creates and then deletes a low-priority throw route in the table for the 612 given netid, which invalidates the destination cache entries of any sockets 613 that refer to routes in that table. 614 615 The fact that this method actually invalidates destination cache entries is 616 tested by OutgoingTest#testIPv[46]Remarking, which checks that the kernel 617 does not re-route sockets when they are remarked, but does re-route them if 618 this method is called. 619 620 Args: 621 version: The IP version, 4 or 6. 622 netid: The netid to invalidate dst caches on. 623 """ 624 iface = self.GetInterfaceName(netid) 625 ifindex = self.ifindices[netid] 626 table = self._TableForNetid(netid) 627 for action in [iproute.RTM_NEWROUTE, iproute.RTM_DELROUTE]: 628 self.iproute._Route(version, iproute.RTPROT_STATIC, action, table, 629 "default", 0, nexthop=None, dev=None, mark=None, 630 uid=None, route_type=iproute.RTN_THROW, 631 priority=100000) 632 633 def ClearTunQueues(self): 634 # Keep reading packets on all netids until we get no packets on any of them. 635 waiting = None 636 while waiting != 0: 637 waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS) 638 639 def assertPacketMatches(self, expected, actual): 640 # The expected packet is just a rough sketch of the packet we expect to 641 # receive. For example, it doesn't contain fields we can't predict, such as 642 # initial TCP sequence numbers, or that depend on the host implementation 643 # and settings, such as TCP options. To check whether the packet matches 644 # what we expect, instead of just checking all the known fields one by one, 645 # we blank out fields in the actual packet and then compare the whole 646 # packets to each other as strings. Because we modify the actual packet, 647 # make a copy here. 648 actual = actual.copy() 649 650 # Blank out IPv4 fields that we can't predict, like ID and the DF bit. 651 actualip = actual.getlayer("IP") 652 expectedip = expected.getlayer("IP") 653 if actualip and expectedip: 654 actualip.id = expectedip.id 655 actualip.flags &= 5 656 actualip.chksum = None # Change the header, recalculate the checksum. 657 658 # Blank out the flow label, since new kernels randomize it by default. 659 actualipv6 = actual.getlayer("IPv6") 660 expectedipv6 = expected.getlayer("IPv6") 661 if actualipv6 and expectedipv6: 662 actualipv6.fl = expectedipv6.fl 663 664 # Blank out UDP fields that we can't predict (e.g., the source port for 665 # kernel-originated packets). 666 actualudp = actual.getlayer("UDP") 667 expectedudp = expected.getlayer("UDP") 668 if actualudp and expectedudp: 669 if expectedudp.sport is None: 670 actualudp.sport = None 671 actualudp.chksum = None 672 elif actualudp.chksum == 0xffff and expectedudp.chksum == 0: 673 # Scapy does not appear to change 0 to 0xffff as required by RFC 768. 674 # It is possible that scapy has been upgraded and this no longer triggers. 675 actualudp.chksum = 0 676 677 # Since the TCP code below messes with options, recalculate the length. 678 if actualip: 679 actualip.len = None 680 if actualipv6: 681 actualipv6.plen = None 682 683 # Blank out TCP fields that we can't predict. 684 actualtcp = actual.getlayer("TCP") 685 expectedtcp = expected.getlayer("TCP") 686 if actualtcp and expectedtcp: 687 actualtcp.dataofs = expectedtcp.dataofs 688 actualtcp.options = expectedtcp.options 689 actualtcp.window = expectedtcp.window 690 if expectedtcp.sport is None: 691 actualtcp.sport = None 692 if expectedtcp.seq is None: 693 actualtcp.seq = None 694 if expectedtcp.ack is None: 695 actualtcp.ack = None 696 actualtcp.chksum = None 697 698 # Serialize the packet so that expected packet fields that are only set when 699 # a packet is serialized e.g., the checksum) are filled in. 700 expected_real = expected.__class__(bytes(expected)) 701 actual_real = actual.__class__(bytes(actual)) 702 # repr() can be expensive. Call it only if the test is going to fail and we 703 # want to see the error. 704 if expected_real != actual_real: 705 self.assertEqual(repr(expected_real), repr(actual_real)) 706 707 def PacketMatches(self, expected, actual): 708 try: 709 self.assertPacketMatches(expected, actual) 710 return True 711 except AssertionError: 712 return False 713 714 def ExpectNoPacketsOn(self, netid, msg): 715 packets = self.ReadAllPacketsOn(netid) 716 if packets: 717 firstpacket = repr(packets[0]) 718 else: 719 firstpacket = "" 720 self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket) 721 722 def ExpectPacketOn(self, netid, msg, expected): 723 # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop 724 # multicast packets unless the packet we expect to see is a multicast 725 # packet. For now the only tests that use this are IPv6. 726 ipv6 = expected.getlayer("IPv6") 727 if ipv6 and ipv6.dst.startswith("ff"): 728 include_multicast = True 729 else: 730 include_multicast = False 731 732 packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast) 733 self.assertTrue(packets, msg + ": received no packets") 734 735 # If we receive a packet that matches what we expected, return it. 736 for packet in packets: 737 if self.PacketMatches(expected, packet): 738 return packet 739 740 # None of the packets matched. Call assertPacketMatches to output a diff 741 # between the expected packet and the last packet we received. In theory, 742 # we'd output a diff to the packet that's the best match for what we 743 # expected, but this is good enough for now. 744 try: 745 self.assertPacketMatches(expected, packets[-1]) 746 except Exception as e: 747 raise UnexpectedPacketError( 748 "%s: diff with last packet:\n%s" % (msg, str(e))) 749 750 def Combinations(self, version): 751 """Produces a list of combinations to test.""" 752 combinations = [] 753 754 # Check packets addressed to the IP addresses of all our interfaces... 755 for dest_ip_netid in self.tuns: 756 ip_if = self.GetInterfaceName(dest_ip_netid) 757 myaddr = self.MyAddress(version, dest_ip_netid) 758 prefix = {4: "172.22.", 6: "2001:db8:aaaa::"}[version] 759 remoteaddr = self.GetRandomDestination(prefix) 760 761 # ... coming in on all our interfaces. 762 for netid in self.tuns: 763 iif = self.GetInterfaceName(netid) 764 combinations.append((netid, iif, ip_if, myaddr, remoteaddr)) 765 766 return combinations 767 768 def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc): 769 msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra) 770 if reply_desc: 771 msg += ": Expecting %s on %s" % (reply_desc, iif) 772 else: 773 msg += ": Expecting no packets on %s" % iif 774 return msg 775 776 def _ReceiveAndExpectResponse(self, netid, packet, reply, msg): 777 self.ReceivePacketOn(netid, packet) 778 if reply: 779 return self.ExpectPacketOn(netid, msg, reply) 780 else: 781 self.ExpectNoPacketsOn(netid, msg) 782 return None 783 784 785class InboundMarkingTest(MultiNetworkBaseTest): 786 """Class that automatically sets up inbound marking.""" 787 788 @classmethod 789 def setUpClass(cls): 790 super(InboundMarkingTest, cls).setUpClass() 791 cls.SetInboundMarks(True) 792 793 @classmethod 794 def tearDownClass(cls): 795 cls.SetInboundMarks(False) 796 super(InboundMarkingTest, cls).tearDownClass() 797