1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Keys and Key Storage 17# 18# ----------------------------------------------------------------------------- 19 20# ----------------------------------------------------------------------------- 21# Imports 22# ----------------------------------------------------------------------------- 23from __future__ import annotations 24import asyncio 25import logging 26import os 27import json 28from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type 29from typing_extensions import Self 30 31from .colors import color 32from .hci import Address 33 34if TYPE_CHECKING: 35 from .device import Device 36 37 38# ----------------------------------------------------------------------------- 39# Logging 40# ----------------------------------------------------------------------------- 41logger = logging.getLogger(__name__) 42 43 44# ----------------------------------------------------------------------------- 45class PairingKeys: 46 class Key: 47 def __init__(self, value, authenticated=False, ediv=None, rand=None): 48 self.value = value 49 self.authenticated = authenticated 50 self.ediv = ediv 51 self.rand = rand 52 53 @classmethod 54 def from_dict(cls, key_dict): 55 value = bytes.fromhex(key_dict['value']) 56 authenticated = key_dict.get('authenticated', False) 57 ediv = key_dict.get('ediv') 58 rand = key_dict.get('rand') 59 if rand is not None: 60 rand = bytes.fromhex(rand) 61 62 return cls(value, authenticated, ediv, rand) 63 64 def to_dict(self): 65 key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated} 66 if self.ediv is not None: 67 key_dict['ediv'] = self.ediv 68 if self.rand is not None: 69 key_dict['rand'] = self.rand.hex() 70 71 return key_dict 72 73 def __init__(self): 74 self.address_type = None 75 self.ltk = None 76 self.ltk_central = None 77 self.ltk_peripheral = None 78 self.irk = None 79 self.csrk = None 80 self.link_key = None # Classic 81 82 @staticmethod 83 def key_from_dict(keys_dict, key_name): 84 key_dict = keys_dict.get(key_name) 85 if key_dict is None: 86 return None 87 88 return PairingKeys.Key.from_dict(key_dict) 89 90 @staticmethod 91 def from_dict(keys_dict): 92 keys = PairingKeys() 93 94 keys.address_type = keys_dict.get('address_type') 95 keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk') 96 keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central') 97 keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral') 98 keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk') 99 keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk') 100 keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key') 101 102 return keys 103 104 def to_dict(self): 105 keys = {} 106 107 if self.address_type is not None: 108 keys['address_type'] = self.address_type 109 110 if self.ltk is not None: 111 keys['ltk'] = self.ltk.to_dict() 112 113 if self.ltk_central is not None: 114 keys['ltk_central'] = self.ltk_central.to_dict() 115 116 if self.ltk_peripheral is not None: 117 keys['ltk_peripheral'] = self.ltk_peripheral.to_dict() 118 119 if self.irk is not None: 120 keys['irk'] = self.irk.to_dict() 121 122 if self.csrk is not None: 123 keys['csrk'] = self.csrk.to_dict() 124 125 if self.link_key is not None: 126 keys['link_key'] = self.link_key.to_dict() 127 128 return keys 129 130 def print(self, prefix=''): 131 keys_dict = self.to_dict() 132 for container_property, value in keys_dict.items(): 133 if isinstance(value, dict): 134 print(f'{prefix}{color(container_property, "cyan")}:') 135 for key_property, key_value in value.items(): 136 print(f'{prefix} {color(key_property, "green")}: {key_value}') 137 else: 138 print(f'{prefix}{color(container_property, "cyan")}: {value}') 139 140 141# ----------------------------------------------------------------------------- 142class KeyStore: 143 async def delete(self, name: str): 144 pass 145 146 async def update(self, name: str, keys: PairingKeys) -> None: 147 pass 148 149 async def get(self, _name: str) -> Optional[PairingKeys]: 150 return None 151 152 async def get_all(self) -> List[Tuple[str, PairingKeys]]: 153 return [] 154 155 async def delete_all(self) -> None: 156 all_keys = await self.get_all() 157 await asyncio.gather(*(self.delete(name) for (name, _) in all_keys)) 158 159 async def get_resolving_keys(self): 160 all_keys = await self.get_all() 161 resolving_keys = [] 162 for name, keys in all_keys: 163 if keys.irk is not None: 164 if keys.address_type is None: 165 address_type = Address.RANDOM_DEVICE_ADDRESS 166 else: 167 address_type = keys.address_type 168 resolving_keys.append((keys.irk.value, Address(name, address_type))) 169 170 return resolving_keys 171 172 async def print(self, prefix=''): 173 entries = await self.get_all() 174 separator = '' 175 for name, keys in entries: 176 print(separator + prefix + color(name, 'yellow')) 177 keys.print(prefix=prefix + ' ') 178 separator = '\n' 179 180 @staticmethod 181 def create_for_device(device: Device) -> KeyStore: 182 if device.config.keystore is None: 183 return MemoryKeyStore() 184 185 keystore_type = device.config.keystore.split(':', 1)[0] 186 if keystore_type == 'JsonKeyStore': 187 return JsonKeyStore.from_device(device) 188 189 return MemoryKeyStore() 190 191 192# ----------------------------------------------------------------------------- 193class JsonKeyStore(KeyStore): 194 """ 195 KeyStore implementation that is backed by a JSON file. 196 197 This implementation supports storing a hierarchy of key sets in a single file. 198 A key set is a representation of a PairingKeys object. Each key set is stored 199 in a map, with the address of paired peer as the key. Maps are themselves grouped 200 into namespaces, grouping pairing keys by controller addresses. 201 The JSON object model looks like: 202 { 203 "<namespace>": { 204 "peer-address": { 205 "address_type": <n>, 206 "irk" : { 207 "authenticated": <true/false>, 208 "value": "hex-encoded-key" 209 }, 210 ... other keys ... 211 }, 212 ... other peers ... 213 } 214 ... other namespaces ... 215 } 216 217 A namespace is typically the BD_ADDR of a controller, since that is a convenient 218 unique identifier, but it may be something else. 219 A special namespace, called the "default" namespace, is used when instantiating this 220 class without a namespace. With the default namespace, reading from a file will 221 load an existing namespace if there is only one, which may be convenient for reading 222 from a file with a single key set and for which the namespace isn't known. If the 223 file does not include any existing key set, or if there are more than one and none 224 has the default name, a new one will be created with the name "__DEFAULT__". 225 """ 226 227 APP_NAME = 'Bumble' 228 APP_AUTHOR = 'Google' 229 KEYS_DIR = 'Pairing' 230 DEFAULT_NAMESPACE = '__DEFAULT__' 231 DEFAULT_BASE_NAME = "keys" 232 233 def __init__(self, namespace, filename=None): 234 self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE 235 236 if filename is None: 237 # Use a default for the current user 238 239 # Import here because this may not exist on all platforms 240 # pylint: disable=import-outside-toplevel 241 import appdirs 242 243 self.directory_name = os.path.join( 244 appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR 245 ) 246 base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace 247 json_filename = ( 248 f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p') 249 ) 250 self.filename = os.path.join(self.directory_name, json_filename) 251 else: 252 self.filename = filename 253 self.directory_name = os.path.dirname(os.path.abspath(self.filename)) 254 255 logger.debug(f'JSON keystore: {self.filename}') 256 257 @classmethod 258 def from_device( 259 cls: Type[Self], device: Device, filename: Optional[str] = None 260 ) -> Self: 261 if not filename: 262 # Extract the filename from the config if there is one 263 if device.config.keystore is not None: 264 params = device.config.keystore.split(':', 1)[1:] 265 if params: 266 filename = params[0] 267 268 # Use a namespace based on the device address 269 if device.public_address not in (Address.ANY, Address.ANY_RANDOM): 270 namespace = str(device.public_address) 271 elif device.random_address != Address.ANY_RANDOM: 272 namespace = str(device.random_address) 273 else: 274 namespace = JsonKeyStore.DEFAULT_NAMESPACE 275 276 return cls(namespace, filename) 277 278 async def load(self): 279 # Try to open the file, without failing. If the file does not exist, it 280 # will be created upon saving. 281 try: 282 with open(self.filename, 'r', encoding='utf-8') as json_file: 283 db = json.load(json_file) 284 except FileNotFoundError: 285 db = {} 286 287 # First, look for a namespace match 288 if self.namespace in db: 289 return (db, db[self.namespace]) 290 291 # Then, if the namespace is the default namespace, and there's 292 # only one entry in the db, use that 293 if self.namespace == self.DEFAULT_NAMESPACE and len(db) == 1: 294 return next(iter(db.items())) 295 296 # Finally, just create an empty key map for the namespace 297 key_map = {} 298 db[self.namespace] = key_map 299 return (db, key_map) 300 301 async def save(self, db): 302 # Create the directory if it doesn't exist 303 if not os.path.exists(self.directory_name): 304 os.makedirs(self.directory_name, exist_ok=True) 305 306 # Save to a temporary file 307 temp_filename = self.filename + '.tmp' 308 with open(temp_filename, 'w', encoding='utf-8') as output: 309 json.dump(db, output, sort_keys=True, indent=4) 310 311 # Atomically replace the previous file 312 os.replace(temp_filename, self.filename) 313 314 async def delete(self, name: str) -> None: 315 db, key_map = await self.load() 316 del key_map[name] 317 await self.save(db) 318 319 async def update(self, name, keys): 320 db, key_map = await self.load() 321 key_map.setdefault(name, {}).update(keys.to_dict()) 322 await self.save(db) 323 324 async def get_all(self): 325 _, key_map = await self.load() 326 return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()] 327 328 async def delete_all(self): 329 db, key_map = await self.load() 330 key_map.clear() 331 await self.save(db) 332 333 async def get(self, name: str) -> Optional[PairingKeys]: 334 _, key_map = await self.load() 335 if name not in key_map: 336 return None 337 338 return PairingKeys.from_dict(key_map[name]) 339 340 341# ----------------------------------------------------------------------------- 342class MemoryKeyStore(KeyStore): 343 all_keys: Dict[str, PairingKeys] 344 345 def __init__(self) -> None: 346 self.all_keys = {} 347 348 async def delete(self, name: str) -> None: 349 if name in self.all_keys: 350 del self.all_keys[name] 351 352 async def update(self, name: str, keys: PairingKeys) -> None: 353 self.all_keys[name] = keys 354 355 async def get(self, name: str) -> Optional[PairingKeys]: 356 return self.all_keys.get(name) 357 358 async def get_all(self) -> List[Tuple[str, PairingKeys]]: 359 return list(self.all_keys.items()) 360