//! This module handles "arbitration" of ATT packets, to determine whether they //! should be handled by the primary stack or by the Rust stack use pdl_runtime::Packet; use std::sync::{Arc, Mutex}; use log::{error, trace, warn}; use std::sync::RwLock; use crate::{do_in_rust_thread, packets::att}; use super::{ ffi::{InterceptAction, StoreCallbacksFromRust}, ids::{AdvertiserId, TransportIndex}, mtu::MtuEvent, opcode_types::{classify_opcode, OperationType}, server::isolation_manager::IsolationManager, }; static ARBITER: RwLock>>> = RwLock::new(None); /// Initialize the Arbiter pub fn initialize_arbiter() -> Arc> { let arbiter = Arc::new(Mutex::new(IsolationManager::new())); let mut lock = ARBITER.write().unwrap(); assert!(lock.is_none(), "Rust stack should only start up once"); *lock = Some(arbiter.clone()); StoreCallbacksFromRust( on_le_connect, on_le_disconnect, intercept_packet, |tcb_idx| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::OutgoingRequest), |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingResponse(mtu)), |tcb_idx, mtu| on_mtu_event(TransportIndex(tcb_idx), MtuEvent::IncomingRequest(mtu)), ); arbiter } /// Clean the Arbiter pub fn clean_arbiter() { let mut lock = ARBITER.write().unwrap(); *lock = None } /// Acquire the mutex holding the Arbiter and provide a mutable reference to the /// supplied closure pub fn with_arbiter(f: impl FnOnce(&mut IsolationManager) -> T) -> T { f(ARBITER.read().unwrap().as_ref().expect("Rust stack is not started").lock().as_mut().unwrap()) } /// Check if the Arbiter is initialized. pub fn has_arbiter() -> bool { ARBITER.read().unwrap().is_some() } /// Test to see if a buffer contains a valid ATT packet with an opcode we /// are interested in intercepting (those intended for servers that are isolated) fn try_parse_att_server_packet( isolation_manager: &IsolationManager, tcb_idx: TransportIndex, packet: &[u8], ) -> Option { isolation_manager.get_server_id(tcb_idx)?; let att = att::Att::decode_full(packet).ok()?; if att.opcode == att::AttOpcode::ExchangeMtuRequest { // special case: this server opcode is handled by legacy stack, and we snoop // on its handling, since the MTU is shared between the client + server return None; } match classify_opcode(att.opcode) { OperationType::Command | OperationType::Request | OperationType::Confirmation => Some(att), _ => None, } } fn on_le_connect(tcb_idx: u8, advertiser: u8) { let tcb_idx = TransportIndex(tcb_idx); let advertiser = AdvertiserId(advertiser); let is_isolated = with_arbiter(|arbiter| arbiter.is_advertiser_isolated(advertiser)); if is_isolated { do_in_rust_thread(move |modules| { if let Err(err) = modules.gatt_module.on_le_connect(tcb_idx, Some(advertiser)) { error!("{err:?}") } }) } } fn on_le_disconnect(tcb_idx: u8) { // Events may be received after a FactoryReset // is initiated for Bluetooth and the rust arbiter is taken // down. if !has_arbiter() { warn!("arbiter is not yet initialized"); return; } let tcb_idx = TransportIndex(tcb_idx); let was_isolated = with_arbiter(|arbiter| arbiter.is_connection_isolated(tcb_idx)); if was_isolated { do_in_rust_thread(move |modules| { if let Err(err) = modules.gatt_module.on_le_disconnect(tcb_idx) { error!("{err:?}") } }) } } fn intercept_packet(tcb_idx: u8, packet: Vec) -> InterceptAction { // Events may be received after a FactoryReset // is initiated for Bluetooth and the rust arbiter is taken // down. if !has_arbiter() { warn!("arbiter is not yet initialized"); return InterceptAction::Drop; } let tcb_idx = TransportIndex(tcb_idx); if let Some(att) = with_arbiter(|arbiter| try_parse_att_server_packet(arbiter, tcb_idx, &packet)) { do_in_rust_thread(move |modules| { trace!("pushing packet to GATT"); if let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) { bearer.handle_packet(att) } else { error!("Bearer for {tcb_idx:?} not found"); } }); InterceptAction::Drop } else { InterceptAction::Forward } } fn on_mtu_event(tcb_idx: TransportIndex, event: MtuEvent) { if with_arbiter(|arbiter| arbiter.is_connection_isolated(tcb_idx)) { do_in_rust_thread(move |modules| { let Some(bearer) = modules.gatt_module.get_bearer(tcb_idx) else { error!("Bearer for {tcb_idx:?} not found"); return; }; if let Err(err) = bearer.handle_mtu_event(event) { error!("{err:?}") } }); } } #[cfg(test)] mod test { use super::*; use crate::{ gatt::ids::{AttHandle, ServerId}, packets::att, }; const TCB_IDX: TransportIndex = TransportIndex(1); const ADVERTISER_ID: AdvertiserId = AdvertiserId(3); const SERVER_ID: ServerId = ServerId(4); fn create_manager_with_isolated_connection( tcb_idx: TransportIndex, server_id: ServerId, ) -> IsolationManager { let mut isolation_manager = IsolationManager::new(); isolation_manager.associate_server_with_advertiser(server_id, ADVERTISER_ID); isolation_manager.on_le_connect(tcb_idx, Some(ADVERTISER_ID)); isolation_manager } #[test] fn test_packet_capture_when_isolated() { let isolation_manager = create_manager_with_isolated_connection(TCB_IDX, SERVER_ID); let packet = att::AttReadRequest { attribute_handle: AttHandle(1).into() }; let out = try_parse_att_server_packet( &isolation_manager, TCB_IDX, &packet.encode_to_vec().unwrap(), ); assert!(out.is_some()); } #[test] fn test_packet_bypass_when_isolated() { let isolation_manager = create_manager_with_isolated_connection(TCB_IDX, SERVER_ID); let packet = att::AttErrorResponse { opcode_in_error: att::AttOpcode::ReadResponse, handle_in_error: AttHandle(1).into(), error_code: att::AttErrorCode::InvalidHandle, }; let out = try_parse_att_server_packet( &isolation_manager, TCB_IDX, &packet.encode_to_vec().unwrap(), ); assert!(out.is_none()); } #[test] fn test_mtu_bypass() { let isolation_manager = create_manager_with_isolated_connection(TCB_IDX, SERVER_ID); let packet = att::AttExchangeMtuRequest { mtu: 64 }; let out = try_parse_att_server_packet( &isolation_manager, TCB_IDX, &packet.encode_to_vec().unwrap(), ); assert!(out.is_none()); } #[test] fn test_packet_bypass_when_not_isolated() { let isolation_manager = IsolationManager::new(); let packet = att::AttReadRequest { attribute_handle: AttHandle(1).into() }; let out = try_parse_att_server_packet( &isolation_manager, TCB_IDX, &packet.encode_to_vec().unwrap(), ); assert!(out.is_none()); } }