1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <cassert>
16 #include <cstring>
17 #include <iostream>
18 #include <string>
19 
20 #include "ukey2_bindings.h"
21 #include "ukey2_ffi.h"
22 
nullByteArray()23 CFFIByteArray nullByteArray() {
24   return {
25       .handle = nullptr,
26       .len = 0,
27   };
28 }
29 
30 // Implementation of functions
ForInitiator()31 rust::Ukey2Handshake rust::Ukey2Handshake::ForInitiator() {
32   return Ukey2Handshake(initiator_new());
33 }
34 
ForResponder()35 rust::Ukey2Handshake rust::Ukey2Handshake::ForResponder() {
36   return Ukey2Handshake(responder_new());
37 }
38 
IsHandshakeComplete()39 bool rust::Ukey2Handshake::IsHandshakeComplete() {
40   return is_handshake_complete(handle_);
41 }
42 
GetNextHandshakeMessage()43 std::string rust::Ukey2Handshake::GetNextHandshakeMessage() {
44   RustFFIByteArray array = get_next_handshake_message(handle_);
45   std::string ret = std::string((const char*)array.handle, array.len);
46   rust_dealloc_ffi_byte_array(array);
47   return ret;
48 }
49 
ParseHandshakeMessage(std::string message)50 rust::ParseResult rust::Ukey2Handshake::ParseHandshakeMessage(
51     std::string message) {
52   CFFIByteArray messageRaw{
53       .handle = (uint8_t*)message.c_str(),
54       .len = message.length(),
55   };
56   CMessageParseResult result = parse_handshake_message(handle_, messageRaw);
57   std::string alert;
58   if (!result.success) {
59     std::cout << "parse failed" << std::endl;
60     RustFFIByteArray array = result.alert_to_send;
61     if (array.handle != nullptr) {
62       alert = std::string((const char*)array.handle, array.len);
63       rust_dealloc_ffi_byte_array(array);
64     }
65   }
66   return ParseResult{
67       .success = result.success,
68       .alert_to_send = alert,
69   };
70 }
71 
GetVerificationString(size_t output_length)72 std::string rust::Ukey2Handshake::GetVerificationString(size_t output_length) {
73   RustFFIByteArray array = get_verification_string(handle_, output_length);
74   std::string ret = std::string((const char*)array.handle, array.len);
75   rust_dealloc_ffi_byte_array(array);
76   return ret;
77 }
78 
ToConnectionContext()79 rust::D2DConnectionContextV1 rust::Ukey2Handshake::ToConnectionContext() {
80   assert(IsHandshakeComplete());
81   return D2DConnectionContextV1(to_connection_context(handle_));
82 }
83 
DecodeMessageFromPeer(std::string message,std::string associated_data)84 std::string rust::D2DConnectionContextV1::DecodeMessageFromPeer(
85     std::string message, std::string associated_data) {
86   CFFIByteArray messageRaw{
87       .handle = (uint8_t*)message.c_str(),
88       .len = message.length(),
89   };
90   CFFIByteArray associatedDataRaw{
91       .handle = (uint8_t*)associated_data.c_str(),
92       .len = associated_data.length(),
93   };
94   RustFFIByteArray array =
95       decode_message_from_peer(handle_, messageRaw, associatedDataRaw);
96   if (array.handle == nullptr) {
97     return "";
98   }
99   std::string ret = std::string((const char*)array.handle, array.len);
100   rust_dealloc_ffi_byte_array(array);
101   return ret;
102 }
103 
EncodeMessageToPeer(std::string message,std::string associated_data)104 std::string rust::D2DConnectionContextV1::EncodeMessageToPeer(
105     std::string message, std::string associated_data) {
106   CFFIByteArray messageRaw{
107       .handle = (uint8_t*)message.c_str(),
108       .len = message.length(),
109   };
110   CFFIByteArray associatedDataRaw{
111       .handle = (uint8_t*)associated_data.c_str(),
112       .len = associated_data.length(),
113   };
114   RustFFIByteArray array =
115       encode_message_to_peer(handle_, messageRaw, associatedDataRaw);
116   std::string ret = std::string((const char*)array.handle, array.len);
117   rust_dealloc_ffi_byte_array(array);
118   return ret;
119 }
120 
GetSessionUnique()121 std::string rust::D2DConnectionContextV1::GetSessionUnique() {
122   RustFFIByteArray array = get_session_unique(handle_);
123   std::string ret = std::string((const char*)array.handle, array.len);
124   rust_dealloc_ffi_byte_array(array);
125   return ret;
126 }
127 
GetSequenceNumberForEncoding()128 int rust::D2DConnectionContextV1::GetSequenceNumberForEncoding() {
129   return get_sequence_number_for_encoding(handle_);
130 }
131 
GetSequenceNumberForDecoding()132 int rust::D2DConnectionContextV1::GetSequenceNumberForDecoding() {
133   return get_sequence_number_for_decoding(handle_);
134 }
135 
SaveSession()136 std::string rust::D2DConnectionContextV1::SaveSession() {
137   RustFFIByteArray array = save_session(handle_);
138   std::string ret = std::string((const char*)array.handle, array.len);
139   rust_dealloc_ffi_byte_array(array);
140   return ret;
141 }
142 
143 rust::D2DRestoreConnectionContextV1Result
FromSavedSession(std::string data)144 rust::D2DConnectionContextV1::FromSavedSession(std::string data) {
145   CFFIByteArray arr{
146       .handle = (uint8_t*)data.c_str(),
147       .len = data.length(),
148   };
149   auto result = from_saved_session(arr);
150   return {
151       D2DConnectionContextV1(result.handle),
152       result.status,
153   };
154 }
155