xref: /aosp_15_r20/kernel/tests/net/test/multinetwork_base.py (revision 2f2c4c7ab4226c71756b9c31670392fdd6887c4f)
1#!/usr/bin/python3
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Base module for multinetwork tests."""
18
19import errno
20import fcntl
21import os
22import posix
23import random
24import re
25from socket import *  # pylint: disable=wildcard-import
26import struct
27import time
28
29from scapy import all as scapy
30
31import csocket
32import iproute
33import net_test
34
35
36IFF_TUN = 1
37IFF_TAP = 2
38IFF_NO_PI = 0x1000
39TUNSETIFF = 0x400454ca
40
41SO_BINDTODEVICE = 25
42
43# Setsockopt values.
44IP_UNICAST_IF = 50
45IPV6_MULTICAST_IF = 17
46IPV6_UNICAST_IF = 76
47
48# Cmsg values.
49IP_TTL = 2
50IPV6_2292PKTOPTIONS = 6
51IPV6_FLOWINFO = 11
52IPV6_HOPLIMIT = 52  # Different from IPV6_UNICAST_HOPS, this is cmsg only.
53
54
55ACCEPT_RA_MIN_LFT_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_min_lft"
56AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
57IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
58IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
59RA_HONOR_PIO_LIFE_SYSCTL = "/proc/sys/net/ipv6/conf/default/ra_honor_pio_life"
60RA_HONOR_PIO_PFLAG = "/proc/sys/net/ipv6/conf/default/ra_honor_pio_pflag"
61
62HAVE_ACCEPT_RA_MIN_LFT = (os.path.isfile(ACCEPT_RA_MIN_LFT_SYSCTL) or
63                          net_test.NonGXI(5, 10) or
64                          net_test.KernelAtLeast([(5, 10, 199), (5, 15, 136),
65                                                  (6, 1, 57), (6, 6, 0)]))
66HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL)
67HAVE_RA_HONOR_PIO_LIFE = (os.path.isfile(RA_HONOR_PIO_LIFE_SYSCTL) or
68                          net_test.KernelAtLeast([(6, 7, 0)]))
69HAVE_RA_HONOR_PIO_PFLAG = (os.path.isfile(RA_HONOR_PIO_PFLAG) or
70                           net_test.KernelAtLeast([(6, 12, 0)]))
71
72HAVE_USEROPT_PIO_FIX = net_test.KernelAtLeast([(4, 19, 320), (5, 4, 282),
73                                               (5, 10, 224), (5, 15, 165),
74                                               (6, 1, 104), (6, 6, 45),
75                                               (6, 9, 13), (6, 10, 4),
76                                               (6, 11, 0)])
77
78
79class ConfigurationError(AssertionError):
80  pass
81
82
83class UnexpectedPacketError(AssertionError):
84  pass
85
86
87def MakePktInfo(version, addr, ifindex):
88  family = {4: AF_INET, 6: AF_INET6}[version]
89  if not addr:
90    addr = {4: "0.0.0.0", 6: "::"}[version]
91  if addr:
92    addr = inet_pton(family, addr)
93  if version == 6:
94    return csocket.In6Pktinfo((addr, ifindex)).Pack()
95  else:
96    return csocket.InPktinfo((ifindex, addr, b"\x00" * 4)).Pack()
97
98
99class MultiNetworkBaseTest(net_test.NetworkTest):
100  """Base class for all multinetwork tests.
101
102  This class does not contain any test code, but contains code to set up and
103  tear a multi-network environment using multiple tun interfaces. The
104  environment is designed to be similar to a real Android device in terms of
105  rules and routes, and supports IPv4 and IPv6.
106
107  Tests wishing to use this environment should inherit from this class and
108  ensure that any setupClass, tearDownClass, setUp, and tearDown methods they
109  implement also call the superclass versions.
110  """
111
112  # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
113  NETIDS = [100, 150, 200, 250]
114
115  # Stores sysctl values to write back when the test completes.
116  saved_sysctls = {}
117
118  # Wether to output setup commands.
119  DEBUG = False
120
121  UID_RANGE_START = 2000
122  UID_RANGE_END = 9999
123  UID_RANGE_SIZE = UID_RANGE_END - UID_RANGE_START + 1
124
125  # Rule priorities.
126  PRIORITY_UID = 100
127  PRIORITY_OIF = 200
128  PRIORITY_FWMARK = 300
129  PRIORITY_IIF = 400
130  PRIORITY_DEFAULT = 999
131  PRIORITY_UNREACHABLE = 1000
132
133  # Actual device routing is more complicated, involving more than one rule
134  # per NetId, but here we make do with just one rule that selects the lower
135  # 16 bits.
136  NETID_FWMASK = 0xffff
137
138  # For convenience.
139  IPV4_ADDR = net_test.IPV4_ADDR
140  IPV6_ADDR = net_test.IPV6_ADDR
141  IPV4_ADDR2 = net_test.IPV4_ADDR2
142  IPV6_ADDR2 = net_test.IPV6_ADDR2
143  IPV4_PING = net_test.IPV4_PING
144  IPV6_PING = net_test.IPV6_PING
145
146  RA_VALIDITY = 600 # seconds
147
148  @classmethod
149  def UidRangeForNetid(cls, netid):
150    per_netid_range = int(cls.UID_RANGE_SIZE / len(cls.NETIDS))
151    idx = cls.NETIDS.index(netid)
152    return (
153        cls.UID_RANGE_START + per_netid_range * idx,
154        cls.UID_RANGE_START + per_netid_range * (idx + 1) - 1
155    )
156
157  @classmethod
158  def UidForNetid(cls, netid):
159    if not netid:
160      return 0
161    return random.randint(*cls.UidRangeForNetid(netid))
162
163  @classmethod
164  def _TableForNetid(cls, netid):
165    if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices:
166      return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET)
167    else:
168      return netid
169
170  @staticmethod
171  def GetInterfaceName(netid):
172    return "nettest%d" % netid
173
174  @staticmethod
175  def RouterMacAddress(netid):
176    return "02:00:00:00:%02x:00" % netid
177
178  @staticmethod
179  def MyMacAddress(netid):
180    return "02:00:00:00:%02x:01" % netid
181
182  @staticmethod
183  def _RouterAddress(netid, version):
184    if version == 6:
185      return "fe80::%02x00" % netid
186    elif version == 4:
187      return "10.0.%d.1" % netid
188    else:
189      raise ValueError("Don't support IPv%s" % version)
190
191  @classmethod
192  def _MyIPv4Address(cls, netid):
193    return "10.0.%d.2" % netid
194
195  @classmethod
196  def _MyIPv6Address(cls, netid):
197    return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
198
199  @classmethod
200  def MyAddress(cls, version, netid):
201    return {4: cls._MyIPv4Address(netid),
202            5: cls._MyIPv4Address(netid),
203            6: cls._MyIPv6Address(netid)}[version]
204
205  @classmethod
206  def MySocketAddress(cls, version, netid):
207    addr = cls.MyAddress(version, netid)
208    return "::ffff:" + addr if version == 5 else addr
209
210  @classmethod
211  def MyLinkLocalAddress(cls, netid):
212    return net_test.GetLinkAddress(cls.GetInterfaceName(netid), True)
213
214  @staticmethod
215  def OnlinkPrefixLen(version):
216    return {4: 24, 6: 64}[version]
217
218  @staticmethod
219  def OnlinkPrefix(version, netid):
220    return {4: "10.0.%d.0" % netid,
221            6: "2001:db8:%02x::" % netid}[version]
222
223  @staticmethod
224  def GetRandomDestination(prefix):
225    if "." in prefix:
226      return prefix + "%d.%d" % (random.randint(0, 255), random.randint(0, 255))
227    else:
228      return prefix + "%x:%x" % (random.randint(0, 65535),
229                                 random.randint(0, 65535))
230
231  def GetProtocolFamily(self, version):
232    return {4: AF_INET, 6: AF_INET6}[version]
233
234  @classmethod
235  def CreateTunInterface(cls, netid):
236    iface = cls.GetInterfaceName(netid)
237    try:
238      f = open("/dev/net/tun", "r+b", buffering=0)
239    except IOError:
240      f = open("/dev/tun", "r+b", buffering=0)
241    ifr = struct.pack("16sH", iface.encode(), IFF_TAP | IFF_NO_PI)
242    ifr += b"\x00" * (40 - len(ifr))
243    fcntl.ioctl(f, TUNSETIFF, ifr)
244    # Give ourselves a predictable MAC address.
245    net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
246    # Disable DAD so we don't have to wait for it.
247    cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0)
248    # Set accept_ra to 2, because that's what we use.
249    cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_ra" % iface, 2)
250    net_test.SetInterfaceUp(iface)
251    net_test.SetNonBlocking(f)
252    return f
253
254  @classmethod
255  def SendRA(cls, netid, retranstimer=None, reachabletime=0, routerlft=RA_VALIDITY,
256             piolft=RA_VALIDITY, m=0, o=0, piopflag=0, options=()):
257    macaddr = cls.RouterMacAddress(netid)
258    lladdr = cls._RouterAddress(netid, 6)
259
260    if retranstimer is None:
261      # If no retrans timer was specified, pick one that's as long as the
262      # router lifetime. This ensures that no spurious ND retransmits
263      # will interfere with test expectations.
264      retranstimer = routerlft * 1000  # Lifetime is in s, retrans timer in ms.
265
266    # We don't want any routes in the main table. If the kernel doesn't support
267    # putting RA routes into per-interface tables, configure routing manually.
268    if not HAVE_AUTOCONF_TABLE:
269      routerlft = 0
270
271    res1 = 0x10 if piopflag else 0
272
273    ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") /
274          scapy.IPv6(src=lladdr, hlim=255) /
275          scapy.ICMPv6ND_RA(reachabletime=reachabletime,
276                            retranstimer=retranstimer,
277                            routerlifetime=routerlft,
278                            M=m, O=o) /
279          scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
280          scapy.ICMPv6NDOptPrefixInfo(prefix=cls.OnlinkPrefix(6, netid),
281                                      prefixlen=cls.OnlinkPrefixLen(6),
282                                      L=1, A=1, res1=res1,
283                                      validlifetime=piolft,
284                                      preferredlifetime=piolft))
285    for option in options:
286      ra /= option
287    posix.write(cls.tuns[netid].fileno(), bytes(ra))
288
289  @classmethod
290  def _RunSetupCommands(cls, netid, is_add):
291    for version in [4, 6]:
292      # Find out how to configure things.
293      iface = cls.GetInterfaceName(netid)
294      ifindex = cls.ifindices[netid]
295      macaddr = cls.RouterMacAddress(netid)
296      router = cls._RouterAddress(netid, version)
297      table = cls._TableForNetid(netid)
298
299      # Set up routing rules.
300      start, end = cls.UidRangeForNetid(netid)
301      cls.iproute.UidRangeRule(version, is_add, start, end, table,
302                               cls.PRIORITY_UID)
303      cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF)
304      cls.iproute.FwmarkRule(version, is_add, netid, cls.NETID_FWMASK, table,
305                             cls.PRIORITY_FWMARK)
306
307      # Configure routing and addressing.
308      #
309      # IPv6 uses autoconf for everything, except if per-device autoconf routing
310      # tables are not supported, in which case the default route (only) is
311      # configured manually. For IPv4 we have to manually configure addresses,
312      # routes, and neighbour cache entries (since we don't reply to ARP or ND).
313      #
314      # Since deleting addresses also causes routes to be deleted, we need to
315      # be careful with ordering or the delete commands will fail with ENOENT.
316      #
317      # A real Android system will have both IPv4 and IPv6 routes for
318      # directly-connected subnets in the per-interface routing tables. Ensure
319      # we create those as well.
320      do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None)
321      if is_add:
322        if version == 4:
323          cls.iproute.AddAddress(cls._MyIPv4Address(netid),
324                                 cls.OnlinkPrefixLen(4), ifindex)
325          cls.iproute.AddNeighbour(version, router, macaddr, ifindex)
326        if do_routing:
327          cls.iproute.AddRoute(version, table,
328                               cls.OnlinkPrefix(version, netid),
329                               cls.OnlinkPrefixLen(version), None, ifindex)
330          cls.iproute.AddRoute(version, table, "default", 0, router, ifindex)
331      else:
332        if do_routing:
333          cls.iproute.DelRoute(version, table, "default", 0, router, ifindex)
334          cls.iproute.DelRoute(version, table,
335                               cls.OnlinkPrefix(version, netid),
336                               cls.OnlinkPrefixLen(version), None, ifindex)
337        if version == 4:
338          cls.iproute.DelNeighbour(version, router, macaddr, ifindex)
339          cls.iproute.DelAddress(cls._MyIPv4Address(netid),
340                                 cls.OnlinkPrefixLen(4), ifindex)
341
342  @classmethod
343  def SetMarkReflectSysctls(cls, value):
344    """Makes kernel-generated replies use the mark of the original packet."""
345    cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
346    cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
347
348  @classmethod
349  def _SetInboundMarking(cls, netid, iface, is_add):
350    for version in [4, 6]:
351      # Run iptables to set up incoming packet marking.
352      add_del = "-A" if is_add else "-D"
353      iptables = {4: "iptables", 6: "ip6tables"}[version]
354      args = "%s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
355          add_del, iface, netid)
356      if net_test.RunIptablesCommand(version, args):
357        raise ConfigurationError("Setup command failed: %s" % args)
358
359  @classmethod
360  def SetInboundMarks(cls, is_add):
361    for netid in cls.tuns:
362      cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), is_add)
363
364  @classmethod
365  def SetDefaultNetwork(cls, netid):
366    table = cls._TableForNetid(netid) if netid else None
367    for version in [4, 6]:
368      is_add = table is not None
369      cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT)
370
371  @classmethod
372  def ClearDefaultNetwork(cls):
373    cls.SetDefaultNetwork(None)
374
375  @classmethod
376  def GetSysctl(cls, sysctl):
377    with open(sysctl, "r") as sysctl_file:
378      return sysctl_file.read()
379
380  @classmethod
381  def SetSysctl(cls, sysctl, value):
382    # Only save each sysctl value the first time we set it. This is so we can
383    # set it to arbitrary values multiple times and still write it back
384    # correctly at the end.
385    if sysctl not in cls.saved_sysctls:
386      cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
387    with open(sysctl, "w") as sysctl_file:
388      sysctl_file.write(str(value) + "\n")
389
390  @classmethod
391  def SetIPv6SysctlOnAllIfaces(cls, sysctl, value):
392    for netid in cls.tuns:
393      iface = cls.GetInterfaceName(netid)
394      name = "/proc/sys/net/ipv6/conf/%s/%s" % (iface, sysctl)
395      cls.SetSysctl(name, value)
396
397  @classmethod
398  def _RestoreSysctls(cls):
399    for sysctl, value in cls.saved_sysctls.items():
400      try:
401        with open(sysctl, "w") as sysctl_file:
402          sysctl_file.write(value)
403      except IOError:
404        pass
405
406  @classmethod
407  def _ICMPRatelimitFilename(cls, version):
408    return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
409                               6: "ipv6/icmp/ratelimit"}[version]
410
411  @classmethod
412  def _SetICMPRatelimit(cls, version, limit):
413    cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
414
415  @classmethod
416  def setUpClass(cls):
417    # This is per-class setup instead of per-testcase setup because shelling out
418    # to ip and iptables is slow, and because routing configuration doesn't
419    # change during the test.
420    cls.iproute = iproute.IPRoute()
421    cls.tuns = {}
422    cls.ifindices = {}
423    if HAVE_AUTOCONF_TABLE:
424      cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
425      cls.AUTOCONF_TABLE_OFFSET = -1000
426    else:
427      cls.AUTOCONF_TABLE_OFFSET = None
428
429    # Disable ICMP rate limits. These will be restored by _RestoreSysctls.
430    for version in [4, 6]:
431      cls._SetICMPRatelimit(version, 0)
432
433    for version in [4, 6]:
434      cls.iproute.UnreachableRule(version, True, cls.PRIORITY_UNREACHABLE)
435
436    for netid in cls.NETIDS:
437      cls.tuns[netid] = cls.CreateTunInterface(netid)
438      iface = cls.GetInterfaceName(netid)
439      cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
440
441      cls.SendRA(netid)
442      cls._RunSetupCommands(netid, True)
443
444    # Don't print lots of "device foo entered promiscuous mode" warnings.
445    cls.loglevel = cls.GetConsoleLogLevel()
446    cls.SetConsoleLogLevel(net_test.KERN_INFO)
447
448    # When running on device, don't send connections through FwmarkServer.
449    os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] = "1"
450
451    # Uncomment to look around at interface and rule configuration while
452    # running in the background. (Once the test finishes running, all the
453    # interfaces and rules are gone.)
454    # time.sleep(30)
455
456  @classmethod
457  def tearDownClass(cls):
458    del os.environ["ANDROID_NO_USE_FWMARK_CLIENT"]
459
460    for version in [4, 6]:
461      try:
462        cls.iproute.UnreachableRule(version, False, cls.PRIORITY_UNREACHABLE)
463      except IOError:
464        pass
465
466    for netid in cls.tuns:
467      cls._RunSetupCommands(netid, False)
468      cls.tuns[netid].close()
469
470    cls.iproute.close()
471    cls._RestoreSysctls()
472    cls.SetConsoleLogLevel(cls.loglevel)
473
474  def setUp(self):
475    self.ClearTunQueues()
476
477  def SetSocketMark(self, s, netid):
478    if netid is None:
479      netid = 0
480    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
481
482  def GetSocketMark(self, s):
483    return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
484
485  def ClearSocketMark(self, s):
486    self.SetSocketMark(s, 0)
487
488  def BindToDevice(self, s, iface):
489    if not iface:
490      iface = ""
491    s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface.encode())
492
493  def SetUnicastInterface(self, s, ifindex):
494    # Otherwise, Python thinks it's a 1-byte option.
495    ifindex = struct.pack("!I", ifindex)
496
497    # Always set the IPv4 interface, because it will be used even on IPv6
498    # sockets if the destination address is a mapped address.
499    s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex)
500    if s.family == AF_INET6:
501      s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex)
502
503  def GetRemoteAddress(self, version):
504    return {4: self.IPV4_ADDR,
505            5: self.IPV4_ADDR,  # see GetRemoteSocketAddress()
506            6: self.IPV6_ADDR}[version]
507
508  def GetRemoteSocketAddress(self, version):
509    addr = self.GetRemoteAddress(version)
510    return "::ffff:" + addr if version == 5 else addr
511
512  def GetOtherRemoteSocketAddress(self, version):
513    return {4: self.IPV4_ADDR2,
514            5: "::ffff:" + self.IPV4_ADDR2,
515            6: self.IPV6_ADDR2}[version]
516
517  def SelectInterface(self, s, netid, mode):
518    if mode == "uid":
519      os.fchown(s.fileno(), self.UidForNetid(netid), -1)
520    elif mode == "mark":
521      self.SetSocketMark(s, netid)
522    elif mode == "oif":
523      iface = self.GetInterfaceName(netid) if netid else ""
524      self.BindToDevice(s, iface)
525    elif mode == "ucast_oif":
526      self.SetUnicastInterface(s, self.ifindices.get(netid, 0))
527    else:
528      raise ValueError("Unknown interface selection mode %s" % mode)
529
530  def BuildSocket(self, version, constructor, netid, routing_mode):
531    if version == 5: version = 6
532    s = constructor(self.GetProtocolFamily(version))
533
534    if routing_mode not in [None, "uid"]:
535      self.SelectInterface(s, netid, routing_mode)
536    elif routing_mode == "uid":
537      os.fchown(s.fileno(), self.UidForNetid(netid), -1)
538
539    return s
540
541  def RandomNetid(self, exclude=None):
542    """Return a random netid from the list of netids
543
544    Args:
545      exclude: a netid or list of netids that should not be chosen
546    """
547    if exclude is None:
548      exclude = []
549    elif isinstance(exclude, int):
550        exclude = [exclude]
551    diff = [netid for netid in self.NETIDS if netid not in exclude]
552    return random.choice(diff)
553
554  def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs):
555    if netid is not None:
556      pktinfo = MakePktInfo(version, None, self.ifindices[netid])
557      cmsg_level, cmsg_name = {
558          4: (net_test.SOL_IP, csocket.IP_PKTINFO),
559          6: (net_test.SOL_IPV6, csocket.IPV6_PKTINFO)}[version]
560      cmsgs.append((cmsg_level, cmsg_name, pktinfo))
561    csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM)
562
563  def ReceiveEtherPacketOn(self, netid, packet):
564    posix.write(self.tuns[netid].fileno(), bytes(packet))
565
566  def ReceivePacketOn(self, netid, ip_packet):
567    routermac = self.RouterMacAddress(netid)
568    mymac = self.MyMacAddress(netid)
569    packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
570    self.ReceiveEtherPacketOn(netid, packet)
571
572  def ReadAllPacketsOn(self, netid, include_multicast=False):
573    """Return all queued packets on a netid as a list.
574
575    Args:
576      netid: The netid from which to read packets
577      include_multicast: A boolean, whether to remove multicast packets
578        (default=False)
579    """
580    packets = []
581    retries = 0
582    max_retries = 1
583    while True:
584      try:
585        packet = posix.read(self.tuns[netid].fileno(), 4096)
586        if not packet:
587          break
588        ether = scapy.Ether(packet)
589        # Multicast frames are frames where the first byte of the destination
590        # MAC address has 1 in the least-significant bit.
591        if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1:
592          packets.append(ether.payload)
593      except OSError as e:
594        # EAGAIN means there are no more packets waiting.
595        if e.errno == errno.EAGAIN:
596          # If we didn't see any packets, try again for good luck.
597          if not packets and retries < max_retries:
598            time.sleep(0.01)
599            retries += 1
600            continue
601          else:
602            break
603        # Anything else is unexpected.
604        else:
605          raise e
606    return packets
607
608  def InvalidateDstCache(self, version, netid):
609    """Invalidates destination cache entries of sockets on the specified table.
610
611    Creates and then deletes a low-priority throw route in the table for the
612    given netid, which invalidates the destination cache entries of any sockets
613    that refer to routes in that table.
614
615    The fact that this method actually invalidates destination cache entries is
616    tested by OutgoingTest#testIPv[46]Remarking, which checks that the kernel
617    does not re-route sockets when they are remarked, but does re-route them if
618    this method is called.
619
620    Args:
621      version: The IP version, 4 or 6.
622      netid: The netid to invalidate dst caches on.
623    """
624    iface = self.GetInterfaceName(netid)
625    ifindex = self.ifindices[netid]
626    table = self._TableForNetid(netid)
627    for action in [iproute.RTM_NEWROUTE, iproute.RTM_DELROUTE]:
628      self.iproute._Route(version, iproute.RTPROT_STATIC, action, table,
629                          "default", 0, nexthop=None, dev=None, mark=None,
630                          uid=None, route_type=iproute.RTN_THROW,
631                          priority=100000)
632
633  def ClearTunQueues(self):
634    # Keep reading packets on all netids until we get no packets on any of them.
635    waiting = None
636    while waiting != 0:
637      waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
638
639  def assertPacketMatches(self, expected, actual):
640    # The expected packet is just a rough sketch of the packet we expect to
641    # receive. For example, it doesn't contain fields we can't predict, such as
642    # initial TCP sequence numbers, or that depend on the host implementation
643    # and settings, such as TCP options. To check whether the packet matches
644    # what we expect, instead of just checking all the known fields one by one,
645    # we blank out fields in the actual packet and then compare the whole
646    # packets to each other as strings. Because we modify the actual packet,
647    # make a copy here.
648    actual = actual.copy()
649
650    # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
651    actualip = actual.getlayer("IP")
652    expectedip = expected.getlayer("IP")
653    if actualip and expectedip:
654      actualip.id = expectedip.id
655      actualip.flags &= 5
656      actualip.chksum = None  # Change the header, recalculate the checksum.
657
658    # Blank out the flow label, since new kernels randomize it by default.
659    actualipv6 = actual.getlayer("IPv6")
660    expectedipv6 = expected.getlayer("IPv6")
661    if actualipv6 and expectedipv6:
662      actualipv6.fl = expectedipv6.fl
663
664    # Blank out UDP fields that we can't predict (e.g., the source port for
665    # kernel-originated packets).
666    actualudp = actual.getlayer("UDP")
667    expectedudp = expected.getlayer("UDP")
668    if actualudp and expectedudp:
669      if expectedudp.sport is None:
670        actualudp.sport = None
671        actualudp.chksum = None
672      elif actualudp.chksum == 0xffff and expectedudp.chksum == 0:
673        # Scapy does not appear to change 0 to 0xffff as required by RFC 768.
674        # It is possible that scapy has been upgraded and this no longer triggers.
675        actualudp.chksum = 0
676
677    # Since the TCP code below messes with options, recalculate the length.
678    if actualip:
679      actualip.len = None
680    if actualipv6:
681      actualipv6.plen = None
682
683    # Blank out TCP fields that we can't predict.
684    actualtcp = actual.getlayer("TCP")
685    expectedtcp = expected.getlayer("TCP")
686    if actualtcp and expectedtcp:
687      actualtcp.dataofs = expectedtcp.dataofs
688      actualtcp.options = expectedtcp.options
689      actualtcp.window = expectedtcp.window
690      if expectedtcp.sport is None:
691        actualtcp.sport = None
692      if expectedtcp.seq is None:
693        actualtcp.seq = None
694      if expectedtcp.ack is None:
695        actualtcp.ack = None
696      actualtcp.chksum = None
697
698    # Serialize the packet so that expected packet fields that are only set when
699    # a packet is serialized e.g., the checksum) are filled in.
700    expected_real = expected.__class__(bytes(expected))
701    actual_real = actual.__class__(bytes(actual))
702    # repr() can be expensive. Call it only if the test is going to fail and we
703    # want to see the error.
704    if expected_real != actual_real:
705      self.assertEqual(repr(expected_real), repr(actual_real))
706
707  def PacketMatches(self, expected, actual):
708    try:
709      self.assertPacketMatches(expected, actual)
710      return True
711    except AssertionError:
712      return False
713
714  def ExpectNoPacketsOn(self, netid, msg):
715    packets = self.ReadAllPacketsOn(netid)
716    if packets:
717      firstpacket = repr(packets[0])
718    else:
719      firstpacket = ""
720    self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
721
722  def ExpectPacketOn(self, netid, msg, expected):
723    # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop
724    # multicast packets unless the packet we expect to see is a multicast
725    # packet. For now the only tests that use this are IPv6.
726    ipv6 = expected.getlayer("IPv6")
727    if ipv6 and ipv6.dst.startswith("ff"):
728      include_multicast = True
729    else:
730      include_multicast = False
731
732    packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast)
733    self.assertTrue(packets, msg + ": received no packets")
734
735    # If we receive a packet that matches what we expected, return it.
736    for packet in packets:
737      if self.PacketMatches(expected, packet):
738        return packet
739
740    # None of the packets matched. Call assertPacketMatches to output a diff
741    # between the expected packet and the last packet we received. In theory,
742    # we'd output a diff to the packet that's the best match for what we
743    # expected, but this is good enough for now.
744    try:
745      self.assertPacketMatches(expected, packets[-1])
746    except Exception as e:
747      raise UnexpectedPacketError(
748          "%s: diff with last packet:\n%s" % (msg, str(e)))
749
750  def Combinations(self, version):
751    """Produces a list of combinations to test."""
752    combinations = []
753
754    # Check packets addressed to the IP addresses of all our interfaces...
755    for dest_ip_netid in self.tuns:
756      ip_if = self.GetInterfaceName(dest_ip_netid)
757      myaddr = self.MyAddress(version, dest_ip_netid)
758      prefix = {4: "172.22.", 6: "2001:db8:aaaa::"}[version]
759      remoteaddr = self.GetRandomDestination(prefix)
760
761      # ... coming in on all our interfaces.
762      for netid in self.tuns:
763        iif = self.GetInterfaceName(netid)
764        combinations.append((netid, iif, ip_if, myaddr, remoteaddr))
765
766    return combinations
767
768  def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc):
769    msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra)
770    if reply_desc:
771      msg += ": Expecting %s on %s" % (reply_desc, iif)
772    else:
773      msg += ": Expecting no packets on %s" % iif
774    return msg
775
776  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
777    self.ReceivePacketOn(netid, packet)
778    if reply:
779      return self.ExpectPacketOn(netid, msg, reply)
780    else:
781      self.ExpectNoPacketsOn(netid, msg)
782      return None
783
784
785class InboundMarkingTest(MultiNetworkBaseTest):
786  """Class that automatically sets up inbound marking."""
787
788  @classmethod
789  def setUpClass(cls):
790    super(InboundMarkingTest, cls).setUpClass()
791    cls.SetInboundMarks(True)
792
793  @classmethod
794  def tearDownClass(cls):
795    cls.SetInboundMarks(False)
796    super(InboundMarkingTest, cls).tearDownClass()
797