1 /*
<lambda>null2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.android.testutils
17 
18 import android.net.DnsResolver
19 import android.net.Network
20 import android.net.nsd.NsdManager
21 import android.net.nsd.NsdServiceInfo
22 import android.os.Process
23 import com.android.net.module.util.ArrayTrackRecord
24 import com.android.net.module.util.DnsPacket
25 import com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN
26 import com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN
27 import com.android.net.module.util.NetworkStackConstants.IPV6_DST_ADDR_OFFSET
28 import com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN
29 import com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN
30 import com.android.net.module.util.TrackRecord
31 import java.net.Inet6Address
32 import java.net.InetAddress
33 import kotlin.test.assertEquals
34 import kotlin.test.assertNotNull
35 import kotlin.test.assertNull
36 import kotlin.test.assertTrue
37 import kotlin.test.fail
38 
39 private const val MDNS_REGISTRATION_TIMEOUT_MS = 10_000L
40 private const val MDNS_PORT = 5353.toShort()
41 const val MDNS_CALLBACK_TIMEOUT = 2000L
42 const val MDNS_NO_CALLBACK_TIMEOUT_MS = 200L
43 
44 interface NsdEvent
45 open class NsdRecord<T : NsdEvent> private constructor(
46     private val history: ArrayTrackRecord<T>,
47     private val expectedThreadId: Int? = null
48 ) : TrackRecord<T> by history {
49     constructor(expectedThreadId: Int? = null) : this(ArrayTrackRecord(), expectedThreadId)
50 
51     val nextEvents = history.newReadHead()
52 
53     override fun add(e: T): Boolean {
54         if (expectedThreadId != null) {
55             assertEquals(
56                 expectedThreadId, Process.myTid(),
57                 "Callback is running on the wrong thread"
58             )
59         }
60         return history.add(e)
61     }
62 
63     inline fun <reified V : NsdEvent> expectCallbackEventually(
64         timeoutMs: Long = MDNS_CALLBACK_TIMEOUT,
65         crossinline predicate: (V) -> Boolean = { true }
66     ): V = nextEvents.poll(timeoutMs) { e -> e is V && predicate(e) } as V?
67         ?: fail("Callback for ${V::class.java.simpleName} not seen after $timeoutMs ms")
68 
69     inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = MDNS_CALLBACK_TIMEOUT): V {
70         val nextEvent = nextEvents.poll(timeoutMs)
71         assertNotNull(
72             nextEvent, "No callback received after $timeoutMs ms, expected " +
73                     "${V::class.java.simpleName}"
74         )
75         assertTrue(
76             nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
77                     nextEvent.javaClass.simpleName
78         )
79         return nextEvent
80     }
81 
82     inline fun assertNoCallback(timeoutMs: Long = MDNS_NO_CALLBACK_TIMEOUT_MS) {
83         val cb = nextEvents.poll(timeoutMs)
84         assertNull(cb, "Expected no callback but got $cb")
85     }
86 }
87 
88 class NsdDiscoveryRecord(expectedThreadId: Int? = null) :
89     NsdManager.DiscoveryListener, NsdRecord<NsdDiscoveryRecord.DiscoveryEvent>(expectedThreadId) {
90     sealed class DiscoveryEvent : NsdEvent {
91         data class StartDiscoveryFailed(val serviceType: String, val errorCode: Int) :
92             DiscoveryEvent()
93 
94         data class StopDiscoveryFailed(val serviceType: String, val errorCode: Int) :
95             DiscoveryEvent()
96 
97         data class DiscoveryStarted(val serviceType: String) : DiscoveryEvent()
98         data class DiscoveryStopped(val serviceType: String) : DiscoveryEvent()
99         data class ServiceFound(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
100         data class ServiceLost(val serviceInfo: NsdServiceInfo) : DiscoveryEvent()
101     }
102 
onStartDiscoveryFailednull103     override fun onStartDiscoveryFailed(serviceType: String, err: Int) {
104         add(DiscoveryEvent.StartDiscoveryFailed(serviceType, err))
105     }
106 
onStopDiscoveryFailednull107     override fun onStopDiscoveryFailed(serviceType: String, err: Int) {
108         add(DiscoveryEvent.StopDiscoveryFailed(serviceType, err))
109     }
110 
onDiscoveryStartednull111     override fun onDiscoveryStarted(serviceType: String) {
112         add(DiscoveryEvent.DiscoveryStarted(serviceType))
113     }
114 
onDiscoveryStoppednull115     override fun onDiscoveryStopped(serviceType: String) {
116         add(DiscoveryEvent.DiscoveryStopped(serviceType))
117     }
118 
onServiceFoundnull119     override fun onServiceFound(si: NsdServiceInfo) {
120         add(DiscoveryEvent.ServiceFound(si))
121     }
122 
onServiceLostnull123     override fun onServiceLost(si: NsdServiceInfo) {
124         add(DiscoveryEvent.ServiceLost(si))
125     }
126 
waitForServiceDiscoverednull127     fun waitForServiceDiscovered(
128         serviceName: String,
129         serviceType: String,
130         expectedNetwork: Network? = null
131     ): NsdServiceInfo {
132         val serviceFound = expectCallbackEventually<DiscoveryEvent.ServiceFound> {
133             it.serviceInfo.serviceName == serviceName &&
134                     (expectedNetwork == null ||
135                             expectedNetwork == it.serviceInfo.network)
136         }.serviceInfo
137         // Discovered service types have a dot at the end
138         assertEquals("$serviceType.", serviceFound.serviceType)
139         return serviceFound
140     }
141 }
142 
143 class NsdRegistrationRecord(expectedThreadId: Int? = null) : NsdManager.RegistrationListener,
144     NsdRecord<NsdRegistrationRecord.RegistrationEvent>(expectedThreadId) {
145     sealed class RegistrationEvent : NsdEvent {
146         abstract val serviceInfo: NsdServiceInfo
147 
148         data class RegistrationFailed(
149             override val serviceInfo: NsdServiceInfo,
150             val errorCode: Int
151         ) : RegistrationEvent()
152 
153         data class UnregistrationFailed(
154             override val serviceInfo: NsdServiceInfo,
155             val errorCode: Int
156         ) : RegistrationEvent()
157 
158         data class ServiceRegistered(override val serviceInfo: NsdServiceInfo) :
159             RegistrationEvent()
160 
161         data class ServiceUnregistered(override val serviceInfo: NsdServiceInfo) :
162             RegistrationEvent()
163     }
164 
onRegistrationFailednull165     override fun onRegistrationFailed(si: NsdServiceInfo, err: Int) {
166         add(RegistrationEvent.RegistrationFailed(si, err))
167     }
168 
onUnregistrationFailednull169     override fun onUnregistrationFailed(si: NsdServiceInfo, err: Int) {
170         add(RegistrationEvent.UnregistrationFailed(si, err))
171     }
172 
onServiceRegisterednull173     override fun onServiceRegistered(si: NsdServiceInfo) {
174         add(RegistrationEvent.ServiceRegistered(si))
175     }
176 
onServiceUnregisterednull177     override fun onServiceUnregistered(si: NsdServiceInfo) {
178         add(RegistrationEvent.ServiceUnregistered(si))
179     }
180 }
181 
182 class NsdResolveRecord : NsdManager.ResolveListener,
183     NsdRecord<NsdResolveRecord.ResolveEvent>() {
184     sealed class ResolveEvent : NsdEvent {
185         data class ResolveFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
186             ResolveEvent()
187 
188         data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
189         data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent()
190         data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
191             ResolveEvent()
192     }
193 
onResolveFailednull194     override fun onResolveFailed(si: NsdServiceInfo, err: Int) {
195         add(ResolveEvent.ResolveFailed(si, err))
196     }
197 
onServiceResolvednull198     override fun onServiceResolved(si: NsdServiceInfo) {
199         add(ResolveEvent.ServiceResolved(si))
200     }
201 
onResolutionStoppednull202     override fun onResolutionStopped(si: NsdServiceInfo) {
203         add(ResolveEvent.ResolutionStopped(si))
204     }
205 
onStopResolutionFailednull206     override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) {
207         super.onStopResolutionFailed(si, err)
208         add(ResolveEvent.StopResolutionFailed(si, err))
209     }
210 }
211 
212 class NsdServiceInfoCallbackRecord : NsdManager.ServiceInfoCallback,
213     NsdRecord<NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent>() {
214     sealed class ServiceInfoCallbackEvent : NsdEvent {
215         data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent()
216         data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent()
217         object ServiceUpdatedLost : ServiceInfoCallbackEvent()
218         object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent()
219     }
220 
onServiceInfoCallbackRegistrationFailednull221     override fun onServiceInfoCallbackRegistrationFailed(err: Int) {
222         add(ServiceInfoCallbackEvent.RegisterCallbackFailed(err))
223     }
224 
onServiceUpdatednull225     override fun onServiceUpdated(si: NsdServiceInfo) {
226         add(ServiceInfoCallbackEvent.ServiceUpdated(si))
227     }
228 
onServiceLostnull229     override fun onServiceLost() {
230         add(ServiceInfoCallbackEvent.ServiceUpdatedLost)
231     }
232 
onServiceInfoCallbackUnregisterednull233     override fun onServiceInfoCallbackUnregistered() {
234         add(ServiceInfoCallbackEvent.UnregisterCallbackSucceeded)
235     }
236 }
237 
getMdnsPayloadnull238 private fun getMdnsPayload(packet: ByteArray) = packet.copyOfRange(
239     ETHER_HEADER_LEN + IPV6_HEADER_LEN + UDP_HEADER_LEN, packet.size)
240 
241 private fun getDstAddr(packet: ByteArray): Inet6Address {
242     val v6AddrPos = ETHER_HEADER_LEN + IPV6_DST_ADDR_OFFSET
243     return Inet6Address.getByAddress(packet.copyOfRange(v6AddrPos, v6AddrPos + IPV6_ADDR_LEN))
244             as Inet6Address
245 }
246 
pollForMdnsPacketnull247 fun PollPacketReader.pollForMdnsPacket(
248     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS,
249     predicate: (TestDnsPacket) -> Boolean
250 ): TestDnsPacket? {
251     val mdnsProbeFilter = IPv6UdpFilter(srcPort = MDNS_PORT, dstPort = MDNS_PORT).and {
252         val dst = getDstAddr(it)
253         val mdnsPayload = getMdnsPayload(it)
254         try {
255             predicate(TestDnsPacket(mdnsPayload, dst))
256         } catch (e: DnsPacket.ParseException) {
257             false
258         }
259     }
260     return poll(timeoutMs, mdnsProbeFilter)?.let {
261         TestDnsPacket(getMdnsPayload(it), getDstAddr(it))
262     }
263 }
264 
pollForProbenull265 fun PollPacketReader.pollForProbe(
266     serviceName: String,
267     serviceType: String,
268     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
269 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
270     it.isProbeFor("$serviceName.$serviceType.local")
271 }
272 
pollForAdvertisementnull273 fun PollPacketReader.pollForAdvertisement(
274     serviceName: String,
275     serviceType: String,
276     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
277 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
278     it.isReplyFor("$serviceName.$serviceType.local")
279 }
280 
pollForQuerynull281 fun PollPacketReader.pollForQuery(
282     recordName: String,
283     vararg requiredTypes: Int,
284     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
285 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isQueryFor(recordName, *requiredTypes) }
286 
PollPacketReadernull287 fun PollPacketReader.pollForReply(
288     recordName: String,
289     type: Int,
290     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
291 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) { it.isReplyFor(recordName, type) }
292 
PollPacketReadernull293 fun PollPacketReader.pollForReply(
294     serviceName: String,
295     serviceType: String,
296     timeoutMs: Long = MDNS_REGISTRATION_TIMEOUT_MS
297 ): TestDnsPacket? = pollForMdnsPacket(timeoutMs) {
298     it.isReplyFor("$serviceName.$serviceType.local")
299 }
300 
301 class TestDnsPacket(data: ByteArray, val dstAddr: InetAddress) : DnsPacket(data) {
302     val header: DnsHeader
303         get() = mHeader
304     val records: Array<List<DnsRecord>>
305         get() = mRecords
<lambda>null306     fun isProbeFor(name: String): Boolean = mRecords[QDSECTION].any {
307         it.dName == name && it.nsType == DnsResolver.TYPE_ANY
308     }
309 
isReplyFornull310     fun isReplyFor(name: String, type: Int = DnsResolver.TYPE_SRV): Boolean =
311         mRecords[ANSECTION].any {
312             it.dName == name && it.nsType == type
313         }
314 
isQueryFornull315     fun isQueryFor(name: String, vararg requiredTypes: Int): Boolean = requiredTypes.all { type ->
316         mRecords[QDSECTION].any {
317             it.dName == name && it.nsType == type
318         }
319     }
320 }
321