1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package android.net.apf
17 
18 import android.net.apf.ApfCounterTracker.Counter
19 import android.net.apf.ApfCounterTracker.Counter.APF_PROGRAM_ID
20 import android.net.apf.ApfCounterTracker.Counter.APF_VERSION
21 import android.net.apf.ApfCounterTracker.Counter.TOTAL_PACKETS
22 import android.net.apf.BaseApfGenerator.APF_VERSION_6
23 import android.net.ip.IpClient
24 import com.android.net.module.util.HexDump
25 import kotlin.test.assertEquals
26 import org.mockito.ArgumentCaptor
27 import org.mockito.Mockito.clearInvocations
28 import org.mockito.Mockito.timeout
29 import org.mockito.Mockito.verify
30 
31 class ApfTestHelpers private constructor() {
32     companion object {
33         const val TIMEOUT_MS: Long = 1000
34         const val PASS: Int = 1
35         const val DROP: Int = 0
36 
37         // Interpreter will just accept packets without link layer headers, so pad fake packet to at
38         // least the minimum packet size.
39         const val MIN_PKT_SIZE: Int = 15
labelnull40         private fun label(code: Int): String {
41             return when (code) {
42                 PASS -> "PASS"
43                 DROP -> "DROP"
44                 else -> "UNKNOWN"
45             }
46         }
47 
assertReturnCodesEqualnull48         private fun assertReturnCodesEqual(msg: String, expected: Int, got: Int) {
49             assertEquals(label(expected), label(got), msg)
50         }
51 
assertReturnCodesEqualnull52         private fun assertReturnCodesEqual(expected: Int, got: Int) {
53             assertEquals(label(expected), label(got))
54         }
55 
assertVerdictnull56         private fun assertVerdict(
57             apfVersion: Int,
58             expected: Int,
59             program: ByteArray,
60             packet: ByteArray,
61             filterAge: Int
62         ) {
63             val msg = """Unexpected APF verdict. To debug:
64                 apf_run
65                     --program ${HexDump.toHexString(program)}
66                     --packet ${HexDump.toHexString(packet)}
67                     --age $filterAge
68                     ${if (apfVersion > 4) " --v6" else ""}
69                     --trace " + " | less\n
70             """
71             assertReturnCodesEqual(
72                 msg,
73                 expected,
74                 ApfJniUtils.apfSimulate(apfVersion, program, packet, null, filterAge)
75             )
76         }
77 
78         @Throws(BaseApfGenerator.IllegalInstructionException::class)
assertVerdictnull79         private fun assertVerdict(
80             apfVersion: Int,
81             expected: Int,
82             gen: ApfV4Generator,
83             packet: ByteArray,
84             filterAge: Int
85         ) {
86             assertVerdict(apfVersion, expected, gen.generate(), packet, null, filterAge)
87         }
88 
assertVerdictnull89         private fun assertVerdict(
90             apfVersion: Int,
91             expected: Int,
92             program: ByteArray,
93             packet: ByteArray,
94             data: ByteArray?,
95             filterAge: Int
96         ) {
97             val msg = """Unexpected APF verdict. To debug:
98                 apf_run
99                     --program ${HexDump.toHexString(program)}
100                     --packet ${HexDump.toHexString(packet)}
101                     ${if (data != null) "--data ${HexDump.toHexString(data)}" else ""}
102                     --age $filterAge
103                     ${if (apfVersion > 4) "--v6" else ""}
104                     --trace | less
105             """
106             assertReturnCodesEqual(
107                 msg,
108                 expected,
109                 ApfJniUtils.apfSimulate(apfVersion, program, packet, data, filterAge)
110             )
111         }
112 
113         /**
114          * Runs the APF program with customized data region and checks the return code.
115          */
assertVerdictnull116         fun assertVerdict(
117             apfVersion: Int,
118             expected: Int,
119             program: ByteArray,
120             packet: ByteArray,
121             data: ByteArray?
122         ) {
123             assertVerdict(apfVersion, expected, program, packet, data, filterAge = 0)
124         }
125 
126         /**
127          * Runs the APF program and checks the return code is equals to expected value. If not, the
128          * customized message is printed.
129          */
130         @JvmStatic
assertVerdictnull131         fun assertVerdict(
132             apfVersion: Int,
133             msg: String,
134             expected: Int,
135             program: ByteArray?,
136             packet: ByteArray?,
137             filterAge: Int
138         ) {
139             assertReturnCodesEqual(
140                 msg,
141                 expected,
142                 ApfJniUtils.apfSimulate(apfVersion, program, packet, null, filterAge)
143             )
144         }
145 
146         /**
147          * Runs the APF program and checks the return code is equals to expected value.
148          */
149         @JvmStatic
assertVerdictnull150         fun assertVerdict(apfVersion: Int, expected: Int, program: ByteArray, packet: ByteArray) {
151             assertVerdict(apfVersion, expected, program, packet, 0)
152         }
153 
154         /**
155          * Runs the APF program and checks the return code is PASS.
156          */
157         @JvmStatic
assertPassnull158         fun assertPass(apfVersion: Int, program: ByteArray, packet: ByteArray, filterAge: Int) {
159             assertVerdict(apfVersion, PASS, program, packet, filterAge)
160         }
161 
162         /**
163          * Runs the APF program and checks the return code is PASS.
164          */
165         @JvmStatic
assertPassnull166         fun assertPass(apfVersion: Int, program: ByteArray, packet: ByteArray) {
167             assertVerdict(apfVersion, PASS, program, packet)
168         }
169 
170         /**
171          * Runs the APF program and checks the return code is DROP.
172          */
173         @JvmStatic
assertDropnull174         fun assertDrop(apfVersion: Int, program: ByteArray, packet: ByteArray, filterAge: Int) {
175             assertVerdict(apfVersion, DROP, program, packet, filterAge)
176         }
177 
178         /**
179          * Runs the APF program and checks the return code is DROP.
180          */
181         @JvmStatic
assertDropnull182         fun assertDrop(apfVersion: Int, program: ByteArray, packet: ByteArray) {
183             assertVerdict(apfVersion, DROP, program, packet)
184         }
185 
186         /**
187          * Runs the APF program and checks the return code is PASS.
188          */
189         @Throws(BaseApfGenerator.IllegalInstructionException::class)
190         @JvmStatic
assertPassnull191         fun assertPass(apfVersion: Int, gen: ApfV4Generator, packet: ByteArray, filterAge: Int) {
192             assertVerdict(apfVersion, PASS, gen, packet, filterAge)
193         }
194 
195         /**
196          * Runs the APF program and checks the return code is DROP.
197          */
198         @Throws(BaseApfGenerator.IllegalInstructionException::class)
199         @JvmStatic
assertDropnull200         fun assertDrop(apfVersion: Int, gen: ApfV4Generator, packet: ByteArray, filterAge: Int) {
201             assertVerdict(apfVersion, DROP, gen, packet, filterAge)
202         }
203 
204         /**
205          * Runs the APF program and checks the return code is PASS.
206          */
207         @Throws(BaseApfGenerator.IllegalInstructionException::class)
208         @JvmStatic
assertPassnull209         fun assertPass(apfVersion: Int, gen: ApfV4Generator) {
210             assertVerdict(apfVersion, PASS, gen, ByteArray(MIN_PKT_SIZE), 0)
211         }
212 
213         /**
214          * Runs the APF program and checks the return code is DROP.
215          */
216         @Throws(BaseApfGenerator.IllegalInstructionException::class)
217         @JvmStatic
assertDropnull218         fun assertDrop(apfVersion: Int, gen: ApfV4Generator) {
219             assertVerdict(apfVersion, DROP, gen, ByteArray(MIN_PKT_SIZE), 0)
220         }
221 
222         /**
223          * Checks the generated APF program equals to the expected value.
224          */
225         @Throws(AssertionError::class)
226         @JvmStatic
assertProgramEqualsnull227         fun assertProgramEquals(expected: ByteArray, program: ByteArray?) {
228             // assertArrayEquals() would only print one byte, making debugging difficult.
229             if (!expected.contentEquals(program)) {
230                 throw AssertionError(
231                     "\nexpected: " + HexDump.toHexString(expected) +
232                     "\nactual:   " + HexDump.toHexString(program)
233                 )
234             }
235         }
236 
237         /**
238          * Runs the APF program and checks the return code and data regions
239          * equals to expected value.
240          */
241         @Throws(BaseApfGenerator.IllegalInstructionException::class, Exception::class)
242         @JvmStatic
assertDataMemoryContentsnull243         fun assertDataMemoryContents(
244             apfVersion: Int,
245             expected: Int,
246             program: ByteArray?,
247             packet: ByteArray?,
248             data: ByteArray,
249             expectedData: ByteArray,
250             ignoreInterpreterVersion: Boolean
251         ) {
252             assertReturnCodesEqual(
253                 expected,
254                 ApfJniUtils.apfSimulate(apfVersion, program, packet, data, 0)
255             )
256 
257             if (ignoreInterpreterVersion) {
258                 val apfVersionIdx = (Counter.totalSize() +
259                         APF_VERSION.offset())
260                 val apfProgramIdIdx = (Counter.totalSize() +
261                         APF_PROGRAM_ID.offset())
262                 for (i in 0..3) {
263                     data[apfVersionIdx + i] = 0
264                     data[apfProgramIdIdx + i] = 0
265                 }
266             }
267             // assertArrayEquals() would only print one byte, making debugging difficult.
268             if (!expectedData.contentEquals(data)) {
269                 throw Exception(
270                     ("\nprogram:     " + HexDump.toHexString(program) +
271                      "\ndata memory: " + HexDump.toHexString(data) +
272                      "\nexpected:    " + HexDump.toHexString(expectedData))
273                 )
274             }
275         }
276 
verifyProgramRunnull277         fun verifyProgramRun(
278             version: Int,
279             program: ByteArray,
280             pkt: ByteArray,
281             targetCnt: Counter,
282             cntMap: MutableMap<Counter, Long> = mutableMapOf(),
283             dataRegion: ByteArray = ByteArray(Counter.totalSize()) { 0 },
284             incTotal: Boolean = true,
285             result: Int = if (targetCnt.name.startsWith("PASSED")) PASS else DROP
286         ) {
287             assertVerdict(version, result, program, pkt, dataRegion)
288             cntMap[targetCnt] = cntMap.getOrDefault(targetCnt, 0) + 1
289             if (incTotal) {
290                 cntMap[TOTAL_PACKETS] = cntMap.getOrDefault(TOTAL_PACKETS, 0) + 1
291             }
292             val errMsg = "Counter is not increased properly. To debug: \n" +
293                     " apf_run --program ${HexDump.toHexString(program)} " +
294                     "--packet ${HexDump.toHexString(pkt)} " +
295                     "--data ${HexDump.toHexString(dataRegion)} --age 0 " +
296                     "${if (version == APF_VERSION_6) "--v6" else "" } --trace  | less \n"
297             assertEquals(cntMap, decodeCountersIntoMap(dataRegion), errMsg)
298         }
299 
decodeCountersIntoMapnull300         fun decodeCountersIntoMap(counterBytes: ByteArray): Map<Counter, Long> {
301             val counters = Counter::class.java.enumConstants
302             val ret = HashMap<Counter, Long>()
303             val skippedCounters = setOf(APF_PROGRAM_ID, APF_VERSION)
304             // starting from index 2 to skip the endianness mark
305             if (counters != null) {
306                 for (c in listOf(*counters).subList(2, counters.size)) {
307                     if (c in skippedCounters) continue
308                     val value = ApfCounterTracker.getCounterValue(counterBytes, c)
309                     if (value != 0L) {
310                         ret[c] = value
311                     }
312                 }
313             }
314             return ret
315         }
316 
317         @JvmStatic
consumeInstalledProgramnull318         fun consumeInstalledProgram(
319             ipClientCb: IpClient.IpClientCallbacksWrapper,
320             installCnt: Int
321         ): ByteArray {
322             val programCaptor = ArgumentCaptor.forClass(
323                 ByteArray::class.java
324             )
325 
326             verify(ipClientCb, timeout(TIMEOUT_MS).times(installCnt)).installPacketFilter(
327                 programCaptor.capture()
328             )
329 
330             clearInvocations<Any>(ipClientCb)
331             return programCaptor.value
332         }
333 
consumeTransmittedPacketsnull334         fun consumeTransmittedPackets(
335             expectCnt: Int
336         ): List<ByteArray> {
337             val transmittedPackets = ApfJniUtils.getAllTransmittedPackets()
338             assertEquals(expectCnt, transmittedPackets.size)
339             ApfJniUtils.resetTransmittedPacketMemory()
340             return transmittedPackets
341         }
342     }
343 }
344