1 /*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 
17 use alloc::vec::Vec;
18 use log::{error, info};
19 use tipc::{
20     ConnectResult, Deserialize, Handle, Manager, MessageResult, PortCfg, Serialize, Serializer,
21     Service, TipcError, Uuid,
22 };
23 
24 //Some constants that are useful to be predefined
25 const HELLO_TRUSTY_PORT_NAME: &str = "com.android.trusty.hello";
26 const HELLO_TRUSTY_MAX_MSG_SIZE: usize = 100;
27 const MESSAGE_OFFSET: usize = 2;
28 const TAG_STRING: u8 = 1;
29 const TAG_ERROR: u8 = 2;
30 
31 // We need to define a struct that implements the Service trait
32 struct HelloWorldService;
33 
34 // Tags can either identify a String in a method, or an error.
35 enum Tag {
36     TagString,
37     TagError,
38 }
39 
40 // We need to define a struct that implements Deserialize and Serialize traits
41 // Messages sent have a 1 byte tag, 1 byte length, and the actual message itself.
42 struct HelloWorldMessage {
43     tag: Tag,
44     length: u8,
45     message: Vec<u8>,
46 }
47 
48 // Providing a convenient way to translate u8 to a Tag
49 // We use this when deserializing
50 impl From<u8> for Tag {
from(item: u8) -> Self51     fn from(item: u8) -> Self {
52         match item {
53             TAG_STRING => Tag::TagString,
54             _ => Tag::TagError,
55         }
56     }
57 }
58 
59 // Providing a Deserialize implementation is necessary for the
60 // the manager to automatically deserialize messages in on_message()
61 impl Deserialize for HelloWorldMessage {
62     type Error = TipcError;
63     const MAX_SERIALIZED_SIZE: usize = HELLO_TRUSTY_MAX_MSG_SIZE;
64     // The deserialization creates a HelloWorldMessage to be sent back to the service to handle.
deserialize(bytes: &[u8], _handles: &mut [Option<Handle>]) -> Result<Self, TipcError>65     fn deserialize(bytes: &[u8], _handles: &mut [Option<Handle>]) -> Result<Self, TipcError> {
66         if bytes.len() < 2 {
67             log::error!("The message is too short!");
68             return Err(TipcError::InvalidData);
69         } else if bytes.len() - MESSAGE_OFFSET != bytes[1].into() {
70             log::error!("The serialized length does not match the actual length!");
71             return Err(TipcError::InvalidData);
72         }
73         // The first 2 bytes are tag and length, so extracting the message requires an offset.
74         let deserializedmessage = HelloWorldMessage {
75             tag: Tag::from(bytes[0]),
76             length: bytes[1],
77             message: bytes[MESSAGE_OFFSET..].to_vec(),
78         };
79         Ok(deserializedmessage)
80     }
81 }
82 
83 // Providing a serialize implementation is necessary for the
84 // the handle to send messages in on_message()
85 
86 impl<'s> Serialize<'s> for HelloWorldMessage {
87     // The serialize converts a HelloWorldMessage into a slice for the service to send.
88     // The first two bytes represent the tag and length, which we serialize first. The
89     // message is serialized with a different method.
serialize<'a: 's, S: Serializer<'s>>( &'a self, serializer: &mut S, ) -> Result<S::Ok, S::Error>90     fn serialize<'a: 's, S: Serializer<'s>>(
91         &'a self,
92         serializer: &mut S,
93     ) -> Result<S::Ok, S::Error> {
94         unsafe {
95             serializer.serialize_as_bytes(match self.tag {
96                 Tag::TagString => &TAG_STRING,
97                 Tag::TagError => &TAG_ERROR,
98             })?;
99             serializer.serialize_as_bytes(&self.length)?;
100         }
101         serializer.serialize_bytes(&self.message.as_slice())
102     }
103 }
104 
105 // An implementation of the Service trait for a struct, included in the instantiation
106 // of the Manager. The implementation of the Service essentially tells the Manager
107 // how to handle incoming connections and messages.
108 impl Service for HelloWorldService {
109     // Associates the Connection with a specific struct
110     type Connection = ();
111     // Associates the Message with a specific struct
112     type Message = HelloWorldMessage;
113 
114     // This method is called whenever a client connects.
115     // It should return Ok(Some(Connection)) if the connection is to be accepted.
on_connect( &self, _port: &PortCfg, _handle: &Handle, _peer: &Uuid, ) -> tipc::Result<ConnectResult<Self::Connection>>116     fn on_connect(
117         &self,
118         _port: &PortCfg,
119         _handle: &Handle,
120         _peer: &Uuid,
121     ) -> tipc::Result<ConnectResult<Self::Connection>> {
122         info!("Connection to the Rust service!");
123         Ok(ConnectResult::Accept(()))
124     }
125     // This method is called when the service receives a message.
126     // The manager handles the deserialization into msg, which is passed to this callback.
on_message( &self, _connection: &Self::Connection, handle: &Handle, msg: Self::Message, ) -> tipc::Result<MessageResult>127     fn on_message(
128         &self,
129         _connection: &Self::Connection,
130         handle: &Handle,
131         msg: Self::Message,
132     ) -> tipc::Result<MessageResult> {
133         // msg holds our deserialized HelloWorldMessage struct. Here, we want to get
134         // actual message out to create a response.
135         let inputmsg = std::str::from_utf8(&msg.message).map_err(|e| {
136             error!("Failed to convert message to valid UTF-8: {:?}", e);
137             TipcError::InvalidData
138         })?;
139         // Creating the response string based on the input string.
140         let outputmsg = format!("Hello, {}!", inputmsg);
141 
142         // We send the message via handle, the client connection to the service.
143         // Handle contains the method send, which will automatically serialize given
144         // a serialization implementation, and send the message.
145         handle.send(&HelloWorldMessage {
146             tag: Tag::TagString,
147             length: outputmsg.len().try_into().map_err(|e| {
148                 error!("Length of message is too long!: {:?}", e);
149                 TipcError::InvalidData
150             })?,
151             message: outputmsg.as_bytes().to_vec(),
152         })?;
153         // We keep the connection open by returning Ok(MaintainConnection).
154         // Returning an Ok(CloseConnection) or an Err(_) will close the connection.
155         Ok(MessageResult::MaintainConnection)
156     }
157 }
158 
159 // Essentially the main function, it sets up the port and manager.
160 // It immediately starts the service event loop afterwards.
init_and_start_loop() -> Result<(), TipcError>161 pub fn init_and_start_loop() -> Result<(), TipcError> {
162     // Allows the use of logging macros such as info!, debug!, error!
163     trusty_log::init();
164     info!("Hello from the Rust Hello World TA");
165 
166     // Instantiates a new port configuration. It describes a service port path, among other
167     // options. Here, we're allowing other secure (Trusty) clients to connect, as well as
168     // setting the maximumm message length for this port.
169     let cfg = PortCfg::new(HELLO_TRUSTY_PORT_NAME)
170         .map_err(|e| {
171             error!("Could not create port config: {:?}", e);
172             TipcError::UnknownError
173         })?
174         .allow_ta_connect()
175         .msg_max_size(HELLO_TRUSTY_MAX_MSG_SIZE as u32);
176 
177     // Instantiates our service. Services handle IPC messages for specific ports.
178     let service = HelloWorldService {};
179     // Incoming bytes from the manager will be received here.
180     let buffer = [0u8; HELLO_TRUSTY_MAX_MSG_SIZE];
181 
182     // The manager handles the IPC event loop. Given our port configuration, buffer, and
183     // service, it forwards incoming connections & messages to the service.
184     //
185     // <_, _, 1, 1> means we define a Manager with a single port, and max one connection.
186     Manager::<_, _, 1, 1>::new(service, cfg, buffer)
187         .map_err(|e| {
188             error!("Could not create Rust Hello World Service manager: {:?}", e);
189             TipcError::UnknownError
190         })?
191         // The event loop waits for connections/messages,
192         // then dispatches them to the service to be handled
193         .run_event_loop()
194 }
195 
196 // Our testing suite is comprised of a simple connection test and one that is similar
197 // to the original Hello World test-app TA, that exercises the actual TA.
198 #[cfg(test)]
199 mod tests {
200     use super::*;
201     use test::{expect, expect_eq};
202     use trusty_std::ffi::{CString, FallibleCString};
203     //This line is needed in order to run the unit tests in Trusty
204     test::init!();
205 
206     //This test simply attempts a connection, by creating a connection to our TA port.
207     #[test]
connection_test()208     fn connection_test() {
209         let port = CString::try_new(HELLO_TRUSTY_PORT_NAME).unwrap();
210         let _session = Handle::connect(port.as_c_str()).unwrap();
211     }
212     // This test is like the original test-app TA, sending "Hello",
213     // and expecting to see "Hello, World" back.
214     #[test]
hello_world_test()215     fn hello_world_test() {
216         //Setting up the connection to the TA.
217         let port = CString::try_new(HELLO_TRUSTY_PORT_NAME).unwrap();
218         let session = Handle::connect(port.as_c_str()).unwrap();
219         //Creating the input bytes, and sending it to the TA.
220         let inputstring = "World";
221         let test_message = HelloWorldMessage {
222             tag: Tag::TagString,
223             length: inputstring.len().try_into().unwrap(),
224             message: inputstring.as_bytes().to_vec(),
225         };
226         session.send(&test_message).unwrap();
227         //Receiving the response and checking that that it matches "Hello, World".
228         let buf = &mut [0; HELLO_TRUSTY_MAX_MSG_SIZE as usize];
229         let response: Result<HelloWorldMessage, _> = session.recv(buf);
230         expect!(response.is_ok(), "The message should be able to be sent.");
231         let deserializedstr = std::str::from_utf8(&buf[2..])
232             .expect("Not a valid UTF-8 string!")
233             .trim_matches(char::from(0));
234         expect_eq!(
235             deserializedstr,
236             "Hello, World!",
237             "Testing that {} would return Hello, {}!",
238             inputstring,
239             inputstring
240         );
241     }
242 }
243