1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 //! Provides a sample ukey2 shell app which can be run from the command line
16
17 #![allow(clippy::expect_used)]
18 //TODO: remove this and fix instances of unwrap
19 #![allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
20
21 use std::io::{Read, Write};
22 use std::process::exit;
23
24 use clap::Parser;
25
26 use crypto_provider_rustcrypto::RustCrypto;
27 use ukey2_connections::{
28 D2DConnectionContextV1, D2DHandshakeContext, HandshakeImplementation,
29 InitiatorD2DHandshakeContext, NextProtocol, ServerD2DHandshakeContext,
30 };
31
32 const MODE_INITIATOR: &str = "initiator";
33 const MODE_RESPONDER: &str = "responder";
34
35 #[derive(Parser, Debug)]
36 struct Ukey2Cli {
37 /// initiator or responder mode
38 #[arg(short, long)]
39 mode: String,
40 /// length of auth string/next proto secret
41 #[arg(short, long, default_value_t = 32)]
42 verification_string_length: i32,
43 }
44
45 /// Framing functions
46 /*
47 // Writes |message| to stdout in the frame format.
48 void WriteFrame(const string& message) {
49 // Write length of |message| in little-endian.
50 const uint32_t length = message.length();
51 fputc((length >> (3 * 8)) & 0xFF, stdout);
52 fputc((length >> (2 * 8)) & 0xFF, stdout);
53 fputc((length >> (1 * 8)) & 0xFF, stdout);
54 fputc((length >> (0 * 8)) & 0xFF, stdout);
55
56 // Write message to stdout.
57 CHECK_EQ(message.length(),
58 fwrite(message.c_str(), 1, message.length(), stdout));
59 CHECK_EQ(0, fflush(stdout));
60 }
61 */
write_frame(message: Vec<u8>)62 fn write_frame(message: Vec<u8>) {
63 let length: u32 = message.len() as u32;
64 let length_bytes = length.to_be_bytes();
65 std::io::stdout().write_all(&length_bytes).unwrap();
66 std::io::stdout().write_all(message.as_slice()).expect("failed to write message");
67 let _ = std::io::stdout().flush();
68 }
69
70 /*
71 // Returns a message read from stdin after parsing it from the frame format.
72 string ReadFrame() {
73 // Read length of the frame from the stream.
74 uint8_t length_data[sizeof(uint32_t)];
75 CHECK_EQ(sizeof(uint32_t), fread(&length_data, 1, sizeof(uint32_t), stdin));
76
77 uint32_t length = 0;
78 length |= static_cast<uint32_t>(length_data[0]) << (3 * 8);
79 length |= static_cast<uint32_t>(length_data[1]) << (2 * 8);
80 length |= static_cast<uint32_t>(length_data[2]) << (1 * 8);
81 length |= static_cast<uint32_t>(length_data[3]) << (0 * 8);
82
83 // Read |length| bytes from the stream.
84 absl::FixedArray<char> buffer(length);
85 CHECK_EQ(length, fread(buffer.data(), 1, length, stdin));
86
87 return string(buffer.data(), length);
88 }
89
90 */
91 const LENGTH: usize = std::mem::size_of::<u32>();
92
read_frame() -> Vec<u8>93 fn read_frame() -> Vec<u8> {
94 let mut length_buf = [0u8; LENGTH];
95 assert_eq!(LENGTH, std::io::stdin().read(&mut length_buf).unwrap());
96 let length_usize = u32::from_be_bytes(length_buf);
97 let mut buffer = vec![0u8; length_usize as usize];
98 std::io::stdin().read_exact(buffer.as_mut_slice()).expect("failed to read frame");
99 buffer
100 }
101
102 struct Ukey2Shell {
103 verification_string_length: usize,
104 }
105
106 impl Ukey2Shell {
new(verification_string_length: i32) -> Self107 fn new(verification_string_length: i32) -> Self {
108 Self { verification_string_length: verification_string_length as usize }
109 }
110
run_secure_connection_loop(connection_ctx: &mut D2DConnectionContextV1) -> bool111 fn run_secure_connection_loop(connection_ctx: &mut D2DConnectionContextV1) -> bool {
112 loop {
113 let input = read_frame();
114 let idx = input.iter().enumerate().find(|(_index, &byte)| byte == 0x20).unwrap().0;
115 let (cmd, payload) = (&input[0..idx], &input[idx + 1..]);
116 if cmd == b"encrypt" {
117 let result =
118 connection_ctx.encode_message_to_peer::<RustCrypto, &[u8]>(payload, None);
119 write_frame(result);
120 } else if cmd == b"decrypt" {
121 let result =
122 connection_ctx.decode_message_from_peer::<RustCrypto, &[u8]>(payload, None);
123 if result.is_err() {
124 println!("failed to decode payload");
125 return false;
126 }
127 write_frame(result.unwrap());
128 } else if cmd == b"session_unique" {
129 let result = connection_ctx.get_session_unique::<RustCrypto>();
130 write_frame(result);
131 } else {
132 println!("unknown command");
133 return false;
134 }
135 }
136 }
137
run_as_initiator(&self) -> bool138 fn run_as_initiator(&self) -> bool {
139 let mut initiator_ctx = InitiatorD2DHandshakeContext::<RustCrypto, _>::new(
140 HandshakeImplementation::PublicKeyInProtobuf,
141 vec![NextProtocol::Aes256CbcHmacSha256, NextProtocol::Aes256GcmSiv],
142 );
143 write_frame(initiator_ctx.get_next_handshake_message().unwrap());
144 let server_init_msg = read_frame();
145 initiator_ctx
146 .handle_handshake_message(server_init_msg.as_slice())
147 .expect("Failed to handle message");
148 write_frame(initiator_ctx.get_next_handshake_message().unwrap_or_default());
149 // confirm auth str
150 let auth_str = initiator_ctx
151 .to_completed_handshake()
152 .ok()
153 .and_then(|h| h.auth_string::<RustCrypto>().derive_vec(self.verification_string_length))
154 .unwrap_or_else(|| vec![0; self.verification_string_length]);
155 write_frame(auth_str);
156 let ack = read_frame();
157 if ack != "ok".to_string().into_bytes() {
158 println!("handshake failed");
159 return false;
160 }
161 // upgrade to connection context
162 let mut initiator_conn_ctx = initiator_ctx.to_connection_context().unwrap();
163 Self::run_secure_connection_loop(&mut initiator_conn_ctx)
164 }
165
run_as_responder(&self) -> bool166 fn run_as_responder(&self) -> bool {
167 let mut server_ctx = ServerD2DHandshakeContext::<RustCrypto, _>::new(
168 HandshakeImplementation::PublicKeyInProtobuf,
169 &[NextProtocol::Aes256GcmSiv, NextProtocol::Aes256CbcHmacSha256],
170 );
171 let initiator_init_msg = read_frame();
172 server_ctx.handle_handshake_message(initiator_init_msg.as_slice()).unwrap();
173 let server_next_msg = server_ctx.get_next_handshake_message().unwrap();
174 write_frame(server_next_msg);
175 let initiator_finish_msg = read_frame();
176 server_ctx
177 .handle_handshake_message(initiator_finish_msg.as_slice())
178 .expect("Failed to handle message");
179 // confirm auth str
180 let auth_str = server_ctx
181 .to_completed_handshake()
182 .ok()
183 .and_then(|h| h.auth_string::<RustCrypto>().derive_vec(self.verification_string_length))
184 .unwrap_or_else(|| vec![0; self.verification_string_length]);
185 write_frame(auth_str);
186 let ack = read_frame();
187 if ack != "ok".to_string().into_bytes() {
188 println!("handshake failed");
189 return false;
190 }
191 // upgrade to connection context
192 let mut server_conn_ctx = server_ctx.to_connection_context().unwrap();
193 Self::run_secure_connection_loop(&mut server_conn_ctx)
194 }
195 }
196
main()197 fn main() {
198 let args = Ukey2Cli::parse();
199 let shell = Ukey2Shell::new(args.verification_string_length);
200 if args.mode == MODE_INITIATOR {
201 let _ = shell.run_as_initiator();
202 } else if args.mode == MODE_RESPONDER {
203 let _ = shell.run_as_responder();
204 } else {
205 exit(1);
206 }
207 exit(0)
208 }
209