1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.net.ip
17 
18 import android.net.MacAddress
19 import android.net.ip.ConnectivityPacketTracker.Dependencies
20 import android.os.Handler
21 import android.os.HandlerThread
22 import android.system.ErrnoException
23 import android.system.Os
24 import android.system.OsConstants.AF_UNIX
25 import android.system.OsConstants.SOCK_NONBLOCK
26 import android.system.OsConstants.SOCK_STREAM
27 import android.util.LocalLog
28 import androidx.test.filters.SmallTest
29 import com.android.net.module.util.HexDump
30 import com.android.net.module.util.InterfaceParams
31 import com.android.testutils.DevSdkIgnoreRunner
32 import com.android.testutils.waitForIdle
33 import java.io.FileDescriptor
34 import java.io.InterruptedIOException
35 import java.util.concurrent.CompletableFuture
36 import java.util.concurrent.TimeUnit
37 import kotlin.test.assertEquals
38 import libcore.io.IoUtils
39 import org.junit.After
40 import org.junit.Before
41 import org.junit.Test
42 import org.mockito.ArgumentMatchers.anyInt
43 import org.mockito.Mock
44 import org.mockito.Mockito
45 import org.mockito.Mockito.doReturn
46 import org.mockito.MockitoAnnotations
47 
48 /**
49  * Test for ConnectivityPacketTracker.
50  */
51 @SmallTest
52 @DevSdkIgnoreRunner.MonitorThreadLeak
53 class ConnectivityPacketTrackerTest {
54     companion object {
55         private const val TIMEOUT_MS: Long = 10000
56         private const val SLEEP_TIMEOUT_MS: Long = 500
57         private const val TEST_MAX_CAPTURE_PKT_SIZE: Int = 100
58         private const val TAG = "ConnectivityPacketTrackerTest"
59     }
60 
61     private val loInterfaceParams = InterfaceParams.getByName("lo")
62     private val ifParams =
63         InterfaceParams(
64             "lo",
65             loInterfaceParams.index,
66             MacAddress.fromBytes(byteArrayOf(2, 3, 4, 5, 6, 7)),
67             loInterfaceParams.defaultMtu
68         )
69     private val writeSocket = FileDescriptor()
<lambda>null70     private val handlerThread by lazy {
71         HandlerThread("$TAG-handler-thread").apply { start() }
72     }
<lambda>null73     private val handler by lazy { Handler(handlerThread.looper) }
74     @Mock private lateinit var mDependencies: Dependencies
75     @Mock private lateinit var localLog: LocalLog
76     @Before
setUpnull77     fun setUp() {
78         MockitoAnnotations.initMocks(this)
79         val readSocket = FileDescriptor()
80         Os.socketpair(AF_UNIX, SOCK_STREAM or SOCK_NONBLOCK, 0, writeSocket, readSocket)
81         doReturn(readSocket).`when`(mDependencies).createPacketReaderSocket(anyInt())
82         doReturn(TEST_MAX_CAPTURE_PKT_SIZE).`when`(mDependencies).maxCapturePktSize
83     }
84 
85     @After
tearDownnull86     fun tearDown() {
87         IoUtils.closeQuietly(writeSocket)
88         handler.waitForIdle(10000)
89         Mockito.framework().clearInlineMocks()
90         handlerThread.quitSafely()
91         handlerThread.join()
92     }
93 
94     @Test
testCapturePacketnull95     fun testCapturePacket() {
96         val packetTracker = getConnectivityPacketTracker()
97         // Using scapy to generate ARP request packet:
98         // eth = Ether(src="00:01:02:03:04:05", dst="01:02:03:04:05:06")
99         // arp = ARP()
100         // pkt = eth/arp
101         val arpPkt = """
102             010203040506000102030405080600010800060400015c857e3c74e1c0a8012200000000000000000000
103         """.replace("\\s+".toRegex(), "").trim().uppercase()
104         val arpPktByteArray = HexDump.hexStringToByteArray(arpPkt)
105         assertEquals(0, getCapturePacketTypeCount(packetTracker))
106         assertEquals(0, getMatchedPacketCount(packetTracker, arpPkt))
107 
108         // start capture packet
109         setCapture(packetTracker, true)
110 
111         for (i in 1..5) {
112             pretendPacketReceive(arpPktByteArray)
113             Thread.sleep(SLEEP_TIMEOUT_MS)
114         }
115 
116         assertEquals(1, getCapturePacketTypeCount(packetTracker))
117         assertEquals(5, getMatchedPacketCount(packetTracker, arpPkt))
118 
119         // stop capture packet
120         setCapture(packetTracker, false)
121         assertEquals(0, getCapturePacketTypeCount(packetTracker))
122         assertEquals(0, getMatchedPacketCount(packetTracker, arpPkt))
123     }
124 
125     @Test
testMaxCapturePacketSizenull126     fun testMaxCapturePacketSize() {
127         doReturn(3).`when`(mDependencies).maxCapturePktSize
128         val packetTracker = getConnectivityPacketTracker(mDependencies)
129 
130         // Using scapy to generate ARP request packet:
131         // eth = Ether(src="00:01:02:03:04:05", dst="01:02:03:04:05:06")
132         // arp = ARP()
133         // pkt = eth/arp
134         val arpPkt = """
135             010203040506000102030405080600010800060400015c857e3c74e1c0a8012200000000000000000000
136         """.replace("\\s+".toRegex(), "").trim().uppercase()
137         val arpPktByteArray = HexDump.hexStringToByteArray(arpPkt)
138         // start capture packet
139         setCapture(packetTracker, true)
140         val pktCnt = 5
141         val pktList = ArrayList<String>()
142         for (i in 0..<pktCnt) {
143             // modify the original packet's last byte
144             val modPkt = arpPktByteArray.copyOf()
145             modPkt[modPkt.size - 1] = i.toByte()
146             pretendPacketReceive(modPkt)
147             pktList.add(HexDump.toHexString(modPkt))
148             Thread.sleep(SLEEP_TIMEOUT_MS)
149         }
150 
151         // The old packets are evicted due to LruCache size
152         pktList.take(2).forEach {
153             assertEquals(0, getMatchedPacketCount(packetTracker, it))
154         }
155 
156         pktList.drop(2).forEach {
157             assertEquals(1, getMatchedPacketCount(packetTracker, it))
158         }
159 
160         assertEquals(mDependencies.maxCapturePktSize, getCapturePacketTypeCount(packetTracker))
161     }
162 
163     @Throws(InterruptedIOException::class, ErrnoException::class)
pretendPacketReceivenull164     private fun pretendPacketReceive(packet: ByteArray) {
165         Os.write(writeSocket, packet, 0, packet.size)
166     }
167 
getConnectivityPacketTrackernull168     private fun getConnectivityPacketTracker(
169         dependencies: Dependencies = mDependencies
170     ): ConnectivityPacketTracker {
171         val result = CompletableFuture<ConnectivityPacketTracker>()
172         handler.post {
173             try {
174                 val tracker = ConnectivityPacketTracker(handler, ifParams, localLog, dependencies)
175                 tracker.start(TAG)
176                 result.complete(tracker)
177             } catch (e: Exception) {
178                 result.completeExceptionally(e)
179             }
180         }
181 
182         return result.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
183     }
184 
setCapturenull185     private fun setCapture(
186         packetTracker: ConnectivityPacketTracker,
187         isCapturing: Boolean
188     ) {
189         val result = CompletableFuture<Unit>()
190         handler.post {
191             try {
192                 packetTracker.setCapture(isCapturing)
193                 result.complete(Unit)
194             } catch (e: Exception) {
195                 result.completeExceptionally(e)
196             }
197         }
198 
199         result.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
200     }
201 
getMatchedPacketCountnull202     private fun getMatchedPacketCount(
203         packetTracker: ConnectivityPacketTracker,
204         packet: String
205     ): Int {
206         val result = CompletableFuture<Int>()
207         handler.post {
208             try {
209                 result.complete(packetTracker.getMatchedPacketCount(packet))
210             } catch (e: Exception) {
211                 result.completeExceptionally(e)
212             }
213         }
214 
215         return result.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
216     }
217 
getCapturePacketTypeCountnull218     private fun getCapturePacketTypeCount(
219         packetTracker: ConnectivityPacketTracker
220     ): Int {
221         val result = CompletableFuture<Int>()
222         handler.post {
223             try {
224                 val totalCnt = packetTracker.capturePacketTypeCount
225                 result.complete(totalCnt)
226             } catch (e: Exception) {
227                 result.completeExceptionally(e)
228             }
229         }
230 
231         return result.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
232     }
233 }