1 /*
<lambda>null2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 package com.android.testutils
17
18 import android.net.DnsResolver
19 import android.net.Network
20 import android.net.nsd.NsdManager
21 import android.net.nsd.NsdServiceInfo
22 import android.os.Process
23 import com.android.net.module.util.ArrayTrackRecord
24 import com.android.net.module.util.DnsPacket
25 import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
26 import com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN
27 import com.android.net.module.util.NetworkStackConstants.IPV6_DST_ADDR_OFFSET
28 import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
29 import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
30 import com.android.net.module.util.TrackRecord
31 import java.net.Inet6Address
32 import java.net.InetAddress
33 import kotlin.test.assertEquals
34 import kotlin.test.assertNotNull
35 import kotlin.test.assertNull
36 import kotlin.test.assertTrue
37 import kotlin.test.fail
38
39 private const val MDNS_REGISTRATION_TIMEOUT_MS = 10_000L
40 private const val MDNS_PORT = 5353.toShort()
41 const val MDNS_CALLBACK_TIMEOUT = 2000L
42 const val MDNS_NO_CALLBACK_TIMEOUT_MS = 200L
43
44 interface NsdEvent
45 open class NsdRecord<T : NsdEvent> private constructor(
46 private val history: ArrayTrackRecord<T>,
47 private val expectedThreadId: Int? = null
48 ) : TrackRecord<T> by history {
49 constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
50
51 val nextEvents = history.newReadHead()
52
53 override fun add(e: T): Boolean {
54 if (expectedThreadId != null) {
55 assertEquals(
56 expectedThreadId, Process.myTid(),
57 "Callback is running on the wrong thread"
58 )
59 }
60 return history.add(e)
61 }
62
63 inline fun <reified V : NsdEvent> expectCallbackEventually(
64 timeoutMs: Long = MDNS_CALLBACK_TIMEOUT,
65 crossinline predicate: (V) -> Boolean = { true }
66 ): V = nextEvents.poll(timeoutMs) { e -> e is V && predicate(e) } as V?
67 ?: fail("Callback for ${V::class.java.simpleName} not seen after $timeoutMs ms")
68
69 inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = MDNS_CALLBACK_TIMEOUT): V {
70 val nextEvent = nextEvents.poll(timeoutMs)
71 assertNotNull(
72 nextEvent, "No callback received after $timeoutMs ms, expected " +
73 "${V::class.java.simpleName}"
74 )
75 assertTrue(
76 nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
77 nextEvent.javaClass.simpleName
78 )
79 return nextEvent
80 }
81
82 inline fun assertNoCallback(timeoutMs: Long = MDNS_NO_CALLBACK_TIMEOUT_MS) {
83 val cb = nextEvents.poll(timeoutMs)
84 assertNull(cb, "Expected no callback but got $cb")
85 }
86 }
87
88 class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
89 NsdManager.DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
90 sealed class DiscoveryEvent : NsdEvent {
91 data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) :
92 DiscoveryEvent()
93
94 data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int) :
95 DiscoveryEvent()
96
97 data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent()
98 data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent()
99 data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
100 data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
101 }
102
onStartDiscoveryFailednull103 override fun onStartDiscoveryFailed(serviceType: String, err: Int) {
104 add(DiscoveryEvent.StartDiscoveryFailed(serviceType, err))
105 }
106
onStopDiscoveryFailednull107 override fun onStopDiscoveryFailed(serviceType: String, err: Int) {
108 add(DiscoveryEvent.StopDiscoveryFailed(serviceType, err))
109 }
110
onDiscoveryStartednull111 override fun onDiscoveryStarted(serviceType: String) {
112 add(DiscoveryEvent.DiscoveryStarted(serviceType))
113 }
114
onDiscoveryStoppednull115 override fun onDiscoveryStopped(serviceType: String) {
116 add(DiscoveryEvent.DiscoveryStopped(serviceType))
117 }
118
onServiceFoundnull119 override fun onServiceFound(si: NsdServiceInfo) {
120 add(DiscoveryEvent.ServiceFound(si))
121 }
122
onServiceLostnull123 override fun onServiceLost(si: NsdServiceInfo) {
124 add(DiscoveryEvent.ServiceLost(si))
125 }
126
waitForServiceDiscoverednull127 fun waitForServiceDiscovered(
128 serviceName: String,
129 serviceType: String,
130 expectedNetwork: Network? = null
131 ): NsdServiceInfo {
132 val serviceFound = expectCallbackEventually<DiscoveryEvent.ServiceFound> {
133 it.serviceInfo.serviceName == serviceName &&
134 (expectedNetwork == null ||
135 expectedNetwork == it.serviceInfo.network)
136 }.serviceInfo
137 // Discovered service types have a dot at the end
138 assertEquals("$serviceType.", serviceFound.serviceType)
139 return serviceFound
140 }
141 }
142
143 class NsdRegistrationRecord(expectedThreadId: Int? = null) : NsdManager.RegistrationListener,
144 NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
145 sealed class RegistrationEvent : NsdEvent {
146 abstract val serviceInfo: NsdServiceInfo
147
148 data class RegistrationFailed(
149 override val serviceInfo: NsdServiceInfo,
150 val errorCode: Int
151 ) : RegistrationEvent()
152
153 data class UnregistrationFailed(
154 override val serviceInfo: NsdServiceInfo,
155 val errorCode: Int
156 ) : RegistrationEvent()
157
158 data class ServiceRegistered(override val serviceInfo: NsdServiceInfo) :
159 RegistrationEvent()
160
161 data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo) :
162 RegistrationEvent()
163 }
164
onRegistrationFailednull165 override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) {
166 add(RegistrationEvent.RegistrationFailed(si, err))
167 }
168
onUnregistrationFailednull169 override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) {
170 add(RegistrationEvent.UnregistrationFailed(si, err))
171 }
172
onServiceRegisterednull173 override fun onServiceRegistered(si: NsdServiceInfo) {
174 add(RegistrationEvent.ServiceRegistered(si))
175 }
176
onServiceUnregisterednull177 override fun onServiceUnregistered(si: NsdServiceInfo) {
178 add(RegistrationEvent.ServiceUnregistered(si))
179 }
180 }
181
182 class NsdResolveRecord : NsdManager.ResolveListener,
183 NsdRecord<NsdResolveRecord.ResolveEvent>() {
184 sealed class ResolveEvent : NsdEvent {
185 data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
186 ResolveEvent()
187
188 data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
189 data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent()
190 data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
191 ResolveEvent()
192 }
193
onResolveFailednull194 override fun onResolveFailed(si: NsdServiceInfo, err: Int) {
195 add(ResolveEvent.ResolveFailed(si, err))
196 }
197
onServiceResolvednull198 override fun onServiceResolved(si: NsdServiceInfo) {
199 add(ResolveEvent.ServiceResolved(si))
200 }
201
onResolutionStoppednull202 override fun onResolutionStopped(si: NsdServiceInfo) {
203 add(ResolveEvent.ResolutionStopped(si))
204 }
205
onStopResolutionFailednull206 override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) {
207 super.onStopResolutionFailed(si, err)
208 add(ResolveEvent.StopResolutionFailed(si, err))
209 }
210 }
211
212 class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback,
213 NsdRecord<NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent>() {
214 sealed class ServiceInfoCallbackEvent : NsdEvent {
215 data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent()
216 data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent()
217 object ServiceUpdatedLost : ServiceInfoCallbackEvent()
218 object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent()
219 }
220
onServiceInfoCallbackRegistrationFailednull221 override fun onServiceInfoCallbackRegistrationFailed(err: Int) {
222 add(ServiceInfoCallbackEvent.RegisterCallbackFailed(err))
223 }
224
onServiceUpdatednull225 override fun onServiceUpdated(si: NsdServiceInfo) {
226 add(ServiceInfoCallbackEvent.ServiceUpdated(si))
227 }
228
onServiceLostnull229 override fun onServiceLost() {
230 add(ServiceInfoCallbackEvent.ServiceUpdatedLost)
231 }
232
onServiceInfoCallbackUnregisterednull233 override fun onServiceInfoCallbackUnregistered() {
234 add(ServiceInfoCallbackEvent.UnregisterCallbackSucceeded)
235 }
236 }
237
getMdnsPayloadnull238 private fun getMdnsPayload(packet: ByteArray) = packet.copyOfRange(
239 ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, packet.size)
240
241 private fun getDstAddr(packet: ByteArray): Inet6Address {
242 val v6AddrPos = ETHER_HEADER_LEN + IPV6_DST_ADDR_OFFSET
243 return Inet6Address.getByAddress(packet.copyOfRange(v6AddrPos, v6AddrPos + IPV6_ADDR_LEN))
244 as Inet6Address
245 }
246
pollForMdnsPacketnull247 fun PollPacketReader.pollForMdnsPacket(
248 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS,
249 predicate: (TestDnsPacket) -> Boolean
250 ): TestDnsPacket? {
251 val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
252 val dst = getDstAddr(it)
253 val mdnsPayload = getMdnsPayload(it)
254 try {
255 predicate(TestDnsPacket(mdnsPayload, dst))
256 } catch (e: DnsPacket.ParseException) {
257 false
258 }
259 }
260 return poll(timeoutMs, mdnsProbeFilter)?.let {
261 TestDnsPacket(getMdnsPayload(it), getDstAddr(it))
262 }
263 }
264
pollForProbenull265 fun PollPacketReader.pollForProbe(
266 serviceName: String,
267 serviceType: String,
268 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
269 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
270 it.isProbeFor("$serviceName.$serviceType.local")
271 }
272
pollForAdvertisementnull273 fun PollPacketReader.pollForAdvertisement(
274 serviceName: String,
275 serviceType: String,
276 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
277 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
278 it.isReplyFor("$serviceName.$serviceType.local")
279 }
280
pollForQuerynull281 fun PollPacketReader.pollForQuery(
282 recordName: String,
283 vararg requiredTypes: Int,
284 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
285 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, *requiredTypes) }
286
PollPacketReadernull287 fun PollPacketReader.pollForReply(
288 recordName: String,
289 type: Int,
290 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
291 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isReplyFor(recordName, type) }
292
PollPacketReadernull293 fun PollPacketReader.pollForReply(
294 serviceName: String,
295 serviceType: String,
296 timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
297 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
298 it.isReplyFor("$serviceName.$serviceType.local")
299 }
300
301 class TestDnsPacket(data: ByteArray, val dstAddr: InetAddress) : DnsPacket(data) {
302 val header: DnsHeader
303 get() = mHeader
304 val records: Array<List<DnsRecord>>
305 get() = mRecords
<lambda>null306 fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
307 it.dName == name && it.nsType == DnsResolver.TYPE_ANY
308 }
309
isReplyFornull310 fun isReplyFor(name: String, type: Int = DnsResolver.TYPE_SRV): Boolean =
311 mRecords[ANSECTION].any {
312 it.dName == name && it.nsType == type
313 }
314
isQueryFornull315 fun isQueryFor(name: String, vararg requiredTypes: Int): Boolean = requiredTypes.all { type ->
316 mRecords[QDSECTION].any {
317 it.dName == name && it.nsType == type
318 }
319 }
320 }
321