xref: /aosp_15_r20/kernel/tests/net/test/sock_diag_test.py (revision 2f2c4c7ab4226c71756b9c31670392fdd6887c4f)
1#!/usr/bin/python3
2#
3# Copyright 2015 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18from errno import *  # pylint: disable=wildcard-import
19import binascii
20import os
21import random
22import select
23from socket import *  # pylint: disable=wildcard-import
24import struct
25import threading
26import time
27import unittest
28
29import cstruct
30import multinetwork_base
31import net_test
32import packets
33import sock_diag
34import tcp_test
35
36# Mostly empty structure definition containing only the fields we currently use.
37TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh")
38
39NUM_SOCKETS = 30
40NO_BYTECODE = b""
41
42IPPROTO_SCTP = 132
43
44def HaveSctp():
45  try:
46    s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP)
47    s.close()
48    return True
49  except IOError:
50    return False
51
52HAVE_SCTP = HaveSctp()
53
54
55class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
56  """Basic tests for SOCK_DIAG functionality.
57
58    Relevant kernel commits:
59      android-3.4:
60        ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
61        99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
62
63      android-3.10:
64        3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
65        f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
66
67      android-3.18:
68        e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
69
70      android-4.4:
71        525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
72  """
73  @staticmethod
74  def _CreateLotsOfSockets(socktype):
75    # Dict mapping (addr, sport, dport) tuples to socketpairs.
76    socketpairs = {}
77    for _ in range(NUM_SOCKETS):
78      family, addr = random.choice([
79          (AF_INET, "127.0.0.1"),
80          (AF_INET6, "::1"),
81          (AF_INET6, "::ffff:127.0.0.1")])
82      socketpair = net_test.CreateSocketPair(family, socktype, addr)
83      sport, dport = (socketpair[0].getsockname()[1],
84                      socketpair[1].getsockname()[1])
85      socketpairs[(addr, sport, dport)] = socketpair
86    return socketpairs
87
88  def assertSocketClosed(self, sock):
89    self.assertRaisesErrno(ENOTCONN, sock.getpeername)
90
91  def assertSocketConnected(self, sock):
92    sock.getpeername()  # No errors? Socket is alive and connected.
93
94  def assertSocketsClosed(self, socketpair):
95    for sock in socketpair:
96      self.assertSocketClosed(sock)
97
98  def assertMarkIs(self, mark, attrs):
99    self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None))
100
101  def assertSockInfoMatchesSocket(self, s, info):
102    diag_msg, attrs = info
103    family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
104    self.assertEqual(diag_msg.family, family)
105
106    src, sport = s.getsockname()[0:2]
107    self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
108    self.assertEqual(diag_msg.id.sport, sport)
109
110    if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
111      dst, dport = s.getpeername()[0:2]
112      self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
113      self.assertEqual(diag_msg.id.dport, dport)
114    else:
115      self.assertRaisesErrno(ENOTCONN, s.getpeername)
116
117    mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
118    self.assertMarkIs(mark, attrs)
119
120  def PackAndCheckBytecode(self, instructions):
121    bytecode = self.sock_diag.PackBytecode(instructions)
122    decoded = self.sock_diag.DecodeBytecode(bytecode)
123    self.assertEqual(len(instructions), len(decoded))
124    self.assertFalse("???" in decoded)
125    return bytecode
126
127  def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
128    """Simulates an external event during a blocking call on sock.
129
130    Args:
131      sock: The socket to use.
132      call: A function, the call to make. Takes one parameter, sock.
133      expected_errno: The value that call is expected to fail with, or None if
134        call is expected to succeed.
135      event: A function, the event that will happen during the blocking call.
136        Takes one parameter, sock.
137    """
138    thread = SocketExceptionThread(sock, call)
139    thread.start()
140    time.sleep(0.1)
141    event(sock)
142    thread.join(1)
143    self.assertFalse(thread.is_alive())
144    if expected_errno is not None:
145      self.assertIsNotNone(thread.exception)
146      self.assertTrue(isinstance(thread.exception, IOError),
147                      "Expected IOError, got %s" % thread.exception)
148      self.assertEqual(expected_errno, thread.exception.errno)
149    else:
150      self.assertIsNone(thread.exception)
151    self.assertSocketClosed(sock)
152
153  def CloseDuringBlockingCall(self, sock, call, expected_errno):
154    self._EventDuringBlockingCall(
155        sock, call, expected_errno,
156        lambda sock: self.sock_diag.CloseSocketFromFd(sock))
157
158  def setUp(self):
159    super(SockDiagBaseTest, self).setUp()
160    self.sock_diag = sock_diag.SockDiag()
161    self.socketpairs = {}
162
163  def tearDown(self):
164    for socketpair in list(self.socketpairs.values()):
165      for s in socketpair:
166        s.close()
167    super(SockDiagBaseTest, self).tearDown()
168
169
170class SockDiagTest(SockDiagBaseTest):
171
172  def testFindsMappedSockets(self):
173    """Tests that inet_diag_find_one_icsk can find mapped sockets."""
174    socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
175                                           "::ffff:127.0.0.1")
176    for sock in socketpair:
177      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
178      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
179      self.sock_diag.GetSockInfo(diag_req)
180      # No errors? Good.
181
182    for sock in socketpair:
183      sock.close()
184
185  def CheckFindsAllMySockets(self, socktype, proto):
186    """Tests that basic socket dumping works."""
187    self.socketpairs = self._CreateLotsOfSockets(socktype)
188    sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE)
189    self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
190
191    # Find the cookies for all of our sockets.
192    cookies = {}
193    for diag_msg, unused_attrs in sockets:
194      addr = self.sock_diag.GetSourceAddress(diag_msg)
195      sport = diag_msg.id.sport
196      dport = diag_msg.id.dport
197      if (addr, sport, dport) in self.socketpairs:
198        cookies[(addr, sport, dport)] = diag_msg.id.cookie
199      elif (addr, dport, sport) in self.socketpairs:
200        cookies[(addr, sport, dport)] = diag_msg.id.cookie
201
202    # Did we find all the cookies?
203    self.assertEqual(2 * NUM_SOCKETS, len(cookies))
204
205    socketpairs = list(self.socketpairs.values())
206    random.shuffle(socketpairs)
207    for socketpair in socketpairs:
208      for sock in socketpair:
209        # Check that we can find a diag_msg by scanning a dump.
210        self.assertSockInfoMatchesSocket(
211            sock,
212            self.sock_diag.FindSockInfoFromFd(sock))
213        cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
214
215        # Check that we can find a diag_msg once we know the cookie.
216        req = self.sock_diag.DiagReqFromSocket(sock)
217        req.id.cookie = cookie
218        if proto == IPPROTO_UDP:
219          # Kernel bug: for UDP sockets, the order of arguments must be swapped.
220          # See testDemonstrateUdpGetSockIdBug.
221          req.id.sport, req.id.dport = req.id.dport, req.id.sport
222          req.id.src, req.id.dst = req.id.dst, req.id.src
223        info = self.sock_diag.GetSockInfo(req)
224        self.assertSockInfoMatchesSocket(sock, info)
225
226    for socketpair in socketpairs:
227      for sock in socketpair:
228        sock.close()
229
230  def assertItemsEqual(self, expected, actual):
231    try:
232      super(SockDiagTest, self).assertItemsEqual(expected, actual)
233    except AttributeError:
234      # This was renamed in python3 but has the same behaviour.
235      super(SockDiagTest, self).assertCountEqual(expected, actual)
236
237  def testFindsAllMySocketsTcp(self):
238    self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP)
239
240  def testFindsAllMySocketsUdp(self):
241    self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP)
242
243  def testBytecodeCompilation(self):
244    # pylint: disable=bad-whitespace
245    instructions = [
246        (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
247        (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
248        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
249        (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
250        (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
251        (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
252        (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
253                                                                       # 76 acc
254                                                                       # 80 rej
255    ]
256    # pylint: enable=bad-whitespace
257    bytecode = self.PackAndCheckBytecode(instructions)
258    expected = (
259        b"0208500000000000"
260        b"050848000000ffff"
261        b"071c20000a800000ffffffff00000000000000000000000000000001"
262        b"01041c00"
263        b"0718200002200000ffffffff7f000001"
264        b"0508100000006566"
265        b"00040400"
266    )
267    states = 1 << tcp_test.TCP_ESTABLISHED
268    self.assertEqual(expected, binascii.hexlify(bytecode))
269    self.assertEqual(76, len(bytecode))
270    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
271    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
272                                                        states=states)
273    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
274                                                   states=states)
275    self.assertItemsEqual(allsockets, filteredsockets)
276
277    # Pick a few sockets in hash table order, and check that the bytecode we
278    # compiled selects them properly.
279    for socketpair in list(self.socketpairs.values())[:20]:
280      for s in socketpair:
281        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
282        instructions = [
283            (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
284            (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
285            (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
286            (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
287        ]
288        bytecode = self.PackAndCheckBytecode(instructions)
289        self.assertEqual(32, len(bytecode))
290        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
291        self.assertEqual(1, len(sockets))
292
293        # TODO: why doesn't comparing the cstructs work?
294        self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack())
295
296  def testCrossFamilyBytecode(self):
297    """Checks for a cross-family bug in inet_diag_hostcond matching.
298
299    Relevant kernel commits:
300      android-3.4:
301        f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
302    """
303    # TODO: this is only here because the test fails if there are any open
304    # sockets other than the ones it creates itself. Make the bytecode more
305    # specific and remove it.
306    states = 1 << tcp_test.TCP_ESTABLISHED
307    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
308                                                       states=states))
309
310    unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
311    unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
312
313    bytecode4 = self.PackAndCheckBytecode([
314        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
315    bytecode6 = self.PackAndCheckBytecode([
316        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
317
318    # IPv4/v6 filters must never match IPv6/IPv4 sockets...
319    v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4,
320                                                  states=states)
321    self.assertTrue(v4socks)
322    self.assertTrue(all(d.family == AF_INET for d, _ in v4socks))
323
324    v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6,
325                                                  states=states)
326    self.assertTrue(v6socks)
327    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks))
328
329    # Except for mapped addresses, which match both IPv4 and IPv6.
330    pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
331                                      "::ffff:127.0.0.1")
332    diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
333    v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
334                                                               bytecode4,
335                                                               states=states)]
336    v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
337                                                               bytecode6,
338                                                               states=states)]
339    self.assertTrue(all(d in v4socks for d in diag_msgs))
340    self.assertTrue(all(d in v6socks for d in diag_msgs))
341
342    for sock in unused_pair4:
343      sock.close()
344
345    for sock in unused_pair6:
346      sock.close()
347
348    for sock in pair5:
349      sock.close()
350
351  def testPortComparisonValidation(self):
352    """Checks for a bug in validating port comparison bytecode.
353
354    Relevant kernel commits:
355      android-3.4:
356        5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
357    """
358    bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
359    self.assertEqual("???",
360                      self.sock_diag.DecodeBytecode(bytecode))
361    self.assertRaisesErrno(
362        EINVAL,
363        self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
364
365  def testNonSockDiagCommand(self):
366    def DiagDump(code):
367      sock_id = self.sock_diag._EmptyInetDiagSockId()
368      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
369                                     sock_id))
370      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg)
371
372    op = sock_diag.SOCK_DIAG_BY_FAMILY
373    DiagDump(op)  # No errors? Good.
374    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
375
376  def CheckSocketCookie(self, inet, addr):
377    """Tests that getsockopt SO_COOKIE can get cookie for all sockets."""
378    socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr)
379    for sock in socketpair:
380      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
381      cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
382      self.assertEqual(diag_msg.id.cookie, cookie)
383
384    for sock in socketpair:
385      sock.close()
386
387  def testGetsockoptcookie(self):
388    self.CheckSocketCookie(AF_INET, "127.0.0.1")
389    self.CheckSocketCookie(AF_INET6, "::1")
390
391  def testDemonstrateUdpGetSockIdBug(self):
392    # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup
393    # by passing the source address as the source address argument.
394    # Unfortunately those functions are intended to match local sockets based
395    # on received packets, and the argument that ends up being compared with
396    # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not
397    # have this bug.  Upstream has confirmed that this will not be fixed:
398    # https://www.mail-archive.com/[email protected]/msg248638.html
399    """Documents a bug: getting UDP sockets requires swapping src and dst."""
400    for version in [4, 5, 6]:
401      family = net_test.GetAddressFamily(version)
402      s = socket(family, SOCK_DGRAM, 0)
403      self.SelectInterface(s, self.RandomNetid(), "mark")
404      s.connect((self.GetRemoteSocketAddress(version), 53))
405
406      # Create a fully-specified diag req from our socket, including cookie if
407      # we can get it.
408      req = self.sock_diag.DiagReqFromSocket(s)
409      req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
410
411      # As is, this request does not find anything.
412      with self.assertRaisesErrno(ENOENT):
413        self.sock_diag.GetSockInfo(req)
414
415      # But if we swap src and dst, the kernel finds our socket.
416      req.id.sport, req.id.dport = req.id.dport, req.id.sport
417      req.id.src, req.id.dst = req.id.dst, req.id.src
418
419      self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req))
420
421      s.close()
422
423
424class SockDestroyTest(SockDiagBaseTest):
425  """Tests that SOCK_DESTROY works correctly.
426
427  Relevant kernel commits:
428    net-next:
429      b613f56 net: diag: split inet_diag_dump_one_icsk into two
430      64be0ae net: diag: Add the ability to destroy a socket.
431      6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
432      c1e64e2 net: diag: Support destroying TCP sockets.
433      2010b93 net: tcp: deal with listen sockets properly in tcp_abort.
434
435    android-3.4:
436      d48ec88 net: diag: split inet_diag_dump_one_icsk into two
437      2438189 net: diag: Add the ability to destroy a socket.
438      7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
439      44047b2 net: diag: Support destroying TCP sockets.
440      200dae7 net: tcp: deal with listen sockets properly in tcp_abort.
441
442    android-3.10:
443      9eaff90 net: diag: split inet_diag_dump_one_icsk into two
444      d60326c net: diag: Add the ability to destroy a socket.
445      3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
446      529dfc6 net: diag: Support destroying TCP sockets.
447      9c712fe net: tcp: deal with listen sockets properly in tcp_abort.
448
449    android-3.18:
450      100263d net: diag: split inet_diag_dump_one_icsk into two
451      194c5f3 net: diag: Add the ability to destroy a socket.
452      8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
453      b80585a net: diag: Support destroying TCP sockets.
454      476c6ce net: tcp: deal with listen sockets properly in tcp_abort.
455
456    android-4.1:
457      56eebf8 net: diag: split inet_diag_dump_one_icsk into two
458      fb486c9 net: diag: Add the ability to destroy a socket.
459      0c02b7e net: diag: Support SOCK_DESTROY for inet sockets.
460      67c71d8 net: diag: Support destroying TCP sockets.
461      a76e0ec net: tcp: deal with listen sockets properly in tcp_abort.
462      e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
463
464    android-4.4:
465      76c83a9 net: diag: split inet_diag_dump_one_icsk into two
466      f7cf791 net: diag: Add the ability to destroy a socket.
467      1c42248 net: diag: Support SOCK_DESTROY for inet sockets.
468      c9e8440d net: diag: Support destroying TCP sockets.
469      3d9502c tcp: diag: add support for request sockets to tcp_abort()
470      001cf75 net: tcp: deal with listen sockets properly in tcp_abort.
471  """
472
473  def testClosesSockets(self):
474    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
475    for _, socketpair in self.socketpairs.items():
476      # Close one of the sockets.
477      # This will send a RST that will close the other side as well.
478      s = random.choice(socketpair)
479      if random.randrange(0, 2) == 1:
480        self.sock_diag.CloseSocketFromFd(s)
481      else:
482        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
483
484        # Get the cookie wrong and ensure that we get an error and the socket
485        # is not closed.
486        real_cookie = diag_msg.id.cookie
487        diag_msg.id.cookie = os.urandom(len(real_cookie))
488        req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
489        self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
490        self.assertSocketConnected(s)
491
492        # Now close it with the correct cookie.
493        req.id.cookie = real_cookie
494        self.sock_diag.CloseSocket(req)
495
496      # Check that both sockets in the pair are closed.
497      self.assertSocketsClosed(socketpair)
498
499  # TODO:
500  # Test that killing unix sockets returns EOPNOTSUPP.
501
502
503class SocketExceptionThread(threading.Thread):
504
505  def __init__(self, sock, operation):
506    self.exception = None
507    super(SocketExceptionThread, self).__init__()
508    self.daemon = True
509    self.sock = sock
510    self.operation = operation
511
512  def run(self):
513    try:
514      self.operation(self.sock)
515    except (IOError, AssertionError) as e:
516      self.exception = e
517
518
519class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
520
521  def testIpv4MappedSynRecvSocket(self):
522    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
523
524    Relevant kernel commits:
525         android-3.4:
526           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
527    """
528    netid = random.choice(list(self.tuns.keys()))
529    self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
530    sock_id = self.sock_diag._EmptyInetDiagSockId()
531    sock_id.sport = self.port
532    states = 1 << tcp_test.TCP_SYN_RECV
533    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
534    children = self.sock_diag.Dump(req, NO_BYTECODE)
535
536    self.assertTrue(children)
537    for child, unused_args in children:
538      self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
539      self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr),
540                       child.id.dst)
541      self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr),
542                       child.id.src)
543
544
545class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
546
547  RWND_SIZE = 64000
548  TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd"
549
550  def setUp(self):
551    super(TcpRcvWindowTest, self).setUp()
552    self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w")
553
554  def checkInitRwndSize(self, version, netid):
555    self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid)
556    tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP,
557                                               net_test.TCP_INFO, len(TcpInfo)))
558    self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh,
559                    "Tcp rwnd of netid=%d, version=%d is not enough. "
560                    "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE,
561                                                tcpInfo.tcpi_rcv_ssthresh))
562    self.CloseSockets()
563
564  def checkSynPacketWindowSize(self, version, netid):
565    s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark")
566    myaddr = self.MyAddress(version, netid)
567    dstaddr = self.GetRemoteAddress(version)
568    dstsockaddr = self.GetRemoteSocketAddress(version)
569    desc, expected = packets.SYN(53, version, myaddr, dstaddr,
570                                 sport=None, seq=None)
571    self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53))
572    msg = "IPv%s TCP connect: expected %s on %s" % (
573        version, desc, self.GetInterfaceName(netid))
574    syn = self.ExpectPacketOn(netid, msg, expected)
575    self.assertLess(self.RWND_SIZE, syn.window)
576    s.close()
577
578  def testTcpCwndSize(self):
579    for version in [4, 5, 6]:
580      for netid in self.NETIDS:
581        self.checkInitRwndSize(version, netid)
582        self.checkSynPacketWindowSize(version, netid)
583
584
585class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
586
587  def setUp(self):
588    super(SockDestroyTcpTest, self).setUp()
589    self.netid = random.choice(list(self.tuns.keys()))
590
591  def ExpectRst(self, msg):
592    desc, rst = self.RstPacket()
593    msg = "%s: expecting %s: " % (msg, desc)
594    self.ExpectPacketOn(self.netid, msg, rst)
595
596  def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
597    """Closes the socket and checks whether a RST is sent or not."""
598    if sock is not None:
599      self.assertIsNone(req, "Must specify sock or req, not both")
600      self.sock_diag.CloseSocketFromFd(sock)
601      self.assertRaisesErrno(EINVAL, sock.accept)
602    else:
603      self.assertIsNone(sock, "Must specify sock or req, not both")
604      self.sock_diag.CloseSocket(req)
605
606    if expect_reset:
607      self.ExpectRst(msg)
608    else:
609      msg = "%s: " % msg
610      self.ExpectNoPacketsOn(self.netid, msg)
611
612    if sock is not None and do_close:
613      sock.close()
614
615  def CheckTcpReset(self, state, statename):
616    for version in [4, 5, 6]:
617      msg = "Closing incoming IPv%d %s socket" % (version, statename)
618      self.IncomingConnection(version, state, self.netid)
619      self.CheckRstOnClose(self.s, None, False, msg)
620      if state != tcp_test.TCP_LISTEN:
621        msg = "Closing accepted IPv%d %s socket" % (version, statename)
622        self.CheckRstOnClose(self.accepted, None, True, msg)
623        self.CloseSockets()
624
625  def testTcpResets(self):
626    """Checks that closing sockets in appropriate states sends a RST."""
627    self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
628    self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
629    self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
630
631  def testFinWait1Socket(self):
632    for version in [4, 5, 6]:
633      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
634
635      # Get the cookie so we can find this socket after we close it.
636      diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted)
637      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
638
639      # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN.
640      net_test.EnableFinWait(self.accepted)
641      self.accepted.close()
642      self.accepted = None
643      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1
644      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
645      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
646      desc, fin = self.FinPacket()
647      msg = "Closing FIN_WAIT1 socket"
648      self.ExpectPacketOn(self.netid, msg, fin)
649
650      # Destroy the socket.
651      self.sock_diag.CloseSocketFromFd(self.s)
652      self.assertRaisesErrno(EINVAL, self.s.accept)
653      try:
654        diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
655      except Error as e:
656        # Newer kernels will have closed the socket and sent a RST.
657        self.assertEqual(ENOENT, e.errno)
658        self.ExpectRst(msg)
659        self.CloseSockets()
660        return
661
662      # Older kernels don't support closing FIN_WAIT1 sockets.
663      # Check that no RST is sent and that the socket is still in FIN_WAIT1, and
664      # advances to FIN_WAIT2 if the FIN is ACked.
665      msg = "%s: " % msg
666      self.ExpectNoPacketsOn(self.netid, msg)
667      self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
668
669      # ACK the FIN so we don't trip over retransmits in future tests.
670      finversion = 4 if version == 5 else version
671      desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
672      diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
673      self.ReceivePacketOn(self.netid, finack)
674
675      # See if we can find the resulting FIN_WAIT2 socket.
676      diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2
677      infos = self.sock_diag.Dump(diag_req, NO_BYTECODE)
678      self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2
679                          for diag_msg, attrs in infos),
680                      "Expected to find FIN_WAIT2 socket in %s" % infos)
681
682      self.CloseSockets()
683
684  def FindChildSockets(self, s):
685    """Finds the SYN_RECV child sockets of a given listening socket."""
686    d = self.sock_diag.FindSockDiagFromFd(self.s)
687    req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
688    req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
689    req.id.cookie = b"\x00" * 8
690
691    bad_bytecode = self.PackAndCheckBytecode(
692        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))])
693    self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode))
694
695    bytecode = self.PackAndCheckBytecode(
696        [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))])
697    children = self.sock_diag.Dump(req, bytecode)
698    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
699            for d, _ in children]
700
701  def CheckChildSocket(self, version, statename, parent_first):
702    state = getattr(tcp_test, statename)
703
704    self.IncomingConnection(version, state, self.netid)
705
706    d = self.sock_diag.FindSockDiagFromFd(self.s)
707    parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
708    children = self.FindChildSockets(self.s)
709    self.assertEqual(1, len(children))
710
711    is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
712    expected_state = tcp_test.TCP_ESTABLISHED if is_established else state
713
714    for child in children:
715      diag_msg, attrs = self.sock_diag.GetSockInfo(child)
716      self.assertEqual(diag_msg.state, expected_state)
717      self.assertMarkIs(self.netid, attrs)
718
719    def CloseParent(expect_reset):
720      msg = "Closing parent IPv%d %s socket %s child" % (
721          version, statename, "before" if parent_first else "after")
722      self.CheckRstOnClose(self.s, None, expect_reset, msg)
723      self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent)
724
725    def CheckChildrenClosed():
726      for child in children:
727        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
728
729    def CloseChildren():
730      for child in children:
731        msg = "Closing child IPv%d %s socket %s parent" % (
732            version, statename, "after" if parent_first else "before")
733        self.sock_diag.GetSockInfo(child)
734        self.CheckRstOnClose(None, child, is_established, msg)
735        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
736      CheckChildrenClosed()
737
738    if parent_first:
739      # Closing the parent will close child sockets, which will send a RST,
740      # iff they are already established.
741      CloseParent(is_established)
742      if is_established:
743        CheckChildrenClosed()
744      else:
745        CloseChildren()
746        CheckChildrenClosed()
747    else:
748      CloseChildren()
749      CloseParent(False)
750
751    self.CloseSockets()
752
753  def testChildSockets(self):
754    for version in [4, 5, 6]:
755      self.CheckChildSocket(version, "TCP_SYN_RECV", False)
756      self.CheckChildSocket(version, "TCP_SYN_RECV", True)
757      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
758      self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
759
760  def testAcceptInterrupted(self):
761    """Tests that accept() is interrupted by SOCK_DESTROY."""
762    for version in [4, 5, 6]:
763      self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
764      self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096)
765      self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
766      self.assertRaisesErrno(ECONNABORTED, self.s.send, b"foo")
767      self.assertRaisesErrno(EINVAL, self.s.accept)
768      # TODO: this should really return an error such as ENOTCONN...
769      self.assertEqual(b"", self.s.recv(4096))
770      self.CloseSockets()
771
772  def testReadInterrupted(self):
773    """Tests that read() is interrupted by SOCK_DESTROY."""
774    for version in [4, 5, 6]:
775      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
776      self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
777                                   ECONNABORTED)
778      # Writing returns EPIPE, and reading returns EOF.
779      self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo")
780      self.assertEqual(b"", self.accepted.recv(4096))
781      self.assertEqual(b"", self.accepted.recv(4096))
782      self.CloseSockets()
783
784  def testConnectInterrupted(self):
785    """Tests that connect() is interrupted by SOCK_DESTROY."""
786    for version in [4, 5, 6]:
787      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
788      s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
789      self.SelectInterface(s, self.netid, "mark")
790
791      remotesockaddr = self.GetRemoteSocketAddress(version)
792      remoteaddr = self.GetRemoteAddress(version)
793      s.bind(("", 0))
794      _, sport = s.getsockname()[:2]
795      self.CloseDuringBlockingCall(
796          s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED)
797      desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
798                              remoteaddr, sport=sport, seq=None)
799      self.ExpectPacketOn(self.netid, desc, syn)
800      msg = "SOCK_DESTROY of socket in connect, expected no RST"
801      self.ExpectNoPacketsOn(self.netid, msg)
802      s.close()
803
804
805class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
806  """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.
807
808  The behaviour of poll() in these cases is not what we might expect: if only
809  POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
810  is (also) specified, it will only return POLLOUT.
811  """
812
813  POLLIN_OUT = select.POLLIN | select.POLLOUT
814  POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP
815
816  def setUp(self):
817    super(PollOnCloseTest, self).setUp()
818    self.netid = random.choice(list(self.tuns.keys()))
819
820  POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"),
821                (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")]
822
823  def PollResultToString(self, poll_events, ignoremask):
824    out = []
825    for fd, event in poll_events:
826      flags = [name for (flag, name) in self.POLL_FLAGS
827               if event & flag & ~ignoremask != 0]
828      out.append((fd, "|".join(flags)))
829    return out
830
831  def BlockingPoll(self, sock, mask, expected, ignoremask):
832    p = select.poll()
833    p.register(sock, mask)
834    expected_fds = [(sock.fileno(), expected)]
835    # Don't block forever or we'll hang continuous test runs on failure.
836    # A 5-second timeout should be long enough not to be flaky.
837    actual_fds = p.poll(5000)
838    self.assertEqual(self.PollResultToString(expected_fds, ignoremask),
839                     self.PollResultToString(actual_fds, ignoremask))
840
841  def RstDuringBlockingCall(self, sock, call, expected_errno):
842    self._EventDuringBlockingCall(
843        sock, call, expected_errno,
844        lambda _: self.ReceiveRstPacketOn(self.netid))
845
846  def assertSocketErrors(self, errno):
847    # The first operation returns the expected errno.
848    self.assertRaisesErrno(errno, self.accepted.recv, 4096)
849
850    # Subsequent operations behave as normal.
851    self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo")
852    self.assertEqual(b"", self.accepted.recv(4096))
853    self.assertEqual(b"", self.accepted.recv(4096))
854
855  def CheckPollDestroy(self, mask, expected, ignoremask):
856    """Interrupts a poll() with SOCK_DESTROY."""
857    for version in [4, 5, 6]:
858      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
859      self.CloseDuringBlockingCall(
860          self.accepted,
861          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
862          None)
863      self.assertSocketErrors(ECONNABORTED)
864      self.CloseSockets()
865
866  def CheckPollRst(self, mask, expected, ignoremask):
867    """Interrupts a poll() by receiving a TCP RST."""
868    for version in [4, 5, 6]:
869      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
870      self.RstDuringBlockingCall(
871          self.accepted,
872          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
873          None)
874      self.assertSocketErrors(ECONNRESET)
875      self.CloseSockets()
876
877  def testReadPollRst(self):
878    self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, 0)
879
880  def testWritePollRst(self):
881    self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0)
882
883  def testReadWritePollRst(self):
884    self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0)
885
886  def testReadPollDestroy(self):
887    # tcp_abort has the same race that tcp_reset has, but it's not fixed yet.
888    ignoremask = select.POLLIN | select.POLLHUP
889    self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
890
891  def testWritePollDestroy(self):
892    self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0)
893
894  def testReadWritePollDestroy(self):
895    self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0)
896
897
898class SockDestroyUdpTest(SockDiagBaseTest):
899
900  """Tests SOCK_DESTROY on UDP sockets.
901
902    Relevant kernel commits:
903      upstream net-next:
904        5d77dca net: diag: support SOCK_DESTROY for UDP sockets
905        f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
906  """
907
908  def testClosesUdpSockets(self):
909    self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
910    for _, socketpair in self.socketpairs.items():
911      s1, s2 = socketpair
912
913      self.assertSocketConnected(s1)
914      self.sock_diag.CloseSocketFromFd(s1)
915      self.assertSocketClosed(s1)
916
917      self.assertSocketConnected(s2)
918      self.sock_diag.CloseSocketFromFd(s2)
919      self.assertSocketClosed(s2)
920
921  def BindToRandomPort(self, s, addr):
922    ATTEMPTS = 20
923    for i in range(20):
924      port = random.randrange(1024, 65535)
925      try:
926        s.bind((addr, port))
927        return port
928      except error as e:
929        if e.errno != EADDRINUSE:
930          raise e
931    raise ValueError("Could not find a free port on %s after %d attempts" %
932                     (addr, ATTEMPTS))
933
934  def testSocketAddressesAfterClose(self):
935    for version in 4, 5, 6:
936      netid = random.choice(self.NETIDS)
937      dst = self.GetRemoteSocketAddress(version)
938      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
939      unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
940
941      # Closing a socket that was not explicitly bound (i.e., bound via
942      # connect(), not bind()) clears the source address and port.
943      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
944      self.SelectInterface(s, netid, "mark")
945      s.connect((dst, 53))
946      self.sock_diag.CloseSocketFromFd(s)
947      self.assertEqual((unspec, 0), s.getsockname()[:2])
948      s.close()
949
950      # Closing a socket bound to an IP address leaves the address as is.
951      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
952      src = self.MySocketAddress(version, netid)
953      s.bind((src, 0))
954      s.connect((dst, 53))
955      port = s.getsockname()[1]
956      self.sock_diag.CloseSocketFromFd(s)
957      self.assertEqual((src, 0), s.getsockname()[:2])
958      s.close()
959
960      # Closing a socket bound to a port leaves the port as is.
961      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
962      port = self.BindToRandomPort(s, "")
963      s.connect((dst, 53))
964      self.sock_diag.CloseSocketFromFd(s)
965      self.assertEqual((unspec, port), s.getsockname()[:2])
966      s.close()
967
968      # Closing a socket bound to IP address and port leaves both as is.
969      s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
970      src = self.MySocketAddress(version, netid)
971      port = self.BindToRandomPort(s, src)
972      self.sock_diag.CloseSocketFromFd(s)
973      self.assertEqual((src, port), s.getsockname()[:2])
974      s.close()
975
976  def testReadInterrupted(self):
977    """Tests that read() is interrupted by SOCK_DESTROY."""
978    for version in [4, 5, 6]:
979      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
980      s = net_test.UDPSocket(family)
981      self.SelectInterface(s, random.choice(self.NETIDS), "mark")
982      addr = self.GetRemoteSocketAddress(version)
983
984      # Check that reads on connected sockets are interrupted.
985      s.connect((addr, 53))
986      self.assertEqual(3, s.send(b"foo"))
987      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
988                                   ECONNABORTED)
989
990      # A destroyed socket is no longer connected, but still usable.
991      self.assertRaisesErrno(EDESTADDRREQ, s.send, b"foo")
992      self.assertEqual(3, s.sendto(b"foo", (addr, 53)))
993
994      # Check that reads on unconnected sockets are also interrupted.
995      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
996                                   ECONNABORTED)
997
998      s.close()
999
1000class SockDestroyPermissionTest(SockDiagBaseTest):
1001
1002  def CheckPermissions(self, socktype):
1003    s = socket(AF_INET6, socktype, 0)
1004    self.SelectInterface(s, random.choice(self.NETIDS), "mark")
1005    if socktype == SOCK_STREAM:
1006      s.listen(1)
1007      expectedstate = tcp_test.TCP_LISTEN
1008    else:
1009      s.connect((self.GetRemoteAddress(6), 53))
1010      expectedstate = tcp_test.TCP_ESTABLISHED
1011
1012    with net_test.RunAsUid(12345):
1013      self.assertRaisesErrno(
1014          EPERM, self.sock_diag.CloseSocketFromFd, s)
1015
1016    self.sock_diag.CloseSocketFromFd(s)
1017    self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s)
1018
1019    s.close()
1020
1021
1022  def testUdp(self):
1023    self.CheckPermissions(SOCK_DGRAM)
1024
1025  def testTcp(self):
1026    self.CheckPermissions(SOCK_STREAM)
1027
1028
1029class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
1030
1031  """Tests SOCK_DIAG bytecode filters that use marks.
1032
1033    Relevant kernel commits:
1034      upstream net-next:
1035        627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
1036        a52e95a net: diag: allow socket bytecode filters to match socket marks
1037        d545cac net: inet: diag: expose the socket mark to privileged processes.
1038  """
1039
1040  def FilterEstablishedSockets(self, mark, mask):
1041    instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))]
1042    bytecode = self.sock_diag.PackBytecode(instructions)
1043    return self.sock_diag.DumpAllInetSockets(
1044        IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED))
1045
1046  def assertSamePorts(self, ports, diag_msgs):
1047    expected = sorted(ports)
1048    actual = sorted([msg[0].id.sport for msg in diag_msgs])
1049    self.assertEqual(expected, actual)
1050
1051  def SockInfoMatchesSocket(self, s, info):
1052    try:
1053      self.assertSockInfoMatchesSocket(s, info)
1054      return True
1055    except AssertionError:
1056      return False
1057
1058  @staticmethod
1059  def SocketDescription(s):
1060    return "%s -> %s" % (str(s.getsockname()), str(s.getpeername()))
1061
1062  def assertFoundSockets(self, infos, sockets):
1063    matches = {}
1064    for s in sockets:
1065      match = None
1066      for info in infos:
1067        if self.SockInfoMatchesSocket(s, info):
1068          if match:
1069            self.fail("Socket %s matched both %s and %s" %
1070                      (self.SocketDescription(s), match, info))
1071          matches[s] = info
1072      self.assertTrue(s in matches, "Did not find socket %s in dump" %
1073                      self.SocketDescription(s))
1074
1075    for i in infos:
1076       if i not in list(matches.values()):
1077         self.fail("Too many sockets in dump, first unexpected: %s" % str(i))
1078
1079  def testMarkBytecode(self):
1080    family, addr = random.choice([
1081        (AF_INET, "127.0.0.1"),
1082        (AF_INET6, "::1"),
1083        (AF_INET6, "::ffff:127.0.0.1")])
1084    s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
1085    s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234)
1086    s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235)
1087
1088    infos = self.FilterEstablishedSockets(0x1234, 0xffff)
1089    self.assertFoundSockets(infos, [s1])
1090
1091    infos = self.FilterEstablishedSockets(0x1234, 0xfffe)
1092    self.assertFoundSockets(infos, [s1, s2])
1093
1094    infos = self.FilterEstablishedSockets(0x1235, 0xffff)
1095    self.assertFoundSockets(infos, [s2])
1096
1097    infos = self.FilterEstablishedSockets(0x0, 0x0)
1098    self.assertFoundSockets(infos, [s1, s2])
1099
1100    infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00)
1101    self.assertEqual(0, len(infos))
1102
1103    with net_test.RunAsUid(12345):
1104        self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets,
1105                               0xfff0000, 0xf0fed00)
1106
1107    s1.close()
1108    s2.close()
1109
1110  @staticmethod
1111  def SetRandomMark(s):
1112    # Python doesn't like marks that don't fit into a signed int.
1113    mark = random.randrange(0, 2**31 - 1)
1114    s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark)
1115    return mark
1116
1117  def assertSocketMarkIs(self, s, mark):
1118    diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
1119    self.assertMarkIs(mark, attrs)
1120    with net_test.RunAsUid(12345):
1121      diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
1122      self.assertMarkIs(None, attrs)
1123
1124  def testMarkInAttributes(self):
1125    testcases = [(AF_INET, "127.0.0.1"),
1126                 (AF_INET6, "::1"),
1127                 (AF_INET6, "::ffff:127.0.0.1")]
1128    for family, addr in testcases:
1129      # TCP listen sockets.
1130      server = socket(family, SOCK_STREAM, 0)
1131      server.bind((addr, 0))
1132      port = server.getsockname()[1]
1133      server.listen(1)  # Or the socket won't be in the hashtables.
1134      server_mark = self.SetRandomMark(server)
1135      self.assertSocketMarkIs(server, server_mark)
1136
1137      # TCP client sockets.
1138      client = socket(family, SOCK_STREAM, 0)
1139      client_mark = self.SetRandomMark(client)
1140      client.connect((addr, port))
1141      self.assertSocketMarkIs(client, client_mark)
1142
1143      # TCP server sockets.
1144      accepted, _ = server.accept()
1145      self.assertSocketMarkIs(accepted, server_mark)
1146
1147      accepted_mark = self.SetRandomMark(accepted)
1148      self.assertSocketMarkIs(accepted, accepted_mark)
1149      self.assertSocketMarkIs(server, server_mark)
1150
1151      accepted.close()
1152      server.close()
1153      client.close()
1154
1155      # Other TCP states are tested in SockDestroyTcpTest.
1156
1157      # UDP sockets.
1158      s = socket(family, SOCK_DGRAM, 0)
1159      mark = self.SetRandomMark(s)
1160      s.connect(("", 53))
1161      self.assertSocketMarkIs(s, mark)
1162      s.close()
1163
1164      # Basic test for SCTP. sctp_diag was only added in 4.7.
1165      if HAVE_SCTP:
1166        s = socket(family, SOCK_STREAM, IPPROTO_SCTP)
1167        s.bind((addr, 0))
1168        s.listen(1)
1169        mark = self.SetRandomMark(s)
1170        self.assertSocketMarkIs(s, mark)
1171        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE)
1172        self.assertEqual(1, len(sockets))
1173        self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None))
1174        s.close()
1175
1176
1177if __name__ == "__main__":
1178  unittest.main()
1179