1 /*
<lambda>null2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License")
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity.mdns.util
18 
19 import android.net.InetAddresses
20 import android.os.Build
21 import com.android.server.connectivity.mdns.MdnsConstants
22 import com.android.server.connectivity.mdns.MdnsConstants.FLAG_TRUNCATED
23 import com.android.server.connectivity.mdns.MdnsConstants.IPV4_SOCKET_ADDR
24 import com.android.server.connectivity.mdns.MdnsConstants.IPV6_SOCKET_ADDR
25 import com.android.server.connectivity.mdns.MdnsInetAddressRecord
26 import com.android.server.connectivity.mdns.MdnsPacket
27 import com.android.server.connectivity.mdns.MdnsPacketReader
28 import com.android.server.connectivity.mdns.MdnsPointerRecord
29 import com.android.server.connectivity.mdns.MdnsRecord
30 import com.android.server.connectivity.mdns.MdnsResponse
31 import com.android.server.connectivity.mdns.MdnsServiceInfo
32 import com.android.server.connectivity.mdns.MdnsServiceRecord
33 import com.android.server.connectivity.mdns.MdnsTextRecord
34 import com.android.server.connectivity.mdns.util.MdnsUtils.createQueryDatagramPackets
35 import com.android.server.connectivity.mdns.util.MdnsUtils.truncateServiceName
36 import com.android.testutils.DevSdkIgnoreRule
37 import com.android.testutils.DevSdkIgnoreRunner
38 import org.junit.Assert.assertArrayEquals
39 import org.junit.Assert.assertEquals
40 import org.junit.Assert.assertFalse
41 import org.junit.Assert.assertTrue
42 import org.junit.Test
43 import org.junit.runner.RunWith
44 import java.net.DatagramPacket
45 import kotlin.test.assertContentEquals
46 
47 @RunWith(DevSdkIgnoreRunner::class)
48 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.S_V2)
49 class MdnsUtilsTest {
50 
51     @Test
52     fun testTruncateServiceName() {
53         assertEquals(truncateServiceName("测试abcde", 7), "测试a")
54         assertEquals(truncateServiceName("测试abcde", 100), "测试abcde")
55     }
56 
57     @Test
58     fun testTypeEqualsOrIsSubtype() {
59         assertTrue(MdnsUtils.typeEqualsOrIsSubtype(
60             arrayOf("_type", "_tcp", "local"),
61             arrayOf("_type", "_TCP", "local")
62         ))
63         assertTrue(MdnsUtils.typeEqualsOrIsSubtype(
64             arrayOf("_type", "_tcp", "local"),
65             arrayOf("a", "_SUB", "_type", "_TCP", "local")
66         ))
67         assertFalse(MdnsUtils.typeEqualsOrIsSubtype(
68             arrayOf("_sub", "_type", "_tcp", "local"),
69                 arrayOf("_type", "_TCP", "local")
70         ))
71         assertFalse(MdnsUtils.typeEqualsOrIsSubtype(
72                 arrayOf("a", "_other", "_type", "_tcp", "local"),
73                 arrayOf("a", "_SUB", "_type", "_TCP", "local")
74         ))
75     }
76 
77     @Test
78     fun testCreateQueryDatagramPackets() {
79         // Question data bytes:
80         // Name label(17)(duplicated labels) + PTR type(2) + cacheFlush(2) = 21
81         //
82         // Known answers data bytes:
83         // Name label(17)(duplicated labels) + PTR type(2) + cacheFlush(2) + receiptTimeMillis(4)
84         // + Data length(2) + Pointer data(18)(duplicated labels) = 45
85         val questions = mutableListOf<MdnsRecord>()
86         val knownAnswers = mutableListOf<MdnsRecord>()
87         for (i in 1..100) {
88             questions.add(MdnsPointerRecord(arrayOf("_testservice$i", "_tcp", "local"), false))
89             knownAnswers.add(MdnsPointerRecord(
90                     arrayOf("_testservice$i", "_tcp", "local"),
91                     0L,
92                     false,
93                     4_500_000L,
94                     arrayOf("MyTestService$i", "_testservice$i", "_tcp", "local")
95             ))
96         }
97         // MdnsPacket data bytes:
98         // Questions(21 * 100) + Answers(45 * 100) = 6600 -> at least 5 packets
99         val query = MdnsPacket(
100                 MdnsConstants.FLAGS_QUERY,
101                 questions as List<MdnsRecord>,
102                 knownAnswers as List<MdnsRecord>,
103                 emptyList(),
104                 emptyList()
105         )
106         // Expect the oversize MdnsPacket to be separated into 5 DatagramPackets.
107         val bufferSize = 1500
108         val packets = createQueryDatagramPackets(
109                 ByteArray(bufferSize),
110                 query,
111                 MdnsConstants.IPV4_SOCKET_ADDR
112         )
113         assertEquals(5, packets.size)
114         assertTrue(packets.all { packet -> packet.length < bufferSize })
115 
116         val mdnsPacket = createMdnsPacketFromMultipleDatagramPackets(packets)
117         assertEquals(query.flags, mdnsPacket.flags)
118         assertContentEquals(query.questions, mdnsPacket.questions)
119         assertContentEquals(query.answers, mdnsPacket.answers)
120     }
121 
122     private fun createMdnsPacketFromMultipleDatagramPackets(
123             packets: List<DatagramPacket>
124     ): MdnsPacket {
125         var flags = 0
126         val questions = mutableListOf<MdnsRecord>()
127         val answers = mutableListOf<MdnsRecord>()
128         for ((index, packet) in packets.withIndex()) {
129             val mdnsPacket = MdnsPacket.parse(MdnsPacketReader(packet))
130             if (index != packets.size - 1) {
131                 assertTrue((mdnsPacket.flags and FLAG_TRUNCATED) == FLAG_TRUNCATED)
132             }
133             flags = mdnsPacket.flags
134             questions.addAll(mdnsPacket.questions)
135             answers.addAll(mdnsPacket.answers)
136         }
137         return MdnsPacket(flags, questions, answers, emptyList(), emptyList())
138     }
139 
140     @Test
141     fun testCheckAllPacketsWithSameAddress() {
142         val buffer = ByteArray(10)
143         val v4Packet = DatagramPacket(buffer, buffer.size, IPV4_SOCKET_ADDR)
144         val otherV4Packet = DatagramPacket(
145             buffer,
146             buffer.size,
147             InetAddresses.parseNumericAddress("192.0.2.1"),
148             1234
149         )
150         val v6Packet = DatagramPacket(ByteArray(10), 10, IPV6_SOCKET_ADDR)
151         val otherV6Packet = DatagramPacket(
152             buffer,
153             buffer.size,
154             InetAddresses.parseNumericAddress("2001:db8::"),
155             1234
156         )
157         assertTrue(MdnsUtils.checkAllPacketsWithSameAddress(listOf()))
158         assertTrue(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v4Packet)))
159         assertTrue(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v4Packet, v4Packet)))
160         assertFalse(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v4Packet, otherV4Packet)))
161         assertTrue(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v6Packet)))
162         assertTrue(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v6Packet, v6Packet)))
163         assertFalse(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v6Packet, otherV6Packet)))
164         assertFalse(MdnsUtils.checkAllPacketsWithSameAddress(listOf(v4Packet, v6Packet)))
165     }
166 
167     @Test
168     fun testBuildMdnsServiceInfoFromResponse() {
169         val serviceInstanceName = "MyTestService"
170         val serviceType = "_testservice._tcp.local"
171         val hostName = "Android_000102030405060708090A0B0C0D0E0F.local"
172         val port = 12345
173         val ttlTime = 120000L
174         val testElapsedRealtime = 123L
175         val serviceName = "$serviceInstanceName.$serviceType".split(".").toTypedArray()
176         val v4Address = "192.0.2.1"
177         val v6Address = "2001:db8::1"
178         val interfaceIndex = 99
179         val response = MdnsResponse(0 /* now */, serviceName, interfaceIndex, null /* network */)
180         // Set PTR record
181         response.addPointerRecord(MdnsPointerRecord(serviceType.split(".").toTypedArray(),
182                 testElapsedRealtime, false /* cacheFlush */, ttlTime, serviceName))
183         // Set SRV record.
184         response.serviceRecord = MdnsServiceRecord(serviceName, testElapsedRealtime,
185                 false /* cacheFlush */, ttlTime, 0 /* servicePriority */, 0 /* serviceWeight */,
186                 port, hostName.split(".").toTypedArray())
187         // Set TXT record.
188         response.textRecord = MdnsTextRecord(serviceName,
189                 testElapsedRealtime, true /* cacheFlush */, 0L /* ttlMillis */,
190                 listOf(MdnsServiceInfo.TextEntry.fromString("somedifferent=entry")))
191         // Set InetAddress record.
192         response.addInet4AddressRecord(MdnsInetAddressRecord(hostName.split(".").toTypedArray(),
193                 testElapsedRealtime, true /* cacheFlush */,
194                 0L /* ttlMillis */, InetAddresses.parseNumericAddress(v4Address)))
195         response.addInet6AddressRecord(MdnsInetAddressRecord(hostName.split(".").toTypedArray(),
196                 testElapsedRealtime, true /* cacheFlush */,
197                 0L /* ttlMillis */, InetAddresses.parseNumericAddress(v6Address)))
198 
199         // Convert a MdnsResponse to a MdnsServiceInfo
200         val serviceInfo = MdnsUtils.buildMdnsServiceInfoFromResponse(
201                 response, serviceType.split(".").toTypedArray(), testElapsedRealtime)
202 
203         assertEquals(serviceInstanceName, serviceInfo.serviceInstanceName)
204         assertArrayEquals(serviceType.split(".").toTypedArray(), serviceInfo.serviceType)
205         assertArrayEquals(hostName.split(".").toTypedArray(), serviceInfo.hostName)
206         assertEquals(port, serviceInfo.port)
207         assertEquals(1, serviceInfo.ipv4Addresses.size)
208         assertEquals(v4Address, serviceInfo.ipv4Addresses[0])
209         assertEquals(1, serviceInfo.ipv6Addresses.size)
210         assertEquals(v6Address, serviceInfo.ipv6Addresses[0])
211         assertEquals(interfaceIndex, serviceInfo.interfaceIndex)
212         assertEquals(null, serviceInfo.network)
213         assertEquals(mapOf("somedifferent" to "entry"),
214                 serviceInfo.attributes)
215     }
216 }
217