1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import annotations
16import contextlib
17import functools
18import grpc
19import inspect
20import logging
21
22from bumble.device import Device
23from bumble.hci import Address
24from google.protobuf.message import Message  # pytype: disable=pyi-error
25from typing import Any, Dict, Generator, MutableMapping, Optional, Tuple
26
27ADDRESS_TYPES: Dict[str, int] = {
28    "public": Address.PUBLIC_DEVICE_ADDRESS,
29    "random": Address.RANDOM_DEVICE_ADDRESS,
30    "public_identity": Address.PUBLIC_IDENTITY_ADDRESS,
31    "random_static_identity": Address.RANDOM_IDENTITY_ADDRESS,
32}
33
34
35def address_from_request(request: Message, field: Optional[str]) -> Address:
36    if field is None:
37        return Address.ANY
38    return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field])
39
40
41class BumbleServerLoggerAdapter(logging.LoggerAdapter):  # type: ignore
42    """Formats logs from the PandoraClient."""
43
44    def process(
45        self, msg: str, kwargs: MutableMapping[str, Any]
46    ) -> Tuple[str, MutableMapping[str, Any]]:
47        assert self.extra
48        service_name = self.extra['service_name']
49        assert isinstance(service_name, str)
50        device = self.extra['device']
51        assert isinstance(device, Device)
52        addr_bytes = bytes(
53            reversed(bytes(device.public_address))
54        )  # pytype: disable=attribute-error
55        addr = ':'.join([f'{x:02X}' for x in addr_bytes[4:]])
56        return (f'[bumble.{service_name}:{addr}] {msg}', kwargs)
57
58
59@contextlib.contextmanager
60def exception_to_rpc_error(
61    context: grpc.ServicerContext,
62) -> Generator[None, None, None]:
63    try:
64        yield None
65    except NotImplementedError as e:
66        context.set_code(grpc.StatusCode.UNIMPLEMENTED)  # type: ignore
67        context.set_details(str(e))  # type: ignore
68    except ValueError as e:
69        context.set_code(grpc.StatusCode.INVALID_ARGUMENT)  # type: ignore
70        context.set_details(str(e))  # type: ignore
71    except RuntimeError as e:
72        context.set_code(grpc.StatusCode.ABORTED)  # type: ignore
73        context.set_details(str(e))  # type: ignore
74
75
76# Decorate an RPC servicer method with a wrapper that transform exceptions to gRPC errors.
77def rpc(func: Any) -> Any:
78    @functools.wraps(func)
79    async def asyncgen_wrapper(
80        self: Any, request: Any, context: grpc.ServicerContext
81    ) -> Any:
82        with exception_to_rpc_error(context):
83            async for v in func(self, request, context):
84                yield v
85
86    @functools.wraps(func)
87    async def async_wrapper(
88        self: Any, request: Any, context: grpc.ServicerContext
89    ) -> Any:
90        with exception_to_rpc_error(context):
91            return await func(self, request, context)
92
93    @functools.wraps(func)
94    def gen_wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
95        with exception_to_rpc_error(context):
96            for v in func(self, request, context):
97                yield v
98
99    @functools.wraps(func)
100    def wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
101        with exception_to_rpc_error(context):
102            return func(self, request, context)
103
104    if inspect.isasyncgenfunction(func):
105        return asyncgen_wrapper
106
107    if inspect.iscoroutinefunction(func):
108        return async_wrapper
109
110    if inspect.isgenerator(func):
111        return gen_wrapper
112
113    return wrapper
114