1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Keys and Key Storage
17#
18# -----------------------------------------------------------------------------
19
20# -----------------------------------------------------------------------------
21# Imports
22# -----------------------------------------------------------------------------
23from __future__ import annotations
24import asyncio
25import logging
26import os
27import json
28from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
29from typing_extensions import Self
30
31from .colors import color
32from .hci import Address
33
34if TYPE_CHECKING:
35    from .device import Device
36
37
38# -----------------------------------------------------------------------------
39# Logging
40# -----------------------------------------------------------------------------
41logger = logging.getLogger(__name__)
42
43
44# -----------------------------------------------------------------------------
45class PairingKeys:
46    class Key:
47        def __init__(self, value, authenticated=False, ediv=None, rand=None):
48            self.value = value
49            self.authenticated = authenticated
50            self.ediv = ediv
51            self.rand = rand
52
53        @classmethod
54        def from_dict(cls, key_dict):
55            value = bytes.fromhex(key_dict['value'])
56            authenticated = key_dict.get('authenticated', False)
57            ediv = key_dict.get('ediv')
58            rand = key_dict.get('rand')
59            if rand is not None:
60                rand = bytes.fromhex(rand)
61
62            return cls(value, authenticated, ediv, rand)
63
64        def to_dict(self):
65            key_dict = {'value': self.value.hex(), 'authenticated': self.authenticated}
66            if self.ediv is not None:
67                key_dict['ediv'] = self.ediv
68            if self.rand is not None:
69                key_dict['rand'] = self.rand.hex()
70
71            return key_dict
72
73    def __init__(self):
74        self.address_type = None
75        self.ltk = None
76        self.ltk_central = None
77        self.ltk_peripheral = None
78        self.irk = None
79        self.csrk = None
80        self.link_key = None  # Classic
81
82    @staticmethod
83    def key_from_dict(keys_dict, key_name):
84        key_dict = keys_dict.get(key_name)
85        if key_dict is None:
86            return None
87
88        return PairingKeys.Key.from_dict(key_dict)
89
90    @staticmethod
91    def from_dict(keys_dict):
92        keys = PairingKeys()
93
94        keys.address_type = keys_dict.get('address_type')
95        keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk')
96        keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
97        keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral')
98        keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
99        keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
100        keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
101
102        return keys
103
104    def to_dict(self):
105        keys = {}
106
107        if self.address_type is not None:
108            keys['address_type'] = self.address_type
109
110        if self.ltk is not None:
111            keys['ltk'] = self.ltk.to_dict()
112
113        if self.ltk_central is not None:
114            keys['ltk_central'] = self.ltk_central.to_dict()
115
116        if self.ltk_peripheral is not None:
117            keys['ltk_peripheral'] = self.ltk_peripheral.to_dict()
118
119        if self.irk is not None:
120            keys['irk'] = self.irk.to_dict()
121
122        if self.csrk is not None:
123            keys['csrk'] = self.csrk.to_dict()
124
125        if self.link_key is not None:
126            keys['link_key'] = self.link_key.to_dict()
127
128        return keys
129
130    def print(self, prefix=''):
131        keys_dict = self.to_dict()
132        for container_property, value in keys_dict.items():
133            if isinstance(value, dict):
134                print(f'{prefix}{color(container_property, "cyan")}:')
135                for key_property, key_value in value.items():
136                    print(f'{prefix}  {color(key_property, "green")}: {key_value}')
137            else:
138                print(f'{prefix}{color(container_property, "cyan")}: {value}')
139
140
141# -----------------------------------------------------------------------------
142class KeyStore:
143    async def delete(self, name: str):
144        pass
145
146    async def update(self, name: str, keys: PairingKeys) -> None:
147        pass
148
149    async def get(self, _name: str) -> Optional[PairingKeys]:
150        return None
151
152    async def get_all(self) -> List[Tuple[str, PairingKeys]]:
153        return []
154
155    async def delete_all(self) -> None:
156        all_keys = await self.get_all()
157        await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
158
159    async def get_resolving_keys(self):
160        all_keys = await self.get_all()
161        resolving_keys = []
162        for name, keys in all_keys:
163            if keys.irk is not None:
164                if keys.address_type is None:
165                    address_type = Address.RANDOM_DEVICE_ADDRESS
166                else:
167                    address_type = keys.address_type
168                resolving_keys.append((keys.irk.value, Address(name, address_type)))
169
170        return resolving_keys
171
172    async def print(self, prefix=''):
173        entries = await self.get_all()
174        separator = ''
175        for name, keys in entries:
176            print(separator + prefix + color(name, 'yellow'))
177            keys.print(prefix=prefix + '  ')
178            separator = '\n'
179
180    @staticmethod
181    def create_for_device(device: Device) -> KeyStore:
182        if device.config.keystore is None:
183            return MemoryKeyStore()
184
185        keystore_type = device.config.keystore.split(':', 1)[0]
186        if keystore_type == 'JsonKeyStore':
187            return JsonKeyStore.from_device(device)
188
189        return MemoryKeyStore()
190
191
192# -----------------------------------------------------------------------------
193class JsonKeyStore(KeyStore):
194    """
195    KeyStore implementation that is backed by a JSON file.
196
197    This implementation supports storing a hierarchy of key sets in a single file.
198    A key set is a representation of a PairingKeys object. Each key set is stored
199    in a map, with the address of paired peer as the key. Maps are themselves grouped
200    into namespaces, grouping pairing keys by controller addresses.
201    The JSON object model looks like:
202    {
203        "<namespace>": {
204            "peer-address": {
205                "address_type": <n>,
206                "irk" : {
207                    "authenticated": <true/false>,
208                    "value": "hex-encoded-key"
209                },
210                ... other keys ...
211            },
212            ... other peers ...
213        }
214        ... other namespaces ...
215    }
216
217    A namespace is typically the BD_ADDR of a controller, since that is a convenient
218    unique identifier, but it may be something else.
219    A special namespace, called the "default" namespace, is used when instantiating this
220    class without a namespace. With the default namespace, reading from a file will
221    load an existing namespace if there is only one, which may be convenient for reading
222    from a file with a single key set and for which the namespace isn't known. If the
223    file does not include any existing key set, or if there are more than one and none
224    has the default name, a new one will be created with the name "__DEFAULT__".
225    """
226
227    APP_NAME = 'Bumble'
228    APP_AUTHOR = 'Google'
229    KEYS_DIR = 'Pairing'
230    DEFAULT_NAMESPACE = '__DEFAULT__'
231    DEFAULT_BASE_NAME = "keys"
232
233    def __init__(self, namespace, filename=None):
234        self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
235
236        if filename is None:
237            # Use a default for the current user
238
239            # Import here because this may not exist on all platforms
240            # pylint: disable=import-outside-toplevel
241            import appdirs
242
243            self.directory_name = os.path.join(
244                appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
245            )
246            base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
247            json_filename = (
248                f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
249            )
250            self.filename = os.path.join(self.directory_name, json_filename)
251        else:
252            self.filename = filename
253            self.directory_name = os.path.dirname(os.path.abspath(self.filename))
254
255        logger.debug(f'JSON keystore: {self.filename}')
256
257    @classmethod
258    def from_device(
259        cls: Type[Self], device: Device, filename: Optional[str] = None
260    ) -> Self:
261        if not filename:
262            # Extract the filename from the config if there is one
263            if device.config.keystore is not None:
264                params = device.config.keystore.split(':', 1)[1:]
265                if params:
266                    filename = params[0]
267
268        # Use a namespace based on the device address
269        if device.public_address not in (Address.ANY, Address.ANY_RANDOM):
270            namespace = str(device.public_address)
271        elif device.random_address != Address.ANY_RANDOM:
272            namespace = str(device.random_address)
273        else:
274            namespace = JsonKeyStore.DEFAULT_NAMESPACE
275
276        return cls(namespace, filename)
277
278    async def load(self):
279        # Try to open the file, without failing. If the file does not exist, it
280        # will be created upon saving.
281        try:
282            with open(self.filename, 'r', encoding='utf-8') as json_file:
283                db = json.load(json_file)
284        except FileNotFoundError:
285            db = {}
286
287        # First, look for a namespace match
288        if self.namespace in db:
289            return (db, db[self.namespace])
290
291        # Then, if the namespace is the default namespace, and there's
292        # only one entry in the db, use that
293        if self.namespace == self.DEFAULT_NAMESPACE and len(db) == 1:
294            return next(iter(db.items()))
295
296        # Finally, just create an empty key map for the namespace
297        key_map = {}
298        db[self.namespace] = key_map
299        return (db, key_map)
300
301    async def save(self, db):
302        # Create the directory if it doesn't exist
303        if not os.path.exists(self.directory_name):
304            os.makedirs(self.directory_name, exist_ok=True)
305
306        # Save to a temporary file
307        temp_filename = self.filename + '.tmp'
308        with open(temp_filename, 'w', encoding='utf-8') as output:
309            json.dump(db, output, sort_keys=True, indent=4)
310
311        # Atomically replace the previous file
312        os.replace(temp_filename, self.filename)
313
314    async def delete(self, name: str) -> None:
315        db, key_map = await self.load()
316        del key_map[name]
317        await self.save(db)
318
319    async def update(self, name, keys):
320        db, key_map = await self.load()
321        key_map.setdefault(name, {}).update(keys.to_dict())
322        await self.save(db)
323
324    async def get_all(self):
325        _, key_map = await self.load()
326        return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
327
328    async def delete_all(self):
329        db, key_map = await self.load()
330        key_map.clear()
331        await self.save(db)
332
333    async def get(self, name: str) -> Optional[PairingKeys]:
334        _, key_map = await self.load()
335        if name not in key_map:
336            return None
337
338        return PairingKeys.from_dict(key_map[name])
339
340
341# -----------------------------------------------------------------------------
342class MemoryKeyStore(KeyStore):
343    all_keys: Dict[str, PairingKeys]
344
345    def __init__(self) -> None:
346        self.all_keys = {}
347
348    async def delete(self, name: str) -> None:
349        if name in self.all_keys:
350            del self.all_keys[name]
351
352    async def update(self, name: str, keys: PairingKeys) -> None:
353        self.all_keys[name] = keys
354
355    async def get(self, name: str) -> Optional[PairingKeys]:
356        return self.all_keys.get(name)
357
358    async def get_all(self) -> List[Tuple[str, PairingKeys]]:
359        return list(self.all_keys.items())
360