1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net.cts;
18 
19 import static android.net.cts.PacketUtils.IP4_HDRLEN;
20 import static android.net.cts.PacketUtils.IP6_HDRLEN;
21 import static android.net.cts.PacketUtils.IPPROTO_ESP;
22 import static android.net.cts.PacketUtils.UDP_HDRLEN;
23 import static android.system.OsConstants.IPPROTO_UDP;
24 
25 import static org.junit.Assert.fail;
26 
27 import android.os.ParcelFileDescriptor;
28 
29 import com.android.net.module.util.CollectionUtils;
30 
31 import java.io.FileInputStream;
32 import java.io.FileOutputStream;
33 import java.io.IOException;
34 import java.nio.ByteBuffer;
35 import java.util.ArrayList;
36 import java.util.Arrays;
37 import java.util.List;
38 import java.util.function.Predicate;
39 
40 public class TunUtils {
41     private static final String TAG = TunUtils.class.getSimpleName();
42 
43     protected static final int IP4_ADDR_OFFSET = 12;
44     protected static final int IP4_ADDR_LEN = 4;
45     protected static final int IP6_ADDR_OFFSET = 8;
46     protected static final int IP6_ADDR_LEN = 16;
47     protected static final int IP4_PROTO_OFFSET = 9;
48     protected static final int IP6_PROTO_OFFSET = 6;
49 
50     private static final int SEQ_NUM_MATCH_NOT_REQUIRED = -1;
51 
52     private static final int DATA_BUFFER_LEN = 4096;
53     private static final int TIMEOUT = 2000;
54 
55     private final List<byte[]> mPackets = new ArrayList<>();
56     private final ParcelFileDescriptor mTunFd;
57     private final Thread mReaderThread;
58 
TunUtils(ParcelFileDescriptor tunFd)59     public TunUtils(ParcelFileDescriptor tunFd) {
60         mTunFd = tunFd;
61 
62         // Start background reader thread
63         mReaderThread =
64                 new Thread(
65                         () -> {
66                             try {
67                                 // Loop will exit and thread will quit when tunFd is closed.
68                                 // Receiving either EOF or an exception will exit this reader loop.
69                                 // FileInputStream in uninterruptable, so there's no good way to
70                                 // ensure that this thread shuts down except upon FD closure.
71                                 while (true) {
72                                     byte[] intercepted = receiveFromTun();
73                                     if (intercepted == null) {
74                                         // Exit once we've hit EOF
75                                         return;
76                                     } else if (intercepted.length > 0) {
77                                         // Only save packet if we've received any bytes.
78                                         synchronized (mPackets) {
79                                             mPackets.add(intercepted);
80                                             mPackets.notifyAll();
81                                         }
82                                     }
83                                 }
84                             } catch (IOException ignored) {
85                                 // Simply exit this reader thread
86                                 return;
87                             }
88                         });
89         mReaderThread.start();
90     }
91 
receiveFromTun()92     private byte[] receiveFromTun() throws IOException {
93         FileInputStream in = new FileInputStream(mTunFd.getFileDescriptor());
94         byte[] inBytes = new byte[DATA_BUFFER_LEN];
95         int bytesRead = in.read(inBytes);
96 
97         if (bytesRead < 0) {
98             return null; // return null for EOF
99         } else if (bytesRead >= DATA_BUFFER_LEN) {
100             throw new IllegalStateException("Too big packet. Fragmentation unsupported");
101         }
102         return Arrays.copyOf(inBytes, bytesRead);
103     }
104 
getFirstMatchingPacket(Predicate<byte[]> verifier, int startIndex)105     private byte[] getFirstMatchingPacket(Predicate<byte[]> verifier, int startIndex) {
106         synchronized (mPackets) {
107             for (int i = startIndex; i < mPackets.size(); i++) {
108                 byte[] pkt = mPackets.get(i);
109                 if (verifier.test(pkt)) {
110                     return pkt;
111                 }
112             }
113         }
114         return null;
115     }
116 
awaitPacket(Predicate<byte[]> verifier)117     protected byte[] awaitPacket(Predicate<byte[]> verifier) throws Exception {
118         long endTime = System.currentTimeMillis() + TIMEOUT;
119         int startIndex = 0;
120 
121         synchronized (mPackets) {
122             while (System.currentTimeMillis() < endTime) {
123                 final byte[] pkt = getFirstMatchingPacket(verifier, startIndex);
124                 if (pkt != null) {
125                     return pkt; // We've found the packet we're looking for.
126                 }
127 
128                 startIndex = mPackets.size();
129 
130                 // Try to prevent waiting too long. If waitTimeout <= 0, we've already hit timeout
131                 long waitTimeout = endTime - System.currentTimeMillis();
132                 if (waitTimeout > 0) {
133                     mPackets.wait(waitTimeout);
134                 }
135             }
136         }
137 
138         fail("No packet found matching verifier");
139         throw new IllegalStateException("Impossible condition; should have thrown in fail()");
140     }
141 
awaitEspPacketNoPlaintext( int spi, byte[] plaintext, boolean useEncap, int expectedPacketSize)142     public byte[] awaitEspPacketNoPlaintext(
143             int spi, byte[] plaintext, boolean useEncap, int expectedPacketSize) throws Exception {
144         final byte[] espPkt = awaitPacket(
145             (pkt) -> expectedPacketSize == pkt.length
146                     && isEspFailIfSpecifiedPlaintextFound(pkt, spi, useEncap, plaintext));
147 
148         return espPkt; // We've found the packet we're looking for.
149     }
150 
151     /** Await the expected ESP packet */
awaitEspPacket(int spi, boolean useEncap)152     public byte[] awaitEspPacket(int spi, boolean useEncap) throws Exception {
153         return awaitEspPacket(spi, useEncap, SEQ_NUM_MATCH_NOT_REQUIRED);
154     }
155 
156     /** Await the expected ESP packet with a matching sequence number */
awaitEspPacket(int spi, boolean useEncap, int seqNum)157     public byte[] awaitEspPacket(int spi, boolean useEncap, int seqNum) throws Exception {
158         return awaitPacket((pkt) -> isEsp(pkt, spi, seqNum, useEncap));
159     }
160 
isMatchingEspPacket(byte[] pkt, int espOffset, int spi, int seqNum)161     private static boolean isMatchingEspPacket(byte[] pkt, int espOffset, int spi, int seqNum) {
162         ByteBuffer buffer = ByteBuffer.wrap(pkt);
163         buffer.get(new byte[espOffset]); // Skip IP, UDP header
164         int actualSpi = buffer.getInt();
165         int actualSeqNum = buffer.getInt();
166 
167         if (actualSeqNum < 0) {
168             throw new UnsupportedOperationException(
169                     "actualSeqNum overflowed and needs to be converted to an unsigned integer");
170         }
171 
172         boolean isSeqNumMatched = (seqNum == SEQ_NUM_MATCH_NOT_REQUIRED || seqNum == actualSeqNum);
173 
174         return actualSpi == spi && isSeqNumMatched;
175     }
176 
177     /**
178      * Variant of isEsp that also fails the test if the provided plaintext is found
179      *
180      * @param pkt the packet bytes to verify
181      * @param spi the expected SPI to look for
182      * @param encap whether encap was enabled, and the packet has a UDP header
183      * @param plaintext the plaintext packet before outbound encryption, which MUST not appear in
184      *     the provided packet.
185      */
isEspFailIfSpecifiedPlaintextFound( byte[] pkt, int spi, boolean encap, byte[] plaintext)186     private static boolean isEspFailIfSpecifiedPlaintextFound(
187             byte[] pkt, int spi, boolean encap, byte[] plaintext) {
188         if (CollectionUtils.indexOfSubArray(pkt, plaintext) != -1) {
189             fail("Banned plaintext packet found");
190         }
191 
192         return isEsp(pkt, spi, SEQ_NUM_MATCH_NOT_REQUIRED, encap);
193     }
194 
isEsp(byte[] pkt, int spi, int seqNum, boolean encap)195     private static boolean isEsp(byte[] pkt, int spi, int seqNum, boolean encap) {
196         if (isIpv6(pkt)) {
197             if (encap) {
198                 return pkt[IP6_PROTO_OFFSET] == IPPROTO_UDP
199                         && isMatchingEspPacket(pkt, IP6_HDRLEN + UDP_HDRLEN, spi, seqNum);
200             } else {
201                 return pkt[IP6_PROTO_OFFSET] == IPPROTO_ESP
202                         && isMatchingEspPacket(pkt, IP6_HDRLEN, spi, seqNum);
203             }
204 
205         } else {
206             // Use default IPv4 header length (assuming no options)
207             if (encap) {
208                 return pkt[IP4_PROTO_OFFSET] == IPPROTO_UDP
209                         && isMatchingEspPacket(pkt, IP4_HDRLEN + UDP_HDRLEN, spi, seqNum);
210             } else {
211                 return pkt[IP4_PROTO_OFFSET] == IPPROTO_ESP
212                         && isMatchingEspPacket(pkt, IP4_HDRLEN, spi, seqNum);
213             }
214         }
215     }
216 
217 
isIpv6(byte[] pkt)218     public static boolean isIpv6(byte[] pkt) {
219         // First nibble shows IP version. 0x60 for IPv6
220         return (pkt[0] & (byte) 0xF0) == (byte) 0x60;
221     }
222 
getReflectedPacket(byte[] pkt)223     private static byte[] getReflectedPacket(byte[] pkt) {
224         byte[] reflected = Arrays.copyOf(pkt, pkt.length);
225 
226         if (isIpv6(pkt)) {
227             // Set reflected packet's dst to that of the original's src
228             System.arraycopy(
229                     pkt, // src
230                     IP6_ADDR_OFFSET + IP6_ADDR_LEN, // src offset
231                     reflected, // dst
232                     IP6_ADDR_OFFSET, // dst offset
233                     IP6_ADDR_LEN); // len
234             // Set reflected packet's src IP to that of the original's dst IP
235             System.arraycopy(
236                     pkt, // src
237                     IP6_ADDR_OFFSET, // src offset
238                     reflected, // dst
239                     IP6_ADDR_OFFSET + IP6_ADDR_LEN, // dst offset
240                     IP6_ADDR_LEN); // len
241         } else {
242             // Set reflected packet's dst to that of the original's src
243             System.arraycopy(
244                     pkt, // src
245                     IP4_ADDR_OFFSET + IP4_ADDR_LEN, // src offset
246                     reflected, // dst
247                     IP4_ADDR_OFFSET, // dst offset
248                     IP4_ADDR_LEN); // len
249             // Set reflected packet's src IP to that of the original's dst IP
250             System.arraycopy(
251                     pkt, // src
252                     IP4_ADDR_OFFSET, // src offset
253                     reflected, // dst
254                     IP4_ADDR_OFFSET + IP4_ADDR_LEN, // dst offset
255                     IP4_ADDR_LEN); // len
256         }
257         return reflected;
258     }
259 
260     /** Takes all captured packets, flips the src/dst, and re-injects them. */
reflectPackets()261     public void reflectPackets() throws IOException {
262         synchronized (mPackets) {
263             for (byte[] pkt : mPackets) {
264                 injectPacket(getReflectedPacket(pkt));
265             }
266         }
267     }
268 
injectPacket(byte[] pkt)269     public void injectPacket(byte[] pkt) throws IOException {
270         FileOutputStream out = new FileOutputStream(mTunFd.getFileDescriptor());
271         out.write(pkt);
272         out.flush();
273     }
274 
275     /** Resets the intercepted packets. */
reset()276     public void reset() throws IOException {
277         synchronized (mPackets) {
278             mPackets.clear();
279         }
280     }
281 }
282