1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /*
18  * This Java source file was generated by the Gradle 'init' task.
19  */
20 package com.google.security.cryptauth.lib.securegcm.ukey2
21 
22 import java.nio.charset.StandardCharsets
23 import org.junit.jupiter.api.Assertions.assertArrayEquals
24 import org.junit.jupiter.api.Assertions.assertEquals
25 import org.junit.jupiter.api.Assertions.assertFalse
26 import org.junit.jupiter.api.Assertions.assertTrue
27 import org.junit.jupiter.api.Test
28 import org.junit.jupiter.api.assertDoesNotThrow
29 import org.junit.jupiter.api.assertThrows
30 
31 // Driver code
32 // Tests exception handling and the handshake routine, as well as encrypting/decrypting short
33 // message between the server and initiator contexts.
34 @Suppress("UNUSED_VARIABLE")
35 class TestUkey2Protocol {
36   @Test
testHandshakenull37   fun testHandshake() {
38     val initiatorContext = D2DHandshakeContext(D2DHandshakeContext.Role.INITIATOR)
39     assertFalse(initiatorContext.isHandshakeComplete)
40     val serverContext = D2DHandshakeContext(D2DHandshakeContext.Role.RESPONDER)
41     assertFalse(serverContext.isHandshakeComplete)
42     assertDoesNotThrow {
43       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
44       initiatorContext.parseHandshakeMessage(serverContext.nextHandshakeMessage)
45       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
46       assertTrue(initiatorContext.isHandshakeComplete)
47       assertTrue(serverContext.isHandshakeComplete)
48     }
49   }
50 
51   @Test
testSendReceiveMessagenull52   fun testSendReceiveMessage() {
53     val initiatorContext = D2DHandshakeContext(D2DHandshakeContext.Role.INITIATOR)
54     val serverContext = D2DHandshakeContext(D2DHandshakeContext.Role.RESPONDER)
55     assertDoesNotThrow {
56       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
57       initiatorContext.parseHandshakeMessage(serverContext.nextHandshakeMessage)
58       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
59       val connContext = initiatorContext.toConnectionContext()
60       val serverConnContext = serverContext.toConnectionContext()
61       val initialShareString = "Nearby sharing to server"
62       val encoded =
63           connContext.encodeMessageToPeer(
64               initialShareString.toByteArray(StandardCharsets.UTF_8), null)
65       val response =
66           String(serverConnContext.decodeMessageFromPeer(encoded, null), StandardCharsets.UTF_8)
67       assertEquals(response, initialShareString)
68     }
69   }
70 
71   @Test
testSaveRestoreSessionnull72   fun testSaveRestoreSession() {
73     val initiatorContext = D2DHandshakeContext(D2DHandshakeContext.Role.INITIATOR)
74     val serverContext = D2DHandshakeContext(D2DHandshakeContext.Role.RESPONDER)
75     assertDoesNotThrow {
76       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
77       initiatorContext.parseHandshakeMessage(serverContext.nextHandshakeMessage)
78       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
79       val connContext = initiatorContext.toConnectionContext()
80       val serverConnContext = serverContext.toConnectionContext()
81       val initiatorSavedSession = connContext.saveSession()
82       val restored = D2DConnectionContextV1.fromSavedSession(initiatorSavedSession)
83       assertArrayEquals(connContext.sessionUnique, restored.sessionUnique)
84       val initialShareString = "Nearby sharing to server"
85       val encoded =
86           serverConnContext.encodeMessageToPeer(
87               initialShareString.toByteArray(StandardCharsets.UTF_8), null)
88       val response = String(restored.decodeMessageFromPeer(encoded, null), StandardCharsets.UTF_8)
89       assertEquals(response, initialShareString)
90     }
91   }
92 
93   @Test
testSaveRestoreBadSessionnull94   fun testSaveRestoreBadSession() {
95     val initiatorContext = D2DHandshakeContext(D2DHandshakeContext.Role.INITIATOR)
96     val serverContext = D2DHandshakeContext(D2DHandshakeContext.Role.RESPONDER)
97     val deriveInitiatorSavedSession = {
98       assertDoesNotThrow {
99         serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
100         initiatorContext.parseHandshakeMessage(serverContext.nextHandshakeMessage)
101         serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
102         val connContext = initiatorContext.toConnectionContext()
103         val serverConnContext = serverContext.toConnectionContext()
104         connContext.saveSession()
105       }
106     }
107     assertThrows<SessionRestoreException> {
108       val unused =
109           D2DConnectionContextV1.fromSavedSession(deriveInitiatorSavedSession().copyOfRange(0, 20))
110     }
111   }
112 
113   @Test
tryReuseHandshakeContextnull114   fun tryReuseHandshakeContext() {
115     val initiatorContext = D2DHandshakeContext(D2DHandshakeContext.Role.INITIATOR)
116     val serverContext = D2DHandshakeContext(D2DHandshakeContext.Role.RESPONDER)
117     assertDoesNotThrow {
118       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
119       initiatorContext.parseHandshakeMessage(serverContext.nextHandshakeMessage)
120       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
121       val connContext = initiatorContext.toConnectionContext()
122       val serverConnContext = serverContext.toConnectionContext()
123     }
124     assertThrows<BadHandleException> {
125       val unused = serverContext.nextHandshakeMessage
126     }
127   }
128 
129   @Test
testSendReceiveMessageWithAssociatedDatanull130   fun testSendReceiveMessageWithAssociatedData() {
131     val initiatorContext = D2DHandshakeContext(D2DHandshakeContext.Role.INITIATOR)
132     val serverContext = D2DHandshakeContext(D2DHandshakeContext.Role.RESPONDER)
133     val associatedData = "Associated data.".toByteArray()
134     assertDoesNotThrow {
135       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
136       initiatorContext.parseHandshakeMessage(serverContext.nextHandshakeMessage)
137       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
138       val connContext = initiatorContext.toConnectionContext()
139       val serverConnContext = serverContext.toConnectionContext()
140       val initialShareString = "Nearby sharing to server"
141       val encoded =
142           connContext.encodeMessageToPeer(
143               initialShareString.toByteArray(StandardCharsets.UTF_8), associatedData)
144       val response =
145           String(
146               serverConnContext.decodeMessageFromPeer(encoded, associatedData),
147               StandardCharsets.UTF_8)
148       assertEquals(response, initialShareString)
149     }
150   }
151 
152   @Test
testVerificationStringnull153   fun testVerificationString() {
154     val initiatorContext = D2DHandshakeContext(D2DHandshakeContext.Role.INITIATOR)
155     val serverContext = D2DHandshakeContext(D2DHandshakeContext.Role.RESPONDER)
156     assertDoesNotThrow {
157       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
158       initiatorContext.parseHandshakeMessage(serverContext.nextHandshakeMessage)
159       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
160     }
161     assert(serverContext.isHandshakeComplete)
162     assert(initiatorContext.isHandshakeComplete)
163     assertArrayEquals(
164         serverContext.getVerificationString(32), initiatorContext.getVerificationString(32))
165   }
166 
167   @Test
throwsAlertExceptionWhenBadMessagenull168   fun throwsAlertExceptionWhenBadMessage() {
169     val serverContext = D2DHandshakeContext(D2DHandshakeContext.Role.RESPONDER)
170     val exception =
171         assertThrows<AlertException> {
172           serverContext.parseHandshakeMessage("Hello UKEY2".toByteArray())
173         }
174     assert(exception.alertMessageToSend.isNotEmpty())
175   }
176 
177   @Test
testGcmnull178   fun testGcm() {
179     val initiatorContext =
180         D2DHandshakeContext(
181             D2DHandshakeContext.Role.INITIATOR,
182             arrayOf(D2DHandshakeContext.NextProtocol.AES_256_GCM_SIV))
183     val serverContext =
184         D2DHandshakeContext(
185             D2DHandshakeContext.Role.RESPONDER,
186             arrayOf(D2DHandshakeContext.NextProtocol.AES_256_GCM_SIV))
187     assertDoesNotThrow {
188       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
189       initiatorContext.parseHandshakeMessage(serverContext.nextHandshakeMessage)
190       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
191     }
192     assert(serverContext.isHandshakeComplete)
193     assert(initiatorContext.isHandshakeComplete)
194   }
195 
196   @Test
testGcmServer_cbcClientnull197   fun testGcmServer_cbcClient() {
198     val initiatorContext =
199         D2DHandshakeContext(
200             D2DHandshakeContext.Role.INITIATOR,
201             arrayOf(D2DHandshakeContext.NextProtocol.AES_256_CBC_HMAC_SHA256))
202     val serverContext =
203         D2DHandshakeContext(
204             D2DHandshakeContext.Role.RESPONDER,
205             arrayOf(
206                 D2DHandshakeContext.NextProtocol.AES_256_CBC_HMAC_SHA256,
207                 D2DHandshakeContext.NextProtocol.AES_256_GCM_SIV))
208     assertDoesNotThrow {
209       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
210       initiatorContext.parseHandshakeMessage(serverContext.nextHandshakeMessage)
211       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
212     }
213     assert(serverContext.isHandshakeComplete)
214     assert(initiatorContext.isHandshakeComplete)
215   }
216 
217   @Test
testGcmClient_cbcServernull218   fun testGcmClient_cbcServer() {
219     val initiatorContext =
220         D2DHandshakeContext(
221             D2DHandshakeContext.Role.INITIATOR,
222             arrayOf(
223                 D2DHandshakeContext.NextProtocol.AES_256_CBC_HMAC_SHA256,
224                 D2DHandshakeContext.NextProtocol.AES_256_GCM_SIV))
225     val serverContext =
226         D2DHandshakeContext(
227             D2DHandshakeContext.Role.RESPONDER,
228             arrayOf(D2DHandshakeContext.NextProtocol.AES_256_CBC_HMAC_SHA256))
229     assertDoesNotThrow {
230       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
231       initiatorContext.parseHandshakeMessage(serverContext.nextHandshakeMessage)
232       serverContext.parseHandshakeMessage(initiatorContext.nextHandshakeMessage)
233     }
234     assert(serverContext.isHandshakeComplete)
235     assert(initiatorContext.isHandshakeComplete)
236   }
237 }
238