1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <string>
16 
17 #include "gtest/gtest.h"
18 #include "ukey2_ffi.h"
19 
20 namespace rust {
21 namespace {
22 
RunHandshake(Ukey2Handshake initiator_handle,Ukey2Handshake responder_handle)23 void RunHandshake(Ukey2Handshake initiator_handle,
24                   Ukey2Handshake responder_handle) {
25   ParseResult parse_result = responder_handle.ParseHandshakeMessage(
26       initiator_handle.GetNextHandshakeMessage());
27   ASSERT_TRUE(parse_result.success);
28   parse_result = initiator_handle.ParseHandshakeMessage(
29       responder_handle.GetNextHandshakeMessage());
30   ASSERT_TRUE(parse_result.success);
31   parse_result = responder_handle.ParseHandshakeMessage(
32       initiator_handle.GetNextHandshakeMessage());
33   ASSERT_TRUE(parse_result.success);
34 }
35 
TEST(Ukey2RustTest,HandshakeStartsIncomplete)36 TEST(Ukey2RustTest, HandshakeStartsIncomplete) {
37   Ukey2Handshake responder_handle = Ukey2Handshake::ForResponder();
38   Ukey2Handshake initiator_handle = Ukey2Handshake::ForInitiator();
39   ASSERT_FALSE(responder_handle.IsHandshakeComplete());
40   ASSERT_FALSE(initiator_handle.IsHandshakeComplete());
41 }
42 
TEST(Ukey2RustTest,HandshakeComplete)43 TEST(Ukey2RustTest, HandshakeComplete) {
44   Ukey2Handshake responder_handle = Ukey2Handshake::ForResponder();
45   Ukey2Handshake initiator_handle = Ukey2Handshake::ForInitiator();
46   RunHandshake(initiator_handle, responder_handle);
47   EXPECT_TRUE(responder_handle.IsHandshakeComplete());
48   EXPECT_TRUE(initiator_handle.IsHandshakeComplete());
49 }
50 
TEST(Ukey2RustTest,CanSendReceiveMessage)51 TEST(Ukey2RustTest, CanSendReceiveMessage) {
52   Ukey2Handshake responder_handle = Ukey2Handshake::ForResponder();
53   Ukey2Handshake initiator_handle = Ukey2Handshake::ForInitiator();
54   RunHandshake(initiator_handle, responder_handle);
55   ASSERT_TRUE(responder_handle.IsHandshakeComplete());
56   ASSERT_TRUE(initiator_handle.IsHandshakeComplete());
57   D2DConnectionContextV1 responder_connection =
58       responder_handle.ToConnectionContext();
59   D2DConnectionContextV1 initiator_connection =
60       initiator_handle.ToConnectionContext();
61   std::string message = "hello world";
62   auto encoded = responder_connection.EncodeMessageToPeer(message, "assocdata");
63   ASSERT_NE(encoded, "");
64   auto decoded =
65       initiator_connection.DecodeMessageFromPeer(encoded, "assocdata");
66   EXPECT_EQ(message, decoded);
67 }
68 
TEST(Ukey2RustTest,TestSaveRestoreSession)69 TEST(Ukey2RustTest, TestSaveRestoreSession) {
70   Ukey2Handshake responder_handle = Ukey2Handshake::ForResponder();
71   Ukey2Handshake initiator_handle = Ukey2Handshake::ForInitiator();
72   RunHandshake(initiator_handle, responder_handle);
73   ASSERT_TRUE(responder_handle.IsHandshakeComplete());
74   ASSERT_TRUE(initiator_handle.IsHandshakeComplete());
75   D2DConnectionContextV1 responder_connection =
76       responder_handle.ToConnectionContext();
77   D2DConnectionContextV1 initiator_connection =
78       initiator_handle.ToConnectionContext();
79   auto saved_responder = responder_connection.SaveSession();
80   D2DRestoreConnectionContextV1Result restore_result =
81       D2DConnectionContextV1::FromSavedSession(saved_responder);
82   ASSERT_EQ(restore_result.status,
83             CD2DRestoreConnectionContextV1Status::STATUS_GOOD);
84   auto new_responder = restore_result.handle;
85   std::string encoded = new_responder.EncodeMessageToPeer("hello world", "");
86   std::string decoded = initiator_connection.DecodeMessageFromPeer(encoded, "");
87   EXPECT_EQ("hello world", decoded);
88 }
89 
90 }  // namespace
91 }  // namespace rust
92