1#!/usr/bin/python3 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import time 18from socket import * # pylint: disable=wildcard-import 19 20import net_test 21import multinetwork_base 22import packets 23 24# TCP states. See include/net/tcp_states.h. 25TCP_ESTABLISHED = 1 26TCP_SYN_SENT = 2 27TCP_SYN_RECV = 3 28TCP_FIN_WAIT1 = 4 29TCP_FIN_WAIT2 = 5 30TCP_TIME_WAIT = 6 31TCP_CLOSE = 7 32TCP_CLOSE_WAIT = 8 33TCP_LAST_ACK = 9 34TCP_LISTEN = 10 35TCP_CLOSING = 11 36TCP_NEW_SYN_RECV = 12 37 38TCP_NOT_YET_ACCEPTED = -1 39 40 41class TcpBaseTest(multinetwork_base.MultiNetworkBaseTest): 42 43 def __init__(self, *args, **kwargs): 44 super().__init__(*args, **kwargs) 45 self.accepted = None 46 self.s = None 47 self.last_packet = None 48 self.sent_fin = False 49 50 def CloseSockets(self): 51 if self.accepted: 52 self.accepted.close() 53 self.accepted = None 54 if self.s: 55 self.s.close() 56 self.s = None 57 58 def tearDown(self): 59 self.CloseSockets() 60 super(TcpBaseTest, self).tearDown() 61 62 def OpenListenSocket(self, version, netid): 63 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version] 64 address = {4: "0.0.0.0", 5: "::", 6: "::"}[version] 65 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) 66 # We haven't configured inbound iptables marking, so bind explicitly. 67 self.SelectInterface(s, netid, "mark") 68 self.port = net_test.BindRandomPort(version, s) 69 return s 70 71 def _ReceiveAndExpectResponse(self, netid, packet, reply, msg): 72 pkt = super(TcpBaseTest, self)._ReceiveAndExpectResponse(netid, packet, 73 reply, msg) 74 self.last_packet = pkt 75 return pkt 76 77 def ReceivePacketOn(self, netid, packet): 78 super(TcpBaseTest, self).ReceivePacketOn(netid, packet) 79 self.last_packet = packet 80 81 def ReceiveRstPacketOn(self, netid): 82 # self.last_packet is the last packet we received. Invert direction twice. 83 _, ack = packets.ACK(self.version, self.myaddr, self.remoteaddr, 84 self.last_packet) 85 desc, rst = packets.RST(self.version, self.remoteaddr, self.myaddr, 86 ack) 87 super(TcpBaseTest, self).ReceivePacketOn(netid, rst) 88 89 def RstPacket(self): 90 return packets.RST(self.version, self.myaddr, self.remoteaddr, 91 self.last_packet, self.sent_fin) 92 93 def FinPacket(self): 94 return packets.FIN(self.version, self.myaddr, self.remoteaddr, 95 self.last_packet) 96 97 def ExpectPacketOn(self, netid, msg, pkt): 98 self.sent_fin |= (pkt.getlayer("TCP").flags & packets.TCP_FIN) != 0 99 return super(TcpBaseTest, self).ExpectPacketOn(netid, msg, pkt) 100 101 def IncomingConnection(self, version, end_state, netid): 102 self.s = self.OpenListenSocket(version, netid) 103 self.end_state = end_state 104 105 remoteaddr = self.remoteaddr = self.GetRemoteAddress(version) 106 remotesockaddr = self.remotesockaddr = self.GetRemoteSocketAddress(version) 107 108 myaddr = self.myaddr = self.MyAddress(version, netid) 109 mysockaddr = self.mysockaddr = self.MySocketAddress(version, netid) 110 111 if version == 5: version = 4 112 self.version = version 113 114 if end_state == TCP_LISTEN: 115 return 116 117 desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr) 118 synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn) 119 msg = "Received %s, expected to see reply %s" % (desc, synack_desc) 120 reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg) 121 if end_state == TCP_SYN_RECV: 122 return 123 124 establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1] 125 self.ReceivePacketOn(netid, establishing_ack) 126 127 if end_state == TCP_NOT_YET_ACCEPTED: 128 return 129 130 self.accepted, _ = self.s.accept() 131 net_test.DisableFinWait(self.accepted) 132 133 if end_state == TCP_ESTABLISHED: 134 return 135 136 desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack, 137 payload=net_test.UDP_PAYLOAD) 138 self.accepted.send(net_test.UDP_PAYLOAD) 139 self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data) 140 141 desc, fin = packets.FIN(version, remoteaddr, myaddr, data) 142 fin = packets._GetIpLayer(version)(bytes(fin)) 143 ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin) 144 msg = "Received %s, expected to see reply %s" % (desc, ack_desc) 145 146 # TODO: Why can't we use this? 147 # self._ReceiveAndExpectResponse(netid, fin, ack, msg) 148 self.ReceivePacketOn(netid, fin) 149 time.sleep(0.1) 150 self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack) 151 if end_state == TCP_CLOSE_WAIT: 152 return 153 154 raise ValueError("Invalid TCP state %d specified" % end_state) 155