1 /* 2 * Copyright (C) 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package android.net.apf 17 18 import android.net.apf.ApfCounterTracker.Counter 19 import android.net.apf.ApfCounterTracker.Counter.APF_PROGRAM_ID 20 import android.net.apf.ApfCounterTracker.Counter.APF_VERSION 21 import android.net.apf.ApfCounterTracker.Counter.TOTAL_PACKETS 22 import android.net.apf.BaseApfGenerator.APF_VERSION_6 23 import android.net.ip.IpClient 24 import com.android.net.module.util.HexDump 25 import kotlin.test.assertEquals 26 import org.mockito.ArgumentCaptor 27 import org.mockito.Mockito.clearInvocations 28 import org.mockito.Mockito.timeout 29 import org.mockito.Mockito.verify 30 31 class ApfTestHelpers private constructor() { 32 companion object { 33 const val TIMEOUT_MS: Long = 1000 34 const val PASS: Int = 1 35 const val DROP: Int = 0 36 37 // Interpreter will just accept packets without link layer headers, so pad fake packet to at 38 // least the minimum packet size. 39 const val MIN_PKT_SIZE: Int = 15 labelnull40 private fun label(code: Int): String { 41 return when (code) { 42 PASS -> "PASS" 43 DROP -> "DROP" 44 else -> "UNKNOWN" 45 } 46 } 47 assertReturnCodesEqualnull48 private fun assertReturnCodesEqual(msg: String, expected: Int, got: Int) { 49 assertEquals(label(expected), label(got), msg) 50 } 51 assertReturnCodesEqualnull52 private fun assertReturnCodesEqual(expected: Int, got: Int) { 53 assertEquals(label(expected), label(got)) 54 } 55 assertVerdictnull56 private fun assertVerdict( 57 apfVersion: Int, 58 expected: Int, 59 program: ByteArray, 60 packet: ByteArray, 61 filterAge: Int 62 ) { 63 val msg = """Unexpected APF verdict. To debug: 64 apf_run 65 --program ${HexDump.toHexString(program)} 66 --packet ${HexDump.toHexString(packet)} 67 --age $filterAge 68 ${if (apfVersion > 4) " --v6" else ""} 69 --trace " + " | less\n 70 """ 71 assertReturnCodesEqual( 72 msg, 73 expected, 74 ApfJniUtils.apfSimulate(apfVersion, program, packet, null, filterAge) 75 ) 76 } 77 78 @Throws(BaseApfGenerator.IllegalInstructionException::class) assertVerdictnull79 private fun assertVerdict( 80 apfVersion: Int, 81 expected: Int, 82 gen: ApfV4Generator, 83 packet: ByteArray, 84 filterAge: Int 85 ) { 86 assertVerdict(apfVersion, expected, gen.generate(), packet, null, filterAge) 87 } 88 assertVerdictnull89 private fun assertVerdict( 90 apfVersion: Int, 91 expected: Int, 92 program: ByteArray, 93 packet: ByteArray, 94 data: ByteArray?, 95 filterAge: Int 96 ) { 97 val msg = """Unexpected APF verdict. To debug: 98 apf_run 99 --program ${HexDump.toHexString(program)} 100 --packet ${HexDump.toHexString(packet)} 101 ${if (data != null) "--data ${HexDump.toHexString(data)}" else ""} 102 --age $filterAge 103 ${if (apfVersion > 4) "--v6" else ""} 104 --trace | less 105 """ 106 assertReturnCodesEqual( 107 msg, 108 expected, 109 ApfJniUtils.apfSimulate(apfVersion, program, packet, data, filterAge) 110 ) 111 } 112 113 /** 114 * Runs the APF program with customized data region and checks the return code. 115 */ assertVerdictnull116 fun assertVerdict( 117 apfVersion: Int, 118 expected: Int, 119 program: ByteArray, 120 packet: ByteArray, 121 data: ByteArray? 122 ) { 123 assertVerdict(apfVersion, expected, program, packet, data, filterAge = 0) 124 } 125 126 /** 127 * Runs the APF program and checks the return code is equals to expected value. If not, the 128 * customized message is printed. 129 */ 130 @JvmStatic assertVerdictnull131 fun assertVerdict( 132 apfVersion: Int, 133 msg: String, 134 expected: Int, 135 program: ByteArray?, 136 packet: ByteArray?, 137 filterAge: Int 138 ) { 139 assertReturnCodesEqual( 140 msg, 141 expected, 142 ApfJniUtils.apfSimulate(apfVersion, program, packet, null, filterAge) 143 ) 144 } 145 146 /** 147 * Runs the APF program and checks the return code is equals to expected value. 148 */ 149 @JvmStatic assertVerdictnull150 fun assertVerdict(apfVersion: Int, expected: Int, program: ByteArray, packet: ByteArray) { 151 assertVerdict(apfVersion, expected, program, packet, 0) 152 } 153 154 /** 155 * Runs the APF program and checks the return code is PASS. 156 */ 157 @JvmStatic assertPassnull158 fun assertPass(apfVersion: Int, program: ByteArray, packet: ByteArray, filterAge: Int) { 159 assertVerdict(apfVersion, PASS, program, packet, filterAge) 160 } 161 162 /** 163 * Runs the APF program and checks the return code is PASS. 164 */ 165 @JvmStatic assertPassnull166 fun assertPass(apfVersion: Int, program: ByteArray, packet: ByteArray) { 167 assertVerdict(apfVersion, PASS, program, packet) 168 } 169 170 /** 171 * Runs the APF program and checks the return code is DROP. 172 */ 173 @JvmStatic assertDropnull174 fun assertDrop(apfVersion: Int, program: ByteArray, packet: ByteArray, filterAge: Int) { 175 assertVerdict(apfVersion, DROP, program, packet, filterAge) 176 } 177 178 /** 179 * Runs the APF program and checks the return code is DROP. 180 */ 181 @JvmStatic assertDropnull182 fun assertDrop(apfVersion: Int, program: ByteArray, packet: ByteArray) { 183 assertVerdict(apfVersion, DROP, program, packet) 184 } 185 186 /** 187 * Runs the APF program and checks the return code is PASS. 188 */ 189 @Throws(BaseApfGenerator.IllegalInstructionException::class) 190 @JvmStatic assertPassnull191 fun assertPass(apfVersion: Int, gen: ApfV4Generator, packet: ByteArray, filterAge: Int) { 192 assertVerdict(apfVersion, PASS, gen, packet, filterAge) 193 } 194 195 /** 196 * Runs the APF program and checks the return code is DROP. 197 */ 198 @Throws(BaseApfGenerator.IllegalInstructionException::class) 199 @JvmStatic assertDropnull200 fun assertDrop(apfVersion: Int, gen: ApfV4Generator, packet: ByteArray, filterAge: Int) { 201 assertVerdict(apfVersion, DROP, gen, packet, filterAge) 202 } 203 204 /** 205 * Runs the APF program and checks the return code is PASS. 206 */ 207 @Throws(BaseApfGenerator.IllegalInstructionException::class) 208 @JvmStatic assertPassnull209 fun assertPass(apfVersion: Int, gen: ApfV4Generator) { 210 assertVerdict(apfVersion, PASS, gen, ByteArray(MIN_PKT_SIZE), 0) 211 } 212 213 /** 214 * Runs the APF program and checks the return code is DROP. 215 */ 216 @Throws(BaseApfGenerator.IllegalInstructionException::class) 217 @JvmStatic assertDropnull218 fun assertDrop(apfVersion: Int, gen: ApfV4Generator) { 219 assertVerdict(apfVersion, DROP, gen, ByteArray(MIN_PKT_SIZE), 0) 220 } 221 222 /** 223 * Checks the generated APF program equals to the expected value. 224 */ 225 @Throws(AssertionError::class) 226 @JvmStatic assertProgramEqualsnull227 fun assertProgramEquals(expected: ByteArray, program: ByteArray?) { 228 // assertArrayEquals() would only print one byte, making debugging difficult. 229 if (!expected.contentEquals(program)) { 230 throw AssertionError( 231 "\nexpected: " + HexDump.toHexString(expected) + 232 "\nactual: " + HexDump.toHexString(program) 233 ) 234 } 235 } 236 237 /** 238 * Runs the APF program and checks the return code and data regions 239 * equals to expected value. 240 */ 241 @Throws(BaseApfGenerator.IllegalInstructionException::class, Exception::class) 242 @JvmStatic assertDataMemoryContentsnull243 fun assertDataMemoryContents( 244 apfVersion: Int, 245 expected: Int, 246 program: ByteArray?, 247 packet: ByteArray?, 248 data: ByteArray, 249 expectedData: ByteArray, 250 ignoreInterpreterVersion: Boolean 251 ) { 252 assertReturnCodesEqual( 253 expected, 254 ApfJniUtils.apfSimulate(apfVersion, program, packet, data, 0) 255 ) 256 257 if (ignoreInterpreterVersion) { 258 val apfVersionIdx = (Counter.totalSize() + 259 APF_VERSION.offset()) 260 val apfProgramIdIdx = (Counter.totalSize() + 261 APF_PROGRAM_ID.offset()) 262 for (i in 0..3) { 263 data[apfVersionIdx + i] = 0 264 data[apfProgramIdIdx + i] = 0 265 } 266 } 267 // assertArrayEquals() would only print one byte, making debugging difficult. 268 if (!expectedData.contentEquals(data)) { 269 throw Exception( 270 ("\nprogram: " + HexDump.toHexString(program) + 271 "\ndata memory: " + HexDump.toHexString(data) + 272 "\nexpected: " + HexDump.toHexString(expectedData)) 273 ) 274 } 275 } 276 verifyProgramRunnull277 fun verifyProgramRun( 278 version: Int, 279 program: ByteArray, 280 pkt: ByteArray, 281 targetCnt: Counter, 282 cntMap: MutableMap<Counter, Long> = mutableMapOf(), 283 dataRegion: ByteArray = ByteArray(Counter.totalSize()) { 0 }, 284 incTotal: Boolean = true, 285 result: Int = if (targetCnt.name.startsWith("PASSED")) PASS else DROP 286 ) { 287 assertVerdict(version, result, program, pkt, dataRegion) 288 cntMap[targetCnt] = cntMap.getOrDefault(targetCnt, 0) + 1 289 if (incTotal) { 290 cntMap[TOTAL_PACKETS] = cntMap.getOrDefault(TOTAL_PACKETS, 0) + 1 291 } 292 val errMsg = "Counter is not increased properly. To debug: \n" + 293 " apf_run --program ${HexDump.toHexString(program)} " + 294 "--packet ${HexDump.toHexString(pkt)} " + 295 "--data ${HexDump.toHexString(dataRegion)} --age 0 " + 296 "${if (version == APF_VERSION_6) "--v6" else "" } --trace | less \n" 297 assertEquals(cntMap, decodeCountersIntoMap(dataRegion), errMsg) 298 } 299 decodeCountersIntoMapnull300 fun decodeCountersIntoMap(counterBytes: ByteArray): Map<Counter, Long> { 301 val counters = Counter::class.java.enumConstants 302 val ret = HashMap<Counter, Long>() 303 val skippedCounters = setOf(APF_PROGRAM_ID, APF_VERSION) 304 // starting from index 2 to skip the endianness mark 305 if (counters != null) { 306 for (c in listOf(*counters).subList(2, counters.size)) { 307 if (c in skippedCounters) continue 308 val value = ApfCounterTracker.getCounterValue(counterBytes, c) 309 if (value != 0L) { 310 ret[c] = value 311 } 312 } 313 } 314 return ret 315 } 316 317 @JvmStatic consumeInstalledProgramnull318 fun consumeInstalledProgram( 319 ipClientCb: IpClient.IpClientCallbacksWrapper, 320 installCnt: Int 321 ): ByteArray { 322 val programCaptor = ArgumentCaptor.forClass( 323 ByteArray::class.java 324 ) 325 326 verify(ipClientCb, timeout(TIMEOUT_MS).times(installCnt)).installPacketFilter( 327 programCaptor.capture() 328 ) 329 330 clearInvocations<Any>(ipClientCb) 331 return programCaptor.value 332 } 333 consumeTransmittedPacketsnull334 fun consumeTransmittedPackets( 335 expectCnt: Int 336 ): List<ByteArray> { 337 val transmittedPackets = ApfJniUtils.getAllTransmittedPackets() 338 assertEquals(expectCnt, transmittedPackets.size) 339 ApfJniUtils.resetTransmittedPacketMemory() 340 return transmittedPackets 341 } 342 } 343 } 344