1 //! This module handles "arbitration" of ATT packets, to determine whether they
2 //! should be handled by the primary stack or by the Rust stack
3 
4 use pdl_runtime::Packet;
5 use std::sync::{Arc, Mutex};
6 
7 use log::{error, trace, warn};
8 use std::sync::RwLock;
9 
10 use crate::{do_in_rust_thread, packets::att};
11 
12 use super::{
13     ffi::{InterceptAction, StoreCallbacksFromRust},
14     ids::{AdvertiserId, TransportIndex},
15     mtu::MtuEvent,
16     opcode_types::{classify_opcode, OperationType},
17     server::isolation_manager::IsolationManager,
18 };
19 
20 static ARBITER: RwLock<Option<Arc<Mutex<IsolationManager>>>> = RwLock::new(None);
21 
22 /// Initialize the Arbiter
initialize_arbiter() -> Arc<Mutex<IsolationManager>>23 pub fn initialize_arbiter() -> Arc<Mutex<IsolationManager>> {
24     let arbiter = Arc::new(Mutex::new(IsolationManager::new()));
25     let mut lock = ARBITER.write().unwrap();
26     assert!(lock.is_none(), "Rust stack should only start up once");
27     *lock = Some(arbiter.clone());
28 
29     StoreCallbacksFromRust(
30         on_le_connect,
31         on_le_disconnect,
32         intercept_packet,
33         |tcb_idx| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::OutgoingRequest),
34         |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingResponse(mtu)),
35         |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingRequest(mtu)),
36     );
37 
38     arbiter
39 }
40 
41 /// Clean the Arbiter
clean_arbiter()42 pub fn clean_arbiter() {
43     let mut lock = ARBITER.write().unwrap();
44     *lock = None
45 }
46 
47 /// Acquire the mutex holding the Arbiter and provide a mutable reference to the
48 /// supplied closure
with_arbiter<T>(f: impl FnOnce(&mut IsolationManager) -> T) -> T49 pub fn with_arbiter<T>(f: impl FnOnce(&mut IsolationManager) -> T) -> T {
50     f(ARBITER.read().unwrap().as_ref().expect("Rust stack is not started").lock().as_mut().unwrap())
51 }
52 
53 /// Check if the Arbiter is initialized.
has_arbiter() -> bool54 pub fn has_arbiter() -> bool {
55     ARBITER.read().unwrap().is_some()
56 }
57 
58 /// Test to see if a buffer contains a valid ATT packet with an opcode we
59 /// are interested in intercepting (those intended for servers that are isolated)
try_parse_att_server_packet( isolation_manager: &IsolationManager, tcb_idx: TransportIndex, packet: &[u8], ) -> Option<att::Att>60 fn try_parse_att_server_packet(
61     isolation_manager: &IsolationManager,
62     tcb_idx: TransportIndex,
63     packet: &[u8],
64 ) -> Option<att::Att> {
65     isolation_manager.get_server_id(tcb_idx)?;
66 
67     let att = att::Att::decode_full(packet).ok()?;
68 
69     if att.opcode == att::AttOpcode::ExchangeMtuRequest {
70         // special case: this server opcode is handled by legacy stack, and we snoop
71         // on its handling, since the MTU is shared between the client + server
72         return None;
73     }
74 
75     match classify_opcode(att.opcode) {
76         OperationType::Command | OperationType::Request | OperationType::Confirmation => Some(att),
77         _ => None,
78     }
79 }
80 
on_le_connect(tcb_idx: u8, advertiser: u8)81 fn on_le_connect(tcb_idx: u8, advertiser: u8) {
82     let tcb_idx = TransportIndex(tcb_idx);
83     let advertiser = AdvertiserId(advertiser);
84     let is_isolated = with_arbiter(|arbiter| arbiter.is_advertiser_isolated(advertiser));
85     if is_isolated {
86         do_in_rust_thread(move |modules| {
87             if let Err(err) = modules.gatt_module.on_le_connect(tcb_idx, Some(advertiser)) {
88                 error!("{err:?}")
89             }
90         })
91     }
92 }
93 
on_le_disconnect(tcb_idx: u8)94 fn on_le_disconnect(tcb_idx: u8) {
95     // Events may be received after a FactoryReset
96     // is initiated for Bluetooth and the rust arbiter is taken
97     // down.
98     if !has_arbiter() {
99         warn!("arbiter is not yet initialized");
100         return;
101     }
102 
103     let tcb_idx = TransportIndex(tcb_idx);
104     let was_isolated = with_arbiter(|arbiter| arbiter.is_connection_isolated(tcb_idx));
105     if was_isolated {
106         do_in_rust_thread(move |modules| {
107             if let Err(err) = modules.gatt_module.on_le_disconnect(tcb_idx) {
108                 error!("{err:?}")
109             }
110         })
111     }
112 }
113 
intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction114 fn intercept_packet(tcb_idx: u8, packet: Vec<u8>) -> InterceptAction {
115     // Events may be received after a FactoryReset
116     // is initiated for Bluetooth and the rust arbiter is taken
117     // down.
118     if !has_arbiter() {
119         warn!("arbiter is not yet initialized");
120         return InterceptAction::Drop;
121     }
122 
123     let tcb_idx = TransportIndex(tcb_idx);
124     if let Some(att) =
125         with_arbiter(|arbiter| try_parse_att_server_packet(arbiter, tcb_idx, &packet))
126     {
127         do_in_rust_thread(move |modules| {
128             trace!("pushing packet to GATT");
129             if let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) {
130                 bearer.handle_packet(att)
131             } else {
132                 error!("Bearer for {tcb_idx:?} not found");
133             }
134         });
135         InterceptAction::Drop
136     } else {
137         InterceptAction::Forward
138     }
139 }
140 
on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent)141 fn on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent) {
142     if with_arbiter(|arbiter| arbiter.is_connection_isolated(tcb_idx)) {
143         do_in_rust_thread(move |modules| {
144             let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) else {
145                 error!("Bearer for {tcb_idx:?} not found");
146                 return;
147             };
148             if let Err(err) = bearer.handle_mtu_event(event) {
149                 error!("{err:?}")
150             }
151         });
152     }
153 }
154 
155 #[cfg(test)]
156 mod test {
157     use super::*;
158 
159     use crate::{
160         gatt::ids::{AttHandle, ServerId},
161         packets::att,
162     };
163 
164     const TCB_IDX: TransportIndex = TransportIndex(1);
165     const ADVERTISER_ID: AdvertiserId = AdvertiserId(3);
166     const SERVER_ID: ServerId = ServerId(4);
167 
create_manager_with_isolated_connection( tcb_idx: TransportIndex, server_id: ServerId, ) -> IsolationManager168     fn create_manager_with_isolated_connection(
169         tcb_idx: TransportIndex,
170         server_id: ServerId,
171     ) -> IsolationManager {
172         let mut isolation_manager = IsolationManager::new();
173         isolation_manager.associate_server_with_advertiser(server_id, ADVERTISER_ID);
174         isolation_manager.on_le_connect(tcb_idx, Some(ADVERTISER_ID));
175         isolation_manager
176     }
177 
178     #[test]
test_packet_capture_when_isolated()179     fn test_packet_capture_when_isolated() {
180         let isolation_manager = create_manager_with_isolated_connection(TCB_IDX, SERVER_ID);
181         let packet = att::AttReadRequest { attribute_handle: AttHandle(1).into() };
182 
183         let out = try_parse_att_server_packet(
184             &isolation_manager,
185             TCB_IDX,
186             &packet.encode_to_vec().unwrap(),
187         );
188 
189         assert!(out.is_some());
190     }
191 
192     #[test]
test_packet_bypass_when_isolated()193     fn test_packet_bypass_when_isolated() {
194         let isolation_manager = create_manager_with_isolated_connection(TCB_IDX, SERVER_ID);
195         let packet = att::AttErrorResponse {
196             opcode_in_error: att::AttOpcode::ReadResponse,
197             handle_in_error: AttHandle(1).into(),
198             error_code: att::AttErrorCode::InvalidHandle,
199         };
200 
201         let out = try_parse_att_server_packet(
202             &isolation_manager,
203             TCB_IDX,
204             &packet.encode_to_vec().unwrap(),
205         );
206 
207         assert!(out.is_none());
208     }
209 
210     #[test]
test_mtu_bypass()211     fn test_mtu_bypass() {
212         let isolation_manager = create_manager_with_isolated_connection(TCB_IDX, SERVER_ID);
213         let packet = att::AttExchangeMtuRequest { mtu: 64 };
214 
215         let out = try_parse_att_server_packet(
216             &isolation_manager,
217             TCB_IDX,
218             &packet.encode_to_vec().unwrap(),
219         );
220 
221         assert!(out.is_none());
222     }
223 
224     #[test]
test_packet_bypass_when_not_isolated()225     fn test_packet_bypass_when_not_isolated() {
226         let isolation_manager = IsolationManager::new();
227         let packet = att::AttReadRequest { attribute_handle: AttHandle(1).into() };
228 
229         let out = try_parse_att_server_packet(
230             &isolation_manager,
231             TCB_IDX,
232             &packet.encode_to_vec().unwrap(),
233         );
234 
235         assert!(out.is_none());
236     }
237 }
238