1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Provides a sample ukey2 shell app which can be run from the command line
16 
17 #![allow(clippy::expect_used)]
18 //TODO: remove this and fix instances of unwrap
19 #![allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
20 
21 use std::io::{Read, Write};
22 use std::process::exit;
23 
24 use clap::Parser;
25 
26 use crypto_provider_rustcrypto::RustCrypto;
27 use ukey2_connections::{
28     D2DConnectionContextV1, D2DHandshakeContext, HandshakeImplementation,
29     InitiatorD2DHandshakeContext, NextProtocol, ServerD2DHandshakeContext,
30 };
31 
32 const MODE_INITIATOR: &str = "initiator";
33 const MODE_RESPONDER: &str = "responder";
34 
35 #[derive(Parser, Debug)]
36 struct Ukey2Cli {
37     /// initiator or responder mode
38     #[arg(short, long)]
39     mode: String,
40     /// length of auth string/next proto secret
41     #[arg(short, long, default_value_t = 32)]
42     verification_string_length: i32,
43 }
44 
45 /// Framing functions
46 /*
47 // Writes |message| to stdout in the frame format.
48 void WriteFrame(const string& message) {
49   // Write length of |message| in little-endian.
50   const uint32_t length = message.length();
51   fputc((length >> (3 * 8)) & 0xFF, stdout);
52   fputc((length >> (2 * 8)) & 0xFF, stdout);
53   fputc((length >> (1 * 8)) & 0xFF, stdout);
54   fputc((length >> (0 * 8)) & 0xFF, stdout);
55 
56   // Write message to stdout.
57   CHECK_EQ(message.length(),
58            fwrite(message.c_str(), 1, message.length(), stdout));
59   CHECK_EQ(0, fflush(stdout));
60 }
61 */
write_frame(message: Vec<u8>)62 fn write_frame(message: Vec<u8>) {
63     let length: u32 = message.len() as u32;
64     let length_bytes = length.to_be_bytes();
65     std::io::stdout().write_all(&length_bytes).unwrap();
66     std::io::stdout().write_all(message.as_slice()).expect("failed to write message");
67     let _ = std::io::stdout().flush();
68 }
69 
70 /*
71 // Returns a message read from stdin after parsing it from the frame format.
72 string ReadFrame() {
73   // Read length of the frame from the stream.
74   uint8_t length_data[sizeof(uint32_t)];
75   CHECK_EQ(sizeof(uint32_t), fread(&length_data, 1, sizeof(uint32_t), stdin));
76 
77   uint32_t length = 0;
78   length |= static_cast<uint32_t>(length_data[0]) << (3 * 8);
79   length |= static_cast<uint32_t>(length_data[1]) << (2 * 8);
80   length |= static_cast<uint32_t>(length_data[2]) << (1 * 8);
81   length |= static_cast<uint32_t>(length_data[3]) << (0 * 8);
82 
83   // Read |length| bytes from the stream.
84   absl::FixedArray<char> buffer(length);
85   CHECK_EQ(length, fread(buffer.data(), 1, length, stdin));
86 
87   return string(buffer.data(), length);
88 }
89 
90  */
91 const LENGTH: usize = std::mem::size_of::<u32>();
92 
read_frame() -> Vec<u8>93 fn read_frame() -> Vec<u8> {
94     let mut length_buf = [0u8; LENGTH];
95     assert_eq!(LENGTH, std::io::stdin().read(&mut length_buf).unwrap());
96     let length_usize = u32::from_be_bytes(length_buf);
97     let mut buffer = vec![0u8; length_usize as usize];
98     std::io::stdin().read_exact(buffer.as_mut_slice()).expect("failed to read frame");
99     buffer
100 }
101 
102 struct Ukey2Shell {
103     verification_string_length: usize,
104 }
105 
106 impl Ukey2Shell {
new(verification_string_length: i32) -> Self107     fn new(verification_string_length: i32) -> Self {
108         Self { verification_string_length: verification_string_length as usize }
109     }
110 
run_secure_connection_loop(connection_ctx: &mut D2DConnectionContextV1) -> bool111     fn run_secure_connection_loop(connection_ctx: &mut D2DConnectionContextV1) -> bool {
112         loop {
113             let input = read_frame();
114             let idx = input.iter().enumerate().find(|(_index, &byte)| byte == 0x20).unwrap().0;
115             let (cmd, payload) = (&input[0..idx], &input[idx + 1..]);
116             if cmd == b"encrypt" {
117                 let result =
118                     connection_ctx.encode_message_to_peer::<RustCrypto, &[u8]>(payload, None);
119                 write_frame(result);
120             } else if cmd == b"decrypt" {
121                 let result =
122                     connection_ctx.decode_message_from_peer::<RustCrypto, &[u8]>(payload, None);
123                 if result.is_err() {
124                     println!("failed to decode payload");
125                     return false;
126                 }
127                 write_frame(result.unwrap());
128             } else if cmd == b"session_unique" {
129                 let result = connection_ctx.get_session_unique::<RustCrypto>();
130                 write_frame(result);
131             } else {
132                 println!("unknown command");
133                 return false;
134             }
135         }
136     }
137 
run_as_initiator(&self) -> bool138     fn run_as_initiator(&self) -> bool {
139         let mut initiator_ctx = InitiatorD2DHandshakeContext::<RustCrypto, _>::new(
140             HandshakeImplementation::PublicKeyInProtobuf,
141             vec![NextProtocol::Aes256CbcHmacSha256, NextProtocol::Aes256GcmSiv],
142         );
143         write_frame(initiator_ctx.get_next_handshake_message().unwrap());
144         let server_init_msg = read_frame();
145         initiator_ctx
146             .handle_handshake_message(server_init_msg.as_slice())
147             .expect("Failed to handle message");
148         write_frame(initiator_ctx.get_next_handshake_message().unwrap_or_default());
149         // confirm auth str
150         let auth_str = initiator_ctx
151             .to_completed_handshake()
152             .ok()
153             .and_then(|h| h.auth_string::<RustCrypto>().derive_vec(self.verification_string_length))
154             .unwrap_or_else(|| vec![0; self.verification_string_length]);
155         write_frame(auth_str);
156         let ack = read_frame();
157         if ack != "ok".to_string().into_bytes() {
158             println!("handshake failed");
159             return false;
160         }
161         // upgrade to connection context
162         let mut initiator_conn_ctx = initiator_ctx.to_connection_context().unwrap();
163         Self::run_secure_connection_loop(&mut initiator_conn_ctx)
164     }
165 
run_as_responder(&self) -> bool166     fn run_as_responder(&self) -> bool {
167         let mut server_ctx = ServerD2DHandshakeContext::<RustCrypto, _>::new(
168             HandshakeImplementation::PublicKeyInProtobuf,
169             &[NextProtocol::Aes256GcmSiv, NextProtocol::Aes256CbcHmacSha256],
170         );
171         let initiator_init_msg = read_frame();
172         server_ctx.handle_handshake_message(initiator_init_msg.as_slice()).unwrap();
173         let server_next_msg = server_ctx.get_next_handshake_message().unwrap();
174         write_frame(server_next_msg);
175         let initiator_finish_msg = read_frame();
176         server_ctx
177             .handle_handshake_message(initiator_finish_msg.as_slice())
178             .expect("Failed to handle message");
179         // confirm auth str
180         let auth_str = server_ctx
181             .to_completed_handshake()
182             .ok()
183             .and_then(|h| h.auth_string::<RustCrypto>().derive_vec(self.verification_string_length))
184             .unwrap_or_else(|| vec![0; self.verification_string_length]);
185         write_frame(auth_str);
186         let ack = read_frame();
187         if ack != "ok".to_string().into_bytes() {
188             println!("handshake failed");
189             return false;
190         }
191         // upgrade to connection context
192         let mut server_conn_ctx = server_ctx.to_connection_context().unwrap();
193         Self::run_secure_connection_loop(&mut server_conn_ctx)
194     }
195 }
196 
main()197 fn main() {
198     let args = Ukey2Cli::parse();
199     let shell = Ukey2Shell::new(args.verification_string_length);
200     if args.mode == MODE_INITIATOR {
201         let _ = shell.run_as_initiator();
202     } else if args.mode == MODE_RESPONDER {
203         let _ = shell.run_as_responder();
204     } else {
205         exit(1);
206     }
207     exit(0)
208 }
209