1#!/usr/bin/python3 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import 18from errno import * # pylint: disable=wildcard-import 19import binascii 20import os 21import random 22import select 23from socket import * # pylint: disable=wildcard-import 24import struct 25import threading 26import time 27import unittest 28 29import cstruct 30import multinetwork_base 31import net_test 32import packets 33import sock_diag 34import tcp_test 35 36# Mostly empty structure definition containing only the fields we currently use. 37TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh") 38 39NUM_SOCKETS = 30 40NO_BYTECODE = b"" 41 42IPPROTO_SCTP = 132 43 44def HaveSctp(): 45 try: 46 s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP) 47 s.close() 48 return True 49 except IOError: 50 return False 51 52HAVE_SCTP = HaveSctp() 53 54 55class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): 56 """Basic tests for SOCK_DIAG functionality. 57 58 Relevant kernel commits: 59 android-3.4: 60 ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields 61 99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 62 63 android-3.10: 64 3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields 65 f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 66 67 android-3.18: 68 e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 69 70 android-4.4: 71 525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 72 """ 73 @staticmethod 74 def _CreateLotsOfSockets(socktype): 75 # Dict mapping (addr, sport, dport) tuples to socketpairs. 76 socketpairs = {} 77 for _ in range(NUM_SOCKETS): 78 family, addr = random.choice([ 79 (AF_INET, "127.0.0.1"), 80 (AF_INET6, "::1"), 81 (AF_INET6, "::ffff:127.0.0.1")]) 82 socketpair = net_test.CreateSocketPair(family, socktype, addr) 83 sport, dport = (socketpair[0].getsockname()[1], 84 socketpair[1].getsockname()[1]) 85 socketpairs[(addr, sport, dport)] = socketpair 86 return socketpairs 87 88 def assertSocketClosed(self, sock): 89 self.assertRaisesErrno(ENOTCONN, sock.getpeername) 90 91 def assertSocketConnected(self, sock): 92 sock.getpeername() # No errors? Socket is alive and connected. 93 94 def assertSocketsClosed(self, socketpair): 95 for sock in socketpair: 96 self.assertSocketClosed(sock) 97 98 def assertMarkIs(self, mark, attrs): 99 self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None)) 100 101 def assertSockInfoMatchesSocket(self, s, info): 102 diag_msg, attrs = info 103 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 104 self.assertEqual(diag_msg.family, family) 105 106 src, sport = s.getsockname()[0:2] 107 self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src)) 108 self.assertEqual(diag_msg.id.sport, sport) 109 110 if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]: 111 dst, dport = s.getpeername()[0:2] 112 self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) 113 self.assertEqual(diag_msg.id.dport, dport) 114 else: 115 self.assertRaisesErrno(ENOTCONN, s.getpeername) 116 117 mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK) 118 self.assertMarkIs(mark, attrs) 119 120 def PackAndCheckBytecode(self, instructions): 121 bytecode = self.sock_diag.PackBytecode(instructions) 122 decoded = self.sock_diag.DecodeBytecode(bytecode) 123 self.assertEqual(len(instructions), len(decoded)) 124 self.assertFalse("???" in decoded) 125 return bytecode 126 127 def _EventDuringBlockingCall(self, sock, call, expected_errno, event): 128 """Simulates an external event during a blocking call on sock. 129 130 Args: 131 sock: The socket to use. 132 call: A function, the call to make. Takes one parameter, sock. 133 expected_errno: The value that call is expected to fail with, or None if 134 call is expected to succeed. 135 event: A function, the event that will happen during the blocking call. 136 Takes one parameter, sock. 137 """ 138 thread = SocketExceptionThread(sock, call) 139 thread.start() 140 time.sleep(0.1) 141 event(sock) 142 thread.join(1) 143 self.assertFalse(thread.is_alive()) 144 if expected_errno is not None: 145 self.assertIsNotNone(thread.exception) 146 self.assertTrue(isinstance(thread.exception, IOError), 147 "Expected IOError, got %s" % thread.exception) 148 self.assertEqual(expected_errno, thread.exception.errno) 149 else: 150 self.assertIsNone(thread.exception) 151 self.assertSocketClosed(sock) 152 153 def CloseDuringBlockingCall(self, sock, call, expected_errno): 154 self._EventDuringBlockingCall( 155 sock, call, expected_errno, 156 lambda sock: self.sock_diag.CloseSocketFromFd(sock)) 157 158 def setUp(self): 159 super(SockDiagBaseTest, self).setUp() 160 self.sock_diag = sock_diag.SockDiag() 161 self.socketpairs = {} 162 163 def tearDown(self): 164 for socketpair in list(self.socketpairs.values()): 165 for s in socketpair: 166 s.close() 167 super(SockDiagBaseTest, self).tearDown() 168 169 170class SockDiagTest(SockDiagBaseTest): 171 172 def testFindsMappedSockets(self): 173 """Tests that inet_diag_find_one_icsk can find mapped sockets.""" 174 socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 175 "::ffff:127.0.0.1") 176 for sock in socketpair: 177 diag_msg = self.sock_diag.FindSockDiagFromFd(sock) 178 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 179 self.sock_diag.GetSockInfo(diag_req) 180 # No errors? Good. 181 182 for sock in socketpair: 183 sock.close() 184 185 def CheckFindsAllMySockets(self, socktype, proto): 186 """Tests that basic socket dumping works.""" 187 self.socketpairs = self._CreateLotsOfSockets(socktype) 188 sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE) 189 self.assertGreaterEqual(len(sockets), NUM_SOCKETS) 190 191 # Find the cookies for all of our sockets. 192 cookies = {} 193 for diag_msg, unused_attrs in sockets: 194 addr = self.sock_diag.GetSourceAddress(diag_msg) 195 sport = diag_msg.id.sport 196 dport = diag_msg.id.dport 197 if (addr, sport, dport) in self.socketpairs: 198 cookies[(addr, sport, dport)] = diag_msg.id.cookie 199 elif (addr, dport, sport) in self.socketpairs: 200 cookies[(addr, sport, dport)] = diag_msg.id.cookie 201 202 # Did we find all the cookies? 203 self.assertEqual(2 * NUM_SOCKETS, len(cookies)) 204 205 socketpairs = list(self.socketpairs.values()) 206 random.shuffle(socketpairs) 207 for socketpair in socketpairs: 208 for sock in socketpair: 209 # Check that we can find a diag_msg by scanning a dump. 210 self.assertSockInfoMatchesSocket( 211 sock, 212 self.sock_diag.FindSockInfoFromFd(sock)) 213 cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie 214 215 # Check that we can find a diag_msg once we know the cookie. 216 req = self.sock_diag.DiagReqFromSocket(sock) 217 req.id.cookie = cookie 218 if proto == IPPROTO_UDP: 219 # Kernel bug: for UDP sockets, the order of arguments must be swapped. 220 # See testDemonstrateUdpGetSockIdBug. 221 req.id.sport, req.id.dport = req.id.dport, req.id.sport 222 req.id.src, req.id.dst = req.id.dst, req.id.src 223 info = self.sock_diag.GetSockInfo(req) 224 self.assertSockInfoMatchesSocket(sock, info) 225 226 for socketpair in socketpairs: 227 for sock in socketpair: 228 sock.close() 229 230 def assertItemsEqual(self, expected, actual): 231 try: 232 super(SockDiagTest, self).assertItemsEqual(expected, actual) 233 except AttributeError: 234 # This was renamed in python3 but has the same behaviour. 235 super(SockDiagTest, self).assertCountEqual(expected, actual) 236 237 def testFindsAllMySocketsTcp(self): 238 self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP) 239 240 def testFindsAllMySocketsUdp(self): 241 self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP) 242 243 def testBytecodeCompilation(self): 244 # pylint: disable=bad-whitespace 245 instructions = [ 246 (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0 247 (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8 248 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16 249 (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44 250 (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48 251 (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64 252 (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72 253 # 76 acc 254 # 80 rej 255 ] 256 # pylint: enable=bad-whitespace 257 bytecode = self.PackAndCheckBytecode(instructions) 258 expected = ( 259 b"0208500000000000" 260 b"050848000000ffff" 261 b"071c20000a800000ffffffff00000000000000000000000000000001" 262 b"01041c00" 263 b"0718200002200000ffffffff7f000001" 264 b"0508100000006566" 265 b"00040400" 266 ) 267 states = 1 << tcp_test.TCP_ESTABLISHED 268 self.assertEqual(expected, binascii.hexlify(bytecode)) 269 self.assertEqual(76, len(bytecode)) 270 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 271 filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode, 272 states=states) 273 allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE, 274 states=states) 275 self.assertItemsEqual(allsockets, filteredsockets) 276 277 # Pick a few sockets in hash table order, and check that the bytecode we 278 # compiled selects them properly. 279 for socketpair in list(self.socketpairs.values())[:20]: 280 for s in socketpair: 281 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 282 instructions = [ 283 (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport), 284 (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport), 285 (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport), 286 (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport), 287 ] 288 bytecode = self.PackAndCheckBytecode(instructions) 289 self.assertEqual(32, len(bytecode)) 290 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) 291 self.assertEqual(1, len(sockets)) 292 293 # TODO: why doesn't comparing the cstructs work? 294 self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack()) 295 296 def testCrossFamilyBytecode(self): 297 """Checks for a cross-family bug in inet_diag_hostcond matching. 298 299 Relevant kernel commits: 300 android-3.4: 301 f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run() 302 """ 303 # TODO: this is only here because the test fails if there are any open 304 # sockets other than the ones it creates itself. Make the bytecode more 305 # specific and remove it. 306 states = 1 << tcp_test.TCP_ESTABLISHED 307 self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE, 308 states=states)) 309 310 unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1") 311 unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1") 312 313 bytecode4 = self.PackAndCheckBytecode([ 314 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))]) 315 bytecode6 = self.PackAndCheckBytecode([ 316 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))]) 317 318 # IPv4/v6 filters must never match IPv6/IPv4 sockets... 319 v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4, 320 states=states) 321 self.assertTrue(v4socks) 322 self.assertTrue(all(d.family == AF_INET for d, _ in v4socks)) 323 324 v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6, 325 states=states) 326 self.assertTrue(v6socks) 327 self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks)) 328 329 # Except for mapped addresses, which match both IPv4 and IPv6. 330 pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, 331 "::ffff:127.0.0.1") 332 diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5] 333 v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 334 bytecode4, 335 states=states)] 336 v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, 337 bytecode6, 338 states=states)] 339 self.assertTrue(all(d in v4socks for d in diag_msgs)) 340 self.assertTrue(all(d in v6socks for d in diag_msgs)) 341 342 for sock in unused_pair4: 343 sock.close() 344 345 for sock in unused_pair6: 346 sock.close() 347 348 for sock in pair5: 349 sock.close() 350 351 def testPortComparisonValidation(self): 352 """Checks for a bug in validating port comparison bytecode. 353 354 Relevant kernel commits: 355 android-3.4: 356 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads 357 """ 358 bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8)) 359 self.assertEqual("???", 360 self.sock_diag.DecodeBytecode(bytecode)) 361 self.assertRaisesErrno( 362 EINVAL, 363 self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack()) 364 365 def testNonSockDiagCommand(self): 366 def DiagDump(code): 367 sock_id = self.sock_diag._EmptyInetDiagSockId() 368 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, 369 sock_id)) 370 self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg) 371 372 op = sock_diag.SOCK_DIAG_BY_FAMILY 373 DiagDump(op) # No errors? Good. 374 self.assertRaisesErrno(EINVAL, DiagDump, op + 17) 375 376 def CheckSocketCookie(self, inet, addr): 377 """Tests that getsockopt SO_COOKIE can get cookie for all sockets.""" 378 socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr) 379 for sock in socketpair: 380 diag_msg = self.sock_diag.FindSockDiagFromFd(sock) 381 cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) 382 self.assertEqual(diag_msg.id.cookie, cookie) 383 384 for sock in socketpair: 385 sock.close() 386 387 def testGetsockoptcookie(self): 388 self.CheckSocketCookie(AF_INET, "127.0.0.1") 389 self.CheckSocketCookie(AF_INET6, "::1") 390 391 def testDemonstrateUdpGetSockIdBug(self): 392 # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup 393 # by passing the source address as the source address argument. 394 # Unfortunately those functions are intended to match local sockets based 395 # on received packets, and the argument that ends up being compared with 396 # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not 397 # have this bug. Upstream has confirmed that this will not be fixed: 398 # https://www.mail-archive.com/[email protected]/msg248638.html 399 """Documents a bug: getting UDP sockets requires swapping src and dst.""" 400 for version in [4, 5, 6]: 401 family = net_test.GetAddressFamily(version) 402 s = socket(family, SOCK_DGRAM, 0) 403 self.SelectInterface(s, self.RandomNetid(), "mark") 404 s.connect((self.GetRemoteSocketAddress(version), 53)) 405 406 # Create a fully-specified diag req from our socket, including cookie if 407 # we can get it. 408 req = self.sock_diag.DiagReqFromSocket(s) 409 req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) 410 411 # As is, this request does not find anything. 412 with self.assertRaisesErrno(ENOENT): 413 self.sock_diag.GetSockInfo(req) 414 415 # But if we swap src and dst, the kernel finds our socket. 416 req.id.sport, req.id.dport = req.id.dport, req.id.sport 417 req.id.src, req.id.dst = req.id.dst, req.id.src 418 419 self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req)) 420 421 s.close() 422 423 424class SockDestroyTest(SockDiagBaseTest): 425 """Tests that SOCK_DESTROY works correctly. 426 427 Relevant kernel commits: 428 net-next: 429 b613f56 net: diag: split inet_diag_dump_one_icsk into two 430 64be0ae net: diag: Add the ability to destroy a socket. 431 6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets. 432 c1e64e2 net: diag: Support destroying TCP sockets. 433 2010b93 net: tcp: deal with listen sockets properly in tcp_abort. 434 435 android-3.4: 436 d48ec88 net: diag: split inet_diag_dump_one_icsk into two 437 2438189 net: diag: Add the ability to destroy a socket. 438 7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets. 439 44047b2 net: diag: Support destroying TCP sockets. 440 200dae7 net: tcp: deal with listen sockets properly in tcp_abort. 441 442 android-3.10: 443 9eaff90 net: diag: split inet_diag_dump_one_icsk into two 444 d60326c net: diag: Add the ability to destroy a socket. 445 3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets. 446 529dfc6 net: diag: Support destroying TCP sockets. 447 9c712fe net: tcp: deal with listen sockets properly in tcp_abort. 448 449 android-3.18: 450 100263d net: diag: split inet_diag_dump_one_icsk into two 451 194c5f3 net: diag: Add the ability to destroy a socket. 452 8387ea2 net: diag: Support SOCK_DESTROY for inet sockets. 453 b80585a net: diag: Support destroying TCP sockets. 454 476c6ce net: tcp: deal with listen sockets properly in tcp_abort. 455 456 android-4.1: 457 56eebf8 net: diag: split inet_diag_dump_one_icsk into two 458 fb486c9 net: diag: Add the ability to destroy a socket. 459 0c02b7e net: diag: Support SOCK_DESTROY for inet sockets. 460 67c71d8 net: diag: Support destroying TCP sockets. 461 a76e0ec net: tcp: deal with listen sockets properly in tcp_abort. 462 e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk() 463 464 android-4.4: 465 76c83a9 net: diag: split inet_diag_dump_one_icsk into two 466 f7cf791 net: diag: Add the ability to destroy a socket. 467 1c42248 net: diag: Support SOCK_DESTROY for inet sockets. 468 c9e8440d net: diag: Support destroying TCP sockets. 469 3d9502c tcp: diag: add support for request sockets to tcp_abort() 470 001cf75 net: tcp: deal with listen sockets properly in tcp_abort. 471 """ 472 473 def testClosesSockets(self): 474 self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) 475 for _, socketpair in self.socketpairs.items(): 476 # Close one of the sockets. 477 # This will send a RST that will close the other side as well. 478 s = random.choice(socketpair) 479 if random.randrange(0, 2) == 1: 480 self.sock_diag.CloseSocketFromFd(s) 481 else: 482 diag_msg = self.sock_diag.FindSockDiagFromFd(s) 483 484 # Get the cookie wrong and ensure that we get an error and the socket 485 # is not closed. 486 real_cookie = diag_msg.id.cookie 487 diag_msg.id.cookie = os.urandom(len(real_cookie)) 488 req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 489 self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) 490 self.assertSocketConnected(s) 491 492 # Now close it with the correct cookie. 493 req.id.cookie = real_cookie 494 self.sock_diag.CloseSocket(req) 495 496 # Check that both sockets in the pair are closed. 497 self.assertSocketsClosed(socketpair) 498 499 # TODO: 500 # Test that killing unix sockets returns EOPNOTSUPP. 501 502 503class SocketExceptionThread(threading.Thread): 504 505 def __init__(self, sock, operation): 506 self.exception = None 507 super(SocketExceptionThread, self).__init__() 508 self.daemon = True 509 self.sock = sock 510 self.operation = operation 511 512 def run(self): 513 try: 514 self.operation(self.sock) 515 except (IOError, AssertionError) as e: 516 self.exception = e 517 518 519class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 520 521 def testIpv4MappedSynRecvSocket(self): 522 """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets. 523 524 Relevant kernel commits: 525 android-3.4: 526 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state 527 """ 528 netid = random.choice(list(self.tuns.keys())) 529 self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid) 530 sock_id = self.sock_diag._EmptyInetDiagSockId() 531 sock_id.sport = self.port 532 states = 1 << tcp_test.TCP_SYN_RECV 533 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id)) 534 children = self.sock_diag.Dump(req, NO_BYTECODE) 535 536 self.assertTrue(children) 537 for child, unused_args in children: 538 self.assertEqual(tcp_test.TCP_SYN_RECV, child.state) 539 self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr), 540 child.id.dst) 541 self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr), 542 child.id.src) 543 544 545class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 546 547 RWND_SIZE = 64000 548 TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd" 549 550 def setUp(self): 551 super(TcpRcvWindowTest, self).setUp() 552 self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w") 553 554 def checkInitRwndSize(self, version, netid): 555 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid) 556 tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP, 557 net_test.TCP_INFO, len(TcpInfo))) 558 self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh, 559 "Tcp rwnd of netid=%d, version=%d is not enough. " 560 "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE, 561 tcpInfo.tcpi_rcv_ssthresh)) 562 self.CloseSockets() 563 564 def checkSynPacketWindowSize(self, version, netid): 565 s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark") 566 myaddr = self.MyAddress(version, netid) 567 dstaddr = self.GetRemoteAddress(version) 568 dstsockaddr = self.GetRemoteSocketAddress(version) 569 desc, expected = packets.SYN(53, version, myaddr, dstaddr, 570 sport=None, seq=None) 571 self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53)) 572 msg = "IPv%s TCP connect: expected %s on %s" % ( 573 version, desc, self.GetInterfaceName(netid)) 574 syn = self.ExpectPacketOn(netid, msg, expected) 575 self.assertLess(self.RWND_SIZE, syn.window) 576 s.close() 577 578 def testTcpCwndSize(self): 579 for version in [4, 5, 6]: 580 for netid in self.NETIDS: 581 self.checkInitRwndSize(version, netid) 582 self.checkSynPacketWindowSize(version, netid) 583 584 585class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 586 587 def setUp(self): 588 super(SockDestroyTcpTest, self).setUp() 589 self.netid = random.choice(list(self.tuns.keys())) 590 591 def ExpectRst(self, msg): 592 desc, rst = self.RstPacket() 593 msg = "%s: expecting %s: " % (msg, desc) 594 self.ExpectPacketOn(self.netid, msg, rst) 595 596 def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): 597 """Closes the socket and checks whether a RST is sent or not.""" 598 if sock is not None: 599 self.assertIsNone(req, "Must specify sock or req, not both") 600 self.sock_diag.CloseSocketFromFd(sock) 601 self.assertRaisesErrno(EINVAL, sock.accept) 602 else: 603 self.assertIsNone(sock, "Must specify sock or req, not both") 604 self.sock_diag.CloseSocket(req) 605 606 if expect_reset: 607 self.ExpectRst(msg) 608 else: 609 msg = "%s: " % msg 610 self.ExpectNoPacketsOn(self.netid, msg) 611 612 if sock is not None and do_close: 613 sock.close() 614 615 def CheckTcpReset(self, state, statename): 616 for version in [4, 5, 6]: 617 msg = "Closing incoming IPv%d %s socket" % (version, statename) 618 self.IncomingConnection(version, state, self.netid) 619 self.CheckRstOnClose(self.s, None, False, msg) 620 if state != tcp_test.TCP_LISTEN: 621 msg = "Closing accepted IPv%d %s socket" % (version, statename) 622 self.CheckRstOnClose(self.accepted, None, True, msg) 623 self.CloseSockets() 624 625 def testTcpResets(self): 626 """Checks that closing sockets in appropriate states sends a RST.""" 627 self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN") 628 self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED") 629 self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT") 630 631 def testFinWait1Socket(self): 632 for version in [4, 5, 6]: 633 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 634 635 # Get the cookie so we can find this socket after we close it. 636 diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted) 637 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) 638 639 # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN. 640 net_test.EnableFinWait(self.accepted) 641 self.accepted.close() 642 self.accepted = None 643 diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1 644 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 645 self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) 646 desc, fin = self.FinPacket() 647 msg = "Closing FIN_WAIT1 socket" 648 self.ExpectPacketOn(self.netid, msg, fin) 649 650 # Destroy the socket. 651 self.sock_diag.CloseSocketFromFd(self.s) 652 self.assertRaisesErrno(EINVAL, self.s.accept) 653 try: 654 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 655 except Error as e: 656 # Newer kernels will have closed the socket and sent a RST. 657 self.assertEqual(ENOENT, e.errno) 658 self.ExpectRst(msg) 659 self.CloseSockets() 660 return 661 662 # Older kernels don't support closing FIN_WAIT1 sockets. 663 # Check that no RST is sent and that the socket is still in FIN_WAIT1, and 664 # advances to FIN_WAIT2 if the FIN is ACked. 665 msg = "%s: " % msg 666 self.ExpectNoPacketsOn(self.netid, msg) 667 self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) 668 669 # ACK the FIN so we don't trip over retransmits in future tests. 670 finversion = 4 if version == 5 else version 671 desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin) 672 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) 673 self.ReceivePacketOn(self.netid, finack) 674 675 # See if we can find the resulting FIN_WAIT2 socket. 676 diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2 677 infos = self.sock_diag.Dump(diag_req, NO_BYTECODE) 678 self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2 679 for diag_msg, attrs in infos), 680 "Expected to find FIN_WAIT2 socket in %s" % infos) 681 682 self.CloseSockets() 683 684 def FindChildSockets(self, s): 685 """Finds the SYN_RECV child sockets of a given listening socket.""" 686 d = self.sock_diag.FindSockDiagFromFd(self.s) 687 req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 688 req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED 689 req.id.cookie = b"\x00" * 8 690 691 bad_bytecode = self.PackAndCheckBytecode( 692 [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))]) 693 self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode)) 694 695 bytecode = self.PackAndCheckBytecode( 696 [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))]) 697 children = self.sock_diag.Dump(req, bytecode) 698 return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 699 for d, _ in children] 700 701 def CheckChildSocket(self, version, statename, parent_first): 702 state = getattr(tcp_test, statename) 703 704 self.IncomingConnection(version, state, self.netid) 705 706 d = self.sock_diag.FindSockDiagFromFd(self.s) 707 parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) 708 children = self.FindChildSockets(self.s) 709 self.assertEqual(1, len(children)) 710 711 is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED) 712 expected_state = tcp_test.TCP_ESTABLISHED if is_established else state 713 714 for child in children: 715 diag_msg, attrs = self.sock_diag.GetSockInfo(child) 716 self.assertEqual(diag_msg.state, expected_state) 717 self.assertMarkIs(self.netid, attrs) 718 719 def CloseParent(expect_reset): 720 msg = "Closing parent IPv%d %s socket %s child" % ( 721 version, statename, "before" if parent_first else "after") 722 self.CheckRstOnClose(self.s, None, expect_reset, msg) 723 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent) 724 725 def CheckChildrenClosed(): 726 for child in children: 727 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 728 729 def CloseChildren(): 730 for child in children: 731 msg = "Closing child IPv%d %s socket %s parent" % ( 732 version, statename, "after" if parent_first else "before") 733 self.sock_diag.GetSockInfo(child) 734 self.CheckRstOnClose(None, child, is_established, msg) 735 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) 736 CheckChildrenClosed() 737 738 if parent_first: 739 # Closing the parent will close child sockets, which will send a RST, 740 # iff they are already established. 741 CloseParent(is_established) 742 if is_established: 743 CheckChildrenClosed() 744 else: 745 CloseChildren() 746 CheckChildrenClosed() 747 else: 748 CloseChildren() 749 CloseParent(False) 750 751 self.CloseSockets() 752 753 def testChildSockets(self): 754 for version in [4, 5, 6]: 755 self.CheckChildSocket(version, "TCP_SYN_RECV", False) 756 self.CheckChildSocket(version, "TCP_SYN_RECV", True) 757 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False) 758 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True) 759 760 def testAcceptInterrupted(self): 761 """Tests that accept() is interrupted by SOCK_DESTROY.""" 762 for version in [4, 5, 6]: 763 self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid) 764 self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096) 765 self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) 766 self.assertRaisesErrno(ECONNABORTED, self.s.send, b"foo") 767 self.assertRaisesErrno(EINVAL, self.s.accept) 768 # TODO: this should really return an error such as ENOTCONN... 769 self.assertEqual(b"", self.s.recv(4096)) 770 self.CloseSockets() 771 772 def testReadInterrupted(self): 773 """Tests that read() is interrupted by SOCK_DESTROY.""" 774 for version in [4, 5, 6]: 775 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 776 self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), 777 ECONNABORTED) 778 # Writing returns EPIPE, and reading returns EOF. 779 self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo") 780 self.assertEqual(b"", self.accepted.recv(4096)) 781 self.assertEqual(b"", self.accepted.recv(4096)) 782 self.CloseSockets() 783 784 def testConnectInterrupted(self): 785 """Tests that connect() is interrupted by SOCK_DESTROY.""" 786 for version in [4, 5, 6]: 787 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 788 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) 789 self.SelectInterface(s, self.netid, "mark") 790 791 remotesockaddr = self.GetRemoteSocketAddress(version) 792 remoteaddr = self.GetRemoteAddress(version) 793 s.bind(("", 0)) 794 _, sport = s.getsockname()[:2] 795 self.CloseDuringBlockingCall( 796 s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED) 797 desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), 798 remoteaddr, sport=sport, seq=None) 799 self.ExpectPacketOn(self.netid, desc, syn) 800 msg = "SOCK_DESTROY of socket in connect, expected no RST" 801 self.ExpectNoPacketsOn(self.netid, msg) 802 s.close() 803 804 805class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 806 """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs. 807 808 The behaviour of poll() in these cases is not what we might expect: if only 809 POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT 810 is (also) specified, it will only return POLLOUT. 811 """ 812 813 POLLIN_OUT = select.POLLIN | select.POLLOUT 814 POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP 815 816 def setUp(self): 817 super(PollOnCloseTest, self).setUp() 818 self.netid = random.choice(list(self.tuns.keys())) 819 820 POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"), 821 (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")] 822 823 def PollResultToString(self, poll_events, ignoremask): 824 out = [] 825 for fd, event in poll_events: 826 flags = [name for (flag, name) in self.POLL_FLAGS 827 if event & flag & ~ignoremask != 0] 828 out.append((fd, "|".join(flags))) 829 return out 830 831 def BlockingPoll(self, sock, mask, expected, ignoremask): 832 p = select.poll() 833 p.register(sock, mask) 834 expected_fds = [(sock.fileno(), expected)] 835 # Don't block forever or we'll hang continuous test runs on failure. 836 # A 5-second timeout should be long enough not to be flaky. 837 actual_fds = p.poll(5000) 838 self.assertEqual(self.PollResultToString(expected_fds, ignoremask), 839 self.PollResultToString(actual_fds, ignoremask)) 840 841 def RstDuringBlockingCall(self, sock, call, expected_errno): 842 self._EventDuringBlockingCall( 843 sock, call, expected_errno, 844 lambda _: self.ReceiveRstPacketOn(self.netid)) 845 846 def assertSocketErrors(self, errno): 847 # The first operation returns the expected errno. 848 self.assertRaisesErrno(errno, self.accepted.recv, 4096) 849 850 # Subsequent operations behave as normal. 851 self.assertRaisesErrno(EPIPE, self.accepted.send, b"foo") 852 self.assertEqual(b"", self.accepted.recv(4096)) 853 self.assertEqual(b"", self.accepted.recv(4096)) 854 855 def CheckPollDestroy(self, mask, expected, ignoremask): 856 """Interrupts a poll() with SOCK_DESTROY.""" 857 for version in [4, 5, 6]: 858 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 859 self.CloseDuringBlockingCall( 860 self.accepted, 861 lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask), 862 None) 863 self.assertSocketErrors(ECONNABORTED) 864 self.CloseSockets() 865 866 def CheckPollRst(self, mask, expected, ignoremask): 867 """Interrupts a poll() by receiving a TCP RST.""" 868 for version in [4, 5, 6]: 869 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid) 870 self.RstDuringBlockingCall( 871 self.accepted, 872 lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask), 873 None) 874 self.assertSocketErrors(ECONNRESET) 875 self.CloseSockets() 876 877 def testReadPollRst(self): 878 self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, 0) 879 880 def testWritePollRst(self): 881 self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0) 882 883 def testReadWritePollRst(self): 884 self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0) 885 886 def testReadPollDestroy(self): 887 # tcp_abort has the same race that tcp_reset has, but it's not fixed yet. 888 ignoremask = select.POLLIN | select.POLLHUP 889 self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask) 890 891 def testWritePollDestroy(self): 892 self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0) 893 894 def testReadWritePollDestroy(self): 895 self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0) 896 897 898class SockDestroyUdpTest(SockDiagBaseTest): 899 900 """Tests SOCK_DESTROY on UDP sockets. 901 902 Relevant kernel commits: 903 upstream net-next: 904 5d77dca net: diag: support SOCK_DESTROY for UDP sockets 905 f95bf34 net: diag: make udp_diag_destroy work for mapped addresses. 906 """ 907 908 def testClosesUdpSockets(self): 909 self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM) 910 for _, socketpair in self.socketpairs.items(): 911 s1, s2 = socketpair 912 913 self.assertSocketConnected(s1) 914 self.sock_diag.CloseSocketFromFd(s1) 915 self.assertSocketClosed(s1) 916 917 self.assertSocketConnected(s2) 918 self.sock_diag.CloseSocketFromFd(s2) 919 self.assertSocketClosed(s2) 920 921 def BindToRandomPort(self, s, addr): 922 ATTEMPTS = 20 923 for i in range(20): 924 port = random.randrange(1024, 65535) 925 try: 926 s.bind((addr, port)) 927 return port 928 except error as e: 929 if e.errno != EADDRINUSE: 930 raise e 931 raise ValueError("Could not find a free port on %s after %d attempts" % 932 (addr, ATTEMPTS)) 933 934 def testSocketAddressesAfterClose(self): 935 for version in 4, 5, 6: 936 netid = random.choice(self.NETIDS) 937 dst = self.GetRemoteSocketAddress(version) 938 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 939 unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version] 940 941 # Closing a socket that was not explicitly bound (i.e., bound via 942 # connect(), not bind()) clears the source address and port. 943 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 944 self.SelectInterface(s, netid, "mark") 945 s.connect((dst, 53)) 946 self.sock_diag.CloseSocketFromFd(s) 947 self.assertEqual((unspec, 0), s.getsockname()[:2]) 948 s.close() 949 950 # Closing a socket bound to an IP address leaves the address as is. 951 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 952 src = self.MySocketAddress(version, netid) 953 s.bind((src, 0)) 954 s.connect((dst, 53)) 955 port = s.getsockname()[1] 956 self.sock_diag.CloseSocketFromFd(s) 957 self.assertEqual((src, 0), s.getsockname()[:2]) 958 s.close() 959 960 # Closing a socket bound to a port leaves the port as is. 961 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 962 port = self.BindToRandomPort(s, "") 963 s.connect((dst, 53)) 964 self.sock_diag.CloseSocketFromFd(s) 965 self.assertEqual((unspec, port), s.getsockname()[:2]) 966 s.close() 967 968 # Closing a socket bound to IP address and port leaves both as is. 969 s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") 970 src = self.MySocketAddress(version, netid) 971 port = self.BindToRandomPort(s, src) 972 self.sock_diag.CloseSocketFromFd(s) 973 self.assertEqual((src, port), s.getsockname()[:2]) 974 s.close() 975 976 def testReadInterrupted(self): 977 """Tests that read() is interrupted by SOCK_DESTROY.""" 978 for version in [4, 5, 6]: 979 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 980 s = net_test.UDPSocket(family) 981 self.SelectInterface(s, random.choice(self.NETIDS), "mark") 982 addr = self.GetRemoteSocketAddress(version) 983 984 # Check that reads on connected sockets are interrupted. 985 s.connect((addr, 53)) 986 self.assertEqual(3, s.send(b"foo")) 987 self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), 988 ECONNABORTED) 989 990 # A destroyed socket is no longer connected, but still usable. 991 self.assertRaisesErrno(EDESTADDRREQ, s.send, b"foo") 992 self.assertEqual(3, s.sendto(b"foo", (addr, 53))) 993 994 # Check that reads on unconnected sockets are also interrupted. 995 self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), 996 ECONNABORTED) 997 998 s.close() 999 1000class SockDestroyPermissionTest(SockDiagBaseTest): 1001 1002 def CheckPermissions(self, socktype): 1003 s = socket(AF_INET6, socktype, 0) 1004 self.SelectInterface(s, random.choice(self.NETIDS), "mark") 1005 if socktype == SOCK_STREAM: 1006 s.listen(1) 1007 expectedstate = tcp_test.TCP_LISTEN 1008 else: 1009 s.connect((self.GetRemoteAddress(6), 53)) 1010 expectedstate = tcp_test.TCP_ESTABLISHED 1011 1012 with net_test.RunAsUid(12345): 1013 self.assertRaisesErrno( 1014 EPERM, self.sock_diag.CloseSocketFromFd, s) 1015 1016 self.sock_diag.CloseSocketFromFd(s) 1017 self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s) 1018 1019 s.close() 1020 1021 1022 def testUdp(self): 1023 self.CheckPermissions(SOCK_DGRAM) 1024 1025 def testTcp(self): 1026 self.CheckPermissions(SOCK_STREAM) 1027 1028 1029class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest): 1030 1031 """Tests SOCK_DIAG bytecode filters that use marks. 1032 1033 Relevant kernel commits: 1034 upstream net-next: 1035 627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks. 1036 a52e95a net: diag: allow socket bytecode filters to match socket marks 1037 d545cac net: inet: diag: expose the socket mark to privileged processes. 1038 """ 1039 1040 def FilterEstablishedSockets(self, mark, mask): 1041 instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))] 1042 bytecode = self.sock_diag.PackBytecode(instructions) 1043 return self.sock_diag.DumpAllInetSockets( 1044 IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED)) 1045 1046 def assertSamePorts(self, ports, diag_msgs): 1047 expected = sorted(ports) 1048 actual = sorted([msg[0].id.sport for msg in diag_msgs]) 1049 self.assertEqual(expected, actual) 1050 1051 def SockInfoMatchesSocket(self, s, info): 1052 try: 1053 self.assertSockInfoMatchesSocket(s, info) 1054 return True 1055 except AssertionError: 1056 return False 1057 1058 @staticmethod 1059 def SocketDescription(s): 1060 return "%s -> %s" % (str(s.getsockname()), str(s.getpeername())) 1061 1062 def assertFoundSockets(self, infos, sockets): 1063 matches = {} 1064 for s in sockets: 1065 match = None 1066 for info in infos: 1067 if self.SockInfoMatchesSocket(s, info): 1068 if match: 1069 self.fail("Socket %s matched both %s and %s" % 1070 (self.SocketDescription(s), match, info)) 1071 matches[s] = info 1072 self.assertTrue(s in matches, "Did not find socket %s in dump" % 1073 self.SocketDescription(s)) 1074 1075 for i in infos: 1076 if i not in list(matches.values()): 1077 self.fail("Too many sockets in dump, first unexpected: %s" % str(i)) 1078 1079 def testMarkBytecode(self): 1080 family, addr = random.choice([ 1081 (AF_INET, "127.0.0.1"), 1082 (AF_INET6, "::1"), 1083 (AF_INET6, "::ffff:127.0.0.1")]) 1084 s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr) 1085 s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234) 1086 s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235) 1087 1088 infos = self.FilterEstablishedSockets(0x1234, 0xffff) 1089 self.assertFoundSockets(infos, [s1]) 1090 1091 infos = self.FilterEstablishedSockets(0x1234, 0xfffe) 1092 self.assertFoundSockets(infos, [s1, s2]) 1093 1094 infos = self.FilterEstablishedSockets(0x1235, 0xffff) 1095 self.assertFoundSockets(infos, [s2]) 1096 1097 infos = self.FilterEstablishedSockets(0x0, 0x0) 1098 self.assertFoundSockets(infos, [s1, s2]) 1099 1100 infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00) 1101 self.assertEqual(0, len(infos)) 1102 1103 with net_test.RunAsUid(12345): 1104 self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets, 1105 0xfff0000, 0xf0fed00) 1106 1107 s1.close() 1108 s2.close() 1109 1110 @staticmethod 1111 def SetRandomMark(s): 1112 # Python doesn't like marks that don't fit into a signed int. 1113 mark = random.randrange(0, 2**31 - 1) 1114 s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark) 1115 return mark 1116 1117 def assertSocketMarkIs(self, s, mark): 1118 diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) 1119 self.assertMarkIs(mark, attrs) 1120 with net_test.RunAsUid(12345): 1121 diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s) 1122 self.assertMarkIs(None, attrs) 1123 1124 def testMarkInAttributes(self): 1125 testcases = [(AF_INET, "127.0.0.1"), 1126 (AF_INET6, "::1"), 1127 (AF_INET6, "::ffff:127.0.0.1")] 1128 for family, addr in testcases: 1129 # TCP listen sockets. 1130 server = socket(family, SOCK_STREAM, 0) 1131 server.bind((addr, 0)) 1132 port = server.getsockname()[1] 1133 server.listen(1) # Or the socket won't be in the hashtables. 1134 server_mark = self.SetRandomMark(server) 1135 self.assertSocketMarkIs(server, server_mark) 1136 1137 # TCP client sockets. 1138 client = socket(family, SOCK_STREAM, 0) 1139 client_mark = self.SetRandomMark(client) 1140 client.connect((addr, port)) 1141 self.assertSocketMarkIs(client, client_mark) 1142 1143 # TCP server sockets. 1144 accepted, _ = server.accept() 1145 self.assertSocketMarkIs(accepted, server_mark) 1146 1147 accepted_mark = self.SetRandomMark(accepted) 1148 self.assertSocketMarkIs(accepted, accepted_mark) 1149 self.assertSocketMarkIs(server, server_mark) 1150 1151 accepted.close() 1152 server.close() 1153 client.close() 1154 1155 # Other TCP states are tested in SockDestroyTcpTest. 1156 1157 # UDP sockets. 1158 s = socket(family, SOCK_DGRAM, 0) 1159 mark = self.SetRandomMark(s) 1160 s.connect(("", 53)) 1161 self.assertSocketMarkIs(s, mark) 1162 s.close() 1163 1164 # Basic test for SCTP. sctp_diag was only added in 4.7. 1165 if HAVE_SCTP: 1166 s = socket(family, SOCK_STREAM, IPPROTO_SCTP) 1167 s.bind((addr, 0)) 1168 s.listen(1) 1169 mark = self.SetRandomMark(s) 1170 self.assertSocketMarkIs(s, mark) 1171 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE) 1172 self.assertEqual(1, len(sockets)) 1173 self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None)) 1174 s.close() 1175 1176 1177if __name__ == "__main__": 1178 unittest.main() 1179