1# Copyright 2021-2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import asyncio
20from dataclasses import dataclass
21import enum
22import logging
23import struct
24from typing import (
25    AsyncIterator,
26    Awaitable,
27    Callable,
28    cast,
29    Dict,
30    Iterable,
31    List,
32    Optional,
33    Sequence,
34    SupportsBytes,
35    Tuple,
36    Type,
37    TypeVar,
38    Union,
39)
40
41import pyee
42
43from bumble.colors import color
44from bumble.device import Device, Connection
45from bumble.sdp import (
46    SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
47    SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
48    SDP_PUBLIC_BROWSE_ROOT,
49    SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
50    SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
51    SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
52    SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
53    DataElement,
54    ServiceAttribute,
55)
56from bumble.utils import AsyncRunner, OpenIntEnum
57from bumble.core import (
58    InvalidArgumentError,
59    ProtocolError,
60    BT_L2CAP_PROTOCOL_ID,
61    BT_AVCTP_PROTOCOL_ID,
62    BT_AV_REMOTE_CONTROL_SERVICE,
63    BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE,
64    BT_AV_REMOTE_CONTROL_TARGET_SERVICE,
65)
66from bumble import l2cap
67from bumble import avc
68from bumble import avctp
69from bumble import utils
70
71
72# -----------------------------------------------------------------------------
73# Logging
74# -----------------------------------------------------------------------------
75logger = logging.getLogger(__name__)
76
77
78# -----------------------------------------------------------------------------
79# Constants
80# -----------------------------------------------------------------------------
81AVRCP_PID = 0x110E
82AVRCP_BLUETOOTH_SIG_COMPANY_ID = 0x001958
83
84
85# -----------------------------------------------------------------------------
86def make_controller_service_sdp_records(
87    service_record_handle: int,
88    avctp_version: Tuple[int, int] = (1, 4),
89    avrcp_version: Tuple[int, int] = (1, 6),
90    supported_features: int = 1,
91) -> List[ServiceAttribute]:
92    # TODO: support a way to compute the supported features from a feature list
93    avctp_version_int = avctp_version[0] << 8 | avctp_version[1]
94    avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1]
95
96    return [
97        ServiceAttribute(
98            SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
99            DataElement.unsigned_integer_32(service_record_handle),
100        ),
101        ServiceAttribute(
102            SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
103            DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
104        ),
105        ServiceAttribute(
106            SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
107            DataElement.sequence(
108                [
109                    DataElement.uuid(BT_AV_REMOTE_CONTROL_SERVICE),
110                    DataElement.uuid(BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE),
111                ]
112            ),
113        ),
114        ServiceAttribute(
115            SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
116            DataElement.sequence(
117                [
118                    DataElement.sequence(
119                        [
120                            DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
121                            DataElement.unsigned_integer_16(avctp.AVCTP_PSM),
122                        ]
123                    ),
124                    DataElement.sequence(
125                        [
126                            DataElement.uuid(BT_AVCTP_PROTOCOL_ID),
127                            DataElement.unsigned_integer_16(avctp_version_int),
128                        ]
129                    ),
130                ]
131            ),
132        ),
133        ServiceAttribute(
134            SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
135            DataElement.sequence(
136                [
137                    DataElement.sequence(
138                        [
139                            DataElement.uuid(BT_AV_REMOTE_CONTROL_SERVICE),
140                            DataElement.unsigned_integer_16(avrcp_version_int),
141                        ]
142                    ),
143                ]
144            ),
145        ),
146        ServiceAttribute(
147            SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
148            DataElement.unsigned_integer_16(supported_features),
149        ),
150    ]
151
152
153# -----------------------------------------------------------------------------
154def make_target_service_sdp_records(
155    service_record_handle: int,
156    avctp_version: Tuple[int, int] = (1, 4),
157    avrcp_version: Tuple[int, int] = (1, 6),
158    supported_features: int = 0x23,
159) -> List[ServiceAttribute]:
160    # TODO: support a way to compute the supported features from a feature list
161    avctp_version_int = avctp_version[0] << 8 | avctp_version[1]
162    avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1]
163
164    return [
165        ServiceAttribute(
166            SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
167            DataElement.unsigned_integer_32(service_record_handle),
168        ),
169        ServiceAttribute(
170            SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
171            DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
172        ),
173        ServiceAttribute(
174            SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
175            DataElement.sequence(
176                [
177                    DataElement.uuid(BT_AV_REMOTE_CONTROL_TARGET_SERVICE),
178                ]
179            ),
180        ),
181        ServiceAttribute(
182            SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
183            DataElement.sequence(
184                [
185                    DataElement.sequence(
186                        [
187                            DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
188                            DataElement.unsigned_integer_16(avctp.AVCTP_PSM),
189                        ]
190                    ),
191                    DataElement.sequence(
192                        [
193                            DataElement.uuid(BT_AVCTP_PROTOCOL_ID),
194                            DataElement.unsigned_integer_16(avctp_version_int),
195                        ]
196                    ),
197                ]
198            ),
199        ),
200        ServiceAttribute(
201            SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
202            DataElement.sequence(
203                [
204                    DataElement.sequence(
205                        [
206                            DataElement.uuid(BT_AV_REMOTE_CONTROL_SERVICE),
207                            DataElement.unsigned_integer_16(avrcp_version_int),
208                        ]
209                    ),
210                ]
211            ),
212        ),
213        ServiceAttribute(
214            SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
215            DataElement.unsigned_integer_16(supported_features),
216        ),
217    ]
218
219
220# -----------------------------------------------------------------------------
221def _decode_attribute_value(value: bytes, character_set: CharacterSetId) -> str:
222    try:
223        if character_set == CharacterSetId.UTF_8:
224            return value.decode("utf-8")
225        return value.decode("ascii")
226    except UnicodeDecodeError:
227        logger.warning(f"cannot decode string with bytes: {value.hex()}")
228        return ""
229
230
231# -----------------------------------------------------------------------------
232class PduAssembler:
233    """
234    PDU Assembler to support fragmented PDUs are defined in:
235    Audio/Video Remote Control / Profile Specification
236    6.3.1 AVRCP specific AV//C commands
237    """
238
239    pdu_id: Optional[Protocol.PduId]
240    payload: bytes
241
242    def __init__(self, callback: Callable[[Protocol.PduId, bytes], None]) -> None:
243        self.callback = callback
244        self.reset()
245
246    def reset(self) -> None:
247        self.pdu_id = None
248        self.parameter = b''
249
250    def on_pdu(self, pdu: bytes) -> None:
251        pdu_id = Protocol.PduId(pdu[0])
252        packet_type = Protocol.PacketType(pdu[1] & 3)
253        parameter_length = struct.unpack_from('>H', pdu, 2)[0]
254        parameter = pdu[4 : 4 + parameter_length]
255        if len(parameter) != parameter_length:
256            logger.warning("parameter length exceeds pdu size")
257            self.reset()
258            return
259
260        if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
261            if self.pdu_id is not None:
262                # We are already in a PDU
263                logger.warning("received START or SINGLE fragment while in pdu")
264                self.reset()
265
266        if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
267            if pdu_id != self.pdu_id:
268                logger.warning("PID does not match")
269                self.reset()
270                return
271        else:
272            self.pdu_id = pdu_id
273
274        self.parameter += parameter
275
276        if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
277            self.on_pdu_complete()
278
279    def on_pdu_complete(self) -> None:
280        assert self.pdu_id is not None
281        try:
282            self.callback(self.pdu_id, self.parameter)
283        except Exception as error:
284            logger.exception(color(f'!!! exception in callback: {error}', 'red'))
285
286        self.reset()
287
288
289# -----------------------------------------------------------------------------
290@dataclass
291class Command:
292    pdu_id: Protocol.PduId
293    parameter: bytes
294
295    def to_string(self, properties: Dict[str, str]) -> str:
296        properties_str = ",".join(
297            [f"{name}={value}" for name, value in properties.items()]
298        )
299        return f"Command[{self.pdu_id.name}]({properties_str})"
300
301    def __str__(self) -> str:
302        return self.to_string({"parameters": self.parameter.hex()})
303
304    def __repr__(self) -> str:
305        return str(self)
306
307
308# -----------------------------------------------------------------------------
309class GetCapabilitiesCommand(Command):
310    class CapabilityId(OpenIntEnum):
311        COMPANY_ID = 0x02
312        EVENTS_SUPPORTED = 0x03
313
314    capability_id: CapabilityId
315
316    @classmethod
317    def from_bytes(cls, pdu: bytes) -> GetCapabilitiesCommand:
318        return cls(cls.CapabilityId(pdu[0]))
319
320    def __init__(self, capability_id: CapabilityId) -> None:
321        super().__init__(Protocol.PduId.GET_CAPABILITIES, bytes([capability_id]))
322        self.capability_id = capability_id
323
324    def __str__(self) -> str:
325        return self.to_string({"capability_id": self.capability_id.name})
326
327
328# -----------------------------------------------------------------------------
329class GetPlayStatusCommand(Command):
330    @classmethod
331    def from_bytes(cls, _: bytes) -> GetPlayStatusCommand:
332        return cls()
333
334    def __init__(self) -> None:
335        super().__init__(Protocol.PduId.GET_PLAY_STATUS, b'')
336
337
338# -----------------------------------------------------------------------------
339class GetElementAttributesCommand(Command):
340    identifier: int
341    attribute_ids: List[MediaAttributeId]
342
343    @classmethod
344    def from_bytes(cls, pdu: bytes) -> GetElementAttributesCommand:
345        identifier = struct.unpack_from(">Q", pdu)[0]
346        num_attributes = pdu[8]
347        attribute_ids = [MediaAttributeId(pdu[9 + i]) for i in range(num_attributes)]
348        return cls(identifier, attribute_ids)
349
350    def __init__(
351        self, identifier: int, attribute_ids: Sequence[MediaAttributeId]
352    ) -> None:
353        parameter = struct.pack(">QB", identifier, len(attribute_ids)) + b''.join(
354            [struct.pack(">I", int(attribute_id)) for attribute_id in attribute_ids]
355        )
356        super().__init__(Protocol.PduId.GET_ELEMENT_ATTRIBUTES, parameter)
357        self.identifier = identifier
358        self.attribute_ids = list(attribute_ids)
359
360
361# -----------------------------------------------------------------------------
362class SetAbsoluteVolumeCommand(Command):
363    MAXIMUM_VOLUME = 0x7F
364
365    volume: int
366
367    @classmethod
368    def from_bytes(cls, pdu: bytes) -> SetAbsoluteVolumeCommand:
369        return cls(pdu[0])
370
371    def __init__(self, volume: int) -> None:
372        super().__init__(Protocol.PduId.SET_ABSOLUTE_VOLUME, bytes([volume]))
373        self.volume = volume
374
375    def __str__(self) -> str:
376        return self.to_string({"volume": str(self.volume)})
377
378
379# -----------------------------------------------------------------------------
380class RegisterNotificationCommand(Command):
381    event_id: EventId
382    playback_interval: int
383
384    @classmethod
385    def from_bytes(cls, pdu: bytes) -> RegisterNotificationCommand:
386        event_id = EventId(pdu[0])
387        playback_interval = struct.unpack_from(">I", pdu, 1)[0]
388        return cls(event_id, playback_interval)
389
390    def __init__(self, event_id: EventId, playback_interval: int) -> None:
391        super().__init__(
392            Protocol.PduId.REGISTER_NOTIFICATION,
393            struct.pack(">BI", int(event_id), playback_interval),
394        )
395        self.event_id = event_id
396        self.playback_interval = playback_interval
397
398    def __str__(self) -> str:
399        return self.to_string(
400            {
401                "event_id": self.event_id.name,
402                "playback_interval": str(self.playback_interval),
403            }
404        )
405
406
407# -----------------------------------------------------------------------------
408@dataclass
409class Response:
410    pdu_id: Protocol.PduId
411    parameter: bytes
412
413    def to_string(self, properties: Dict[str, str]) -> str:
414        properties_str = ",".join(
415            [f"{name}={value}" for name, value in properties.items()]
416        )
417        return f"Response[{self.pdu_id.name}]({properties_str})"
418
419    def __str__(self) -> str:
420        return self.to_string({"parameter": self.parameter.hex()})
421
422    def __repr__(self) -> str:
423        return str(self)
424
425
426# -----------------------------------------------------------------------------
427class RejectedResponse(Response):
428    status_code: Protocol.StatusCode
429
430    @classmethod
431    def from_bytes(cls, pdu_id: Protocol.PduId, pdu: bytes) -> RejectedResponse:
432        return cls(pdu_id, Protocol.StatusCode(pdu[0]))
433
434    def __init__(
435        self, pdu_id: Protocol.PduId, status_code: Protocol.StatusCode
436    ) -> None:
437        super().__init__(pdu_id, bytes([int(status_code)]))
438        self.status_code = status_code
439
440    def __str__(self) -> str:
441        return self.to_string(
442            {
443                "status_code": self.status_code.name,
444            }
445        )
446
447
448# -----------------------------------------------------------------------------
449class NotImplementedResponse(Response):
450    @classmethod
451    def from_bytes(cls, pdu_id: Protocol.PduId, pdu: bytes) -> NotImplementedResponse:
452        return cls(pdu_id, pdu[1:])
453
454
455# -----------------------------------------------------------------------------
456class GetCapabilitiesResponse(Response):
457    capability_id: GetCapabilitiesCommand.CapabilityId
458    capabilities: List[Union[SupportsBytes, bytes]]
459
460    @classmethod
461    def from_bytes(cls, pdu: bytes) -> GetCapabilitiesResponse:
462        if len(pdu) < 2:
463            # Possibly a reject response.
464            return cls(GetCapabilitiesCommand.CapabilityId(0), [])
465
466        # Assume that the payloads all follow the same pattern:
467        #  <CapabilityID><CapabilityCount><Capability*>
468        capability_id = GetCapabilitiesCommand.CapabilityId(pdu[0])
469        capability_count = pdu[1]
470
471        capabilities: List[Union[SupportsBytes, bytes]]
472        if capability_id == GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED:
473            capabilities = [EventId(pdu[2 + x]) for x in range(capability_count)]
474        else:
475            capability_size = (len(pdu) - 2) // capability_count
476            capabilities = [
477                pdu[x : x + capability_size]
478                for x in range(2, len(pdu), capability_size)
479            ]
480
481        return cls(capability_id, capabilities)
482
483    def __init__(
484        self,
485        capability_id: GetCapabilitiesCommand.CapabilityId,
486        capabilities: Sequence[Union[SupportsBytes, bytes]],
487    ) -> None:
488        super().__init__(
489            Protocol.PduId.GET_CAPABILITIES,
490            bytes([capability_id, len(capabilities)])
491            + b''.join(bytes(capability) for capability in capabilities),
492        )
493        self.capability_id = capability_id
494        self.capabilities = list(capabilities)
495
496    def __str__(self) -> str:
497        return self.to_string(
498            {
499                "capability_id": self.capability_id.name,
500                "capabilities": str(self.capabilities),
501            }
502        )
503
504
505# -----------------------------------------------------------------------------
506class GetPlayStatusResponse(Response):
507    song_length: int
508    song_position: int
509    play_status: PlayStatus
510
511    @classmethod
512    def from_bytes(cls, pdu: bytes) -> GetPlayStatusResponse:
513        (song_length, song_position) = struct.unpack_from(">II", pdu, 0)
514        play_status = PlayStatus(pdu[8])
515
516        return cls(song_length, song_position, play_status)
517
518    def __init__(
519        self,
520        song_length: int,
521        song_position: int,
522        play_status: PlayStatus,
523    ) -> None:
524        super().__init__(
525            Protocol.PduId.GET_PLAY_STATUS,
526            struct.pack(">IIB", song_length, song_position, int(play_status)),
527        )
528        self.song_length = song_length
529        self.song_position = song_position
530        self.play_status = play_status
531
532    def __str__(self) -> str:
533        return self.to_string(
534            {
535                "song_length": str(self.song_length),
536                "song_position": str(self.song_position),
537                "play_status": self.play_status.name,
538            }
539        )
540
541
542# -----------------------------------------------------------------------------
543class GetElementAttributesResponse(Response):
544    attributes: List[MediaAttribute]
545
546    @classmethod
547    def from_bytes(cls, pdu: bytes) -> GetElementAttributesResponse:
548        num_attributes = pdu[0]
549        offset = 1
550        attributes: List[MediaAttribute] = []
551        for _ in range(num_attributes):
552            (
553                attribute_id_int,
554                character_set_id_int,
555                attribute_value_length,
556            ) = struct.unpack_from(">IHH", pdu, offset)
557            attribute_value_bytes = pdu[
558                offset + 8 : offset + 8 + attribute_value_length
559            ]
560            attribute_id = MediaAttributeId(attribute_id_int)
561            character_set_id = CharacterSetId(character_set_id_int)
562            attribute_value = _decode_attribute_value(
563                attribute_value_bytes, character_set_id
564            )
565            attributes.append(
566                MediaAttribute(attribute_id, character_set_id, attribute_value)
567            )
568            offset += 8 + attribute_value_length
569
570        return cls(attributes)
571
572    def __init__(self, attributes: Sequence[MediaAttribute]) -> None:
573        parameter = bytes([len(attributes)])
574        for attribute in attributes:
575            attribute_value_bytes = attribute.attribute_value.encode("utf-8")
576            parameter += (
577                struct.pack(
578                    ">IHH",
579                    int(attribute.attribute_id),
580                    int(CharacterSetId.UTF_8),
581                    len(attribute_value_bytes),
582                )
583                + attribute_value_bytes
584            )
585        super().__init__(
586            Protocol.PduId.GET_ELEMENT_ATTRIBUTES,
587            parameter,
588        )
589        self.attributes = list(attributes)
590
591    def __str__(self) -> str:
592        attribute_strs = [str(attribute) for attribute in self.attributes]
593        return self.to_string(
594            {
595                "attributes": f"[{', '.join(attribute_strs)}]",
596            }
597        )
598
599
600# -----------------------------------------------------------------------------
601class SetAbsoluteVolumeResponse(Response):
602    volume: int
603
604    @classmethod
605    def from_bytes(cls, pdu: bytes) -> SetAbsoluteVolumeResponse:
606        return cls(pdu[0])
607
608    def __init__(self, volume: int) -> None:
609        super().__init__(Protocol.PduId.SET_ABSOLUTE_VOLUME, bytes([volume]))
610        self.volume = volume
611
612    def __str__(self) -> str:
613        return self.to_string({"volume": str(self.volume)})
614
615
616# -----------------------------------------------------------------------------
617class RegisterNotificationResponse(Response):
618    event: Event
619
620    @classmethod
621    def from_bytes(cls, pdu: bytes) -> RegisterNotificationResponse:
622        return cls(Event.from_bytes(pdu))
623
624    def __init__(self, event: Event) -> None:
625        super().__init__(
626            Protocol.PduId.REGISTER_NOTIFICATION,
627            bytes(event),
628        )
629        self.event = event
630
631    def __str__(self) -> str:
632        return self.to_string(
633            {
634                "event": str(self.event),
635            }
636        )
637
638
639# -----------------------------------------------------------------------------
640class EventId(OpenIntEnum):
641    PLAYBACK_STATUS_CHANGED = 0x01
642    TRACK_CHANGED = 0x02
643    TRACK_REACHED_END = 0x03
644    TRACK_REACHED_START = 0x04
645    PLAYBACK_POS_CHANGED = 0x05
646    BATT_STATUS_CHANGED = 0x06
647    SYSTEM_STATUS_CHANGED = 0x07
648    PLAYER_APPLICATION_SETTING_CHANGED = 0x08
649    NOW_PLAYING_CONTENT_CHANGED = 0x09
650    AVAILABLE_PLAYERS_CHANGED = 0x0A
651    ADDRESSED_PLAYER_CHANGED = 0x0B
652    UIDS_CHANGED = 0x0C
653    VOLUME_CHANGED = 0x0D
654
655    def __bytes__(self) -> bytes:
656        return bytes([int(self)])
657
658
659# -----------------------------------------------------------------------------
660class CharacterSetId(OpenIntEnum):
661    UTF_8 = 0x06
662
663
664# -----------------------------------------------------------------------------
665class MediaAttributeId(OpenIntEnum):
666    TITLE = 0x01
667    ARTIST_NAME = 0x02
668    ALBUM_NAME = 0x03
669    TRACK_NUMBER = 0x04
670    TOTAL_NUMBER_OF_TRACKS = 0x05
671    GENRE = 0x06
672    PLAYING_TIME = 0x07
673    DEFAULT_COVER_ART = 0x08
674
675
676# -----------------------------------------------------------------------------
677@dataclass
678class MediaAttribute:
679    attribute_id: MediaAttributeId
680    character_set_id: CharacterSetId
681    attribute_value: str
682
683
684# -----------------------------------------------------------------------------
685class PlayStatus(OpenIntEnum):
686    STOPPED = 0x00
687    PLAYING = 0x01
688    PAUSED = 0x02
689    FWD_SEEK = 0x03
690    REV_SEEK = 0x04
691    ERROR = 0xFF
692
693
694# -----------------------------------------------------------------------------
695@dataclass
696class SongAndPlayStatus:
697    song_length: int
698    song_position: int
699    play_status: PlayStatus
700
701
702# -----------------------------------------------------------------------------
703class ApplicationSetting:
704    class AttributeId(OpenIntEnum):
705        EQUALIZER_ON_OFF = 0x01
706        REPEAT_MODE = 0x02
707        SHUFFLE_ON_OFF = 0x03
708        SCAN_ON_OFF = 0x04
709
710    class EqualizerOnOffStatus(OpenIntEnum):
711        OFF = 0x01
712        ON = 0x02
713
714    class RepeatModeStatus(OpenIntEnum):
715        OFF = 0x01
716        SINGLE_TRACK_REPEAT = 0x02
717        ALL_TRACK_REPEAT = 0x03
718        GROUP_REPEAT = 0x04
719
720    class ShuffleOnOffStatus(OpenIntEnum):
721        OFF = 0x01
722        ALL_TRACKS_SHUFFLE = 0x02
723        GROUP_SHUFFLE = 0x03
724
725    class ScanOnOffStatus(OpenIntEnum):
726        OFF = 0x01
727        ALL_TRACKS_SCAN = 0x02
728        GROUP_SCAN = 0x03
729
730    class GenericValue(OpenIntEnum):
731        pass
732
733
734# -----------------------------------------------------------------------------
735@dataclass
736class Event:
737    event_id: EventId
738
739    @classmethod
740    def from_bytes(cls, pdu: bytes) -> Event:
741        event_id = EventId(pdu[0])
742        subclass = EVENT_SUBCLASSES.get(event_id, GenericEvent)
743        return subclass.from_bytes(pdu)
744
745    def __bytes__(self) -> bytes:
746        return bytes([self.event_id])
747
748
749# -----------------------------------------------------------------------------
750@dataclass
751class GenericEvent(Event):
752    data: bytes
753
754    @classmethod
755    def from_bytes(cls, pdu: bytes) -> GenericEvent:
756        return cls(event_id=EventId(pdu[0]), data=pdu[1:])
757
758    def __bytes__(self) -> bytes:
759        return bytes([self.event_id]) + self.data
760
761
762# -----------------------------------------------------------------------------
763@dataclass
764class PlaybackStatusChangedEvent(Event):
765    play_status: PlayStatus
766
767    @classmethod
768    def from_bytes(cls, pdu: bytes) -> PlaybackStatusChangedEvent:
769        return cls(play_status=PlayStatus(pdu[1]))
770
771    def __init__(self, play_status: PlayStatus) -> None:
772        super().__init__(EventId.PLAYBACK_STATUS_CHANGED)
773        self.play_status = play_status
774
775    def __bytes__(self) -> bytes:
776        return bytes([self.event_id]) + bytes([self.play_status])
777
778
779# -----------------------------------------------------------------------------
780@dataclass
781class PlaybackPositionChangedEvent(Event):
782    playback_position: int
783
784    @classmethod
785    def from_bytes(cls, pdu: bytes) -> PlaybackPositionChangedEvent:
786        return cls(playback_position=struct.unpack_from(">I", pdu, 1)[0])
787
788    def __init__(self, playback_position: int) -> None:
789        super().__init__(EventId.PLAYBACK_POS_CHANGED)
790        self.playback_position = playback_position
791
792    def __bytes__(self) -> bytes:
793        return bytes([self.event_id]) + struct.pack(">I", self.playback_position)
794
795
796# -----------------------------------------------------------------------------
797@dataclass
798class TrackChangedEvent(Event):
799    identifier: bytes
800
801    @classmethod
802    def from_bytes(cls, pdu: bytes) -> TrackChangedEvent:
803        return cls(identifier=pdu[1:])
804
805    def __init__(self, identifier: bytes) -> None:
806        super().__init__(EventId.TRACK_CHANGED)
807        self.identifier = identifier
808
809    def __bytes__(self) -> bytes:
810        return bytes([self.event_id]) + self.identifier
811
812
813# -----------------------------------------------------------------------------
814@dataclass
815class PlayerApplicationSettingChangedEvent(Event):
816    @dataclass
817    class Setting:
818        attribute_id: ApplicationSetting.AttributeId
819        value_id: OpenIntEnum
820
821    player_application_settings: List[Setting]
822
823    @classmethod
824    def from_bytes(cls, pdu: bytes) -> PlayerApplicationSettingChangedEvent:
825        def setting(attribute_id_int: int, value_id_int: int):
826            attribute_id = ApplicationSetting.AttributeId(attribute_id_int)
827            value_id: OpenIntEnum
828            if attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF:
829                value_id = ApplicationSetting.EqualizerOnOffStatus(value_id_int)
830            elif attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE:
831                value_id = ApplicationSetting.RepeatModeStatus(value_id_int)
832            elif attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF:
833                value_id = ApplicationSetting.ShuffleOnOffStatus(value_id_int)
834            elif attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF:
835                value_id = ApplicationSetting.ScanOnOffStatus(value_id_int)
836            else:
837                value_id = ApplicationSetting.GenericValue(value_id_int)
838
839            return cls.Setting(attribute_id, value_id)
840
841        settings = [
842            setting(pdu[2 + (i * 2)], pdu[2 + (i * 2) + 1]) for i in range(pdu[1])
843        ]
844        return cls(player_application_settings=settings)
845
846    def __init__(self, player_application_settings: Sequence[Setting]) -> None:
847        super().__init__(EventId.PLAYER_APPLICATION_SETTING_CHANGED)
848        self.player_application_settings = list(player_application_settings)
849
850    def __bytes__(self) -> bytes:
851        return (
852            bytes([self.event_id])
853            + bytes([len(self.player_application_settings)])
854            + b''.join(
855                [
856                    bytes([setting.attribute_id, setting.value_id])
857                    for setting in self.player_application_settings
858                ]
859            )
860        )
861
862
863# -----------------------------------------------------------------------------
864@dataclass
865class NowPlayingContentChangedEvent(Event):
866    @classmethod
867    def from_bytes(cls, pdu: bytes) -> NowPlayingContentChangedEvent:
868        return cls()
869
870    def __init__(self) -> None:
871        super().__init__(EventId.NOW_PLAYING_CONTENT_CHANGED)
872
873
874# -----------------------------------------------------------------------------
875@dataclass
876class AvailablePlayersChangedEvent(Event):
877    @classmethod
878    def from_bytes(cls, pdu: bytes) -> AvailablePlayersChangedEvent:
879        return cls()
880
881    def __init__(self) -> None:
882        super().__init__(EventId.AVAILABLE_PLAYERS_CHANGED)
883
884
885# -----------------------------------------------------------------------------
886@dataclass
887class AddressedPlayerChangedEvent(Event):
888    @dataclass
889    class Player:
890        player_id: int
891        uid_counter: int
892
893    @classmethod
894    def from_bytes(cls, pdu: bytes) -> AddressedPlayerChangedEvent:
895        player_id, uid_counter = struct.unpack_from("<HH", pdu, 1)
896        return cls(cls.Player(player_id, uid_counter))
897
898    def __init__(self, player: Player) -> None:
899        super().__init__(EventId.ADDRESSED_PLAYER_CHANGED)
900        self.player = player
901
902    def __bytes__(self) -> bytes:
903        return bytes([self.event_id]) + struct.pack(
904            ">HH", self.player.player_id, self.player.uid_counter
905        )
906
907
908# -----------------------------------------------------------------------------
909@dataclass
910class UidsChangedEvent(Event):
911    uid_counter: int
912
913    @classmethod
914    def from_bytes(cls, pdu: bytes) -> UidsChangedEvent:
915        return cls(uid_counter=struct.unpack_from(">H", pdu, 1)[0])
916
917    def __init__(self, uid_counter: int) -> None:
918        super().__init__(EventId.UIDS_CHANGED)
919        self.uid_counter = uid_counter
920
921    def __bytes__(self) -> bytes:
922        return bytes([self.event_id]) + struct.pack(">H", self.uid_counter)
923
924
925# -----------------------------------------------------------------------------
926@dataclass
927class VolumeChangedEvent(Event):
928    volume: int
929
930    @classmethod
931    def from_bytes(cls, pdu: bytes) -> VolumeChangedEvent:
932        return cls(volume=pdu[1])
933
934    def __init__(self, volume: int) -> None:
935        super().__init__(EventId.VOLUME_CHANGED)
936        self.volume = volume
937
938    def __bytes__(self) -> bytes:
939        return bytes([self.event_id]) + bytes([self.volume])
940
941
942# -----------------------------------------------------------------------------
943EVENT_SUBCLASSES: Dict[EventId, Type[Event]] = {
944    EventId.PLAYBACK_STATUS_CHANGED: PlaybackStatusChangedEvent,
945    EventId.PLAYBACK_POS_CHANGED: PlaybackPositionChangedEvent,
946    EventId.TRACK_CHANGED: TrackChangedEvent,
947    EventId.PLAYER_APPLICATION_SETTING_CHANGED: PlayerApplicationSettingChangedEvent,
948    EventId.NOW_PLAYING_CONTENT_CHANGED: NowPlayingContentChangedEvent,
949    EventId.AVAILABLE_PLAYERS_CHANGED: AvailablePlayersChangedEvent,
950    EventId.ADDRESSED_PLAYER_CHANGED: AddressedPlayerChangedEvent,
951    EventId.UIDS_CHANGED: UidsChangedEvent,
952    EventId.VOLUME_CHANGED: VolumeChangedEvent,
953}
954
955
956# -----------------------------------------------------------------------------
957class Delegate:
958    """
959    Base class for AVRCP delegates.
960
961    All the methods are async, even if they don't always need to be, so that
962    delegates that do need to wait for an async result may do so.
963    """
964
965    class Error(Exception):
966        """The delegate method failed, with a specified status code."""
967
968        def __init__(self, status_code: Protocol.StatusCode) -> None:
969            self.status_code = status_code
970
971    supported_events: List[EventId]
972    volume: int
973
974    def __init__(self, supported_events: Iterable[EventId] = ()) -> None:
975        self.supported_events = list(supported_events)
976        self.volume = 0
977
978    async def get_supported_events(self) -> List[EventId]:
979        return self.supported_events
980
981    async def set_absolute_volume(self, volume: int) -> None:
982        """
983        Set the absolute volume.
984
985        Returns: the effective volume that was set.
986        """
987        logger.debug(f"@@@ set_absolute_volume: volume={volume}")
988        self.volume = volume
989
990    async def get_absolute_volume(self) -> int:
991        return self.volume
992
993    # TODO add other delegate methods
994
995
996# -----------------------------------------------------------------------------
997class Protocol(pyee.EventEmitter):
998    """AVRCP Controller and Target protocol."""
999
1000    class PacketType(enum.IntEnum):
1001        SINGLE = 0b00
1002        START = 0b01
1003        CONTINUE = 0b10
1004        END = 0b11
1005
1006    class PduId(OpenIntEnum):
1007        GET_CAPABILITIES = 0x10
1008        LIST_PLAYER_APPLICATION_SETTING_ATTRIBUTES = 0x11
1009        LIST_PLAYER_APPLICATION_SETTING_VALUES = 0x12
1010        GET_CURRENT_PLAYER_APPLICATION_SETTING_VALUE = 0x13
1011        SET_PLAYER_APPLICATION_SETTING_VALUE = 0x14
1012        GET_PLAYER_APPLICATION_SETTING_ATTRIBUTE_TEXT = 0x15
1013        GET_PLAYER_APPLICATION_SETTING_VALUE_TEXT = 0x16
1014        INFORM_DISPLAYABLE_CHARACTER_SET = 0x17
1015        INFORM_BATTERY_STATUS_OF_CT = 0x18
1016        GET_ELEMENT_ATTRIBUTES = 0x20
1017        GET_PLAY_STATUS = 0x30
1018        REGISTER_NOTIFICATION = 0x31
1019        REQUEST_CONTINUING_RESPONSE = 0x40
1020        ABORT_CONTINUING_RESPONSE = 0x41
1021        SET_ABSOLUTE_VOLUME = 0x50
1022        SET_ADDRESSED_PLAYER = 0x60
1023        SET_BROWSED_PLAYER = 0x70
1024        GET_FOLDER_ITEMS = 0x71
1025        GET_TOTAL_NUMBER_OF_ITEMS = 0x75
1026
1027    class StatusCode(OpenIntEnum):
1028        INVALID_COMMAND = 0x00
1029        INVALID_PARAMETER = 0x01
1030        PARAMETER_CONTENT_ERROR = 0x02
1031        INTERNAL_ERROR = 0x03
1032        OPERATION_COMPLETED = 0x04
1033        UID_CHANGED = 0x05
1034        INVALID_DIRECTION = 0x07
1035        NOT_A_DIRECTORY = 0x08
1036        DOES_NOT_EXIST = 0x09
1037        INVALID_SCOPE = 0x0A
1038        RANGE_OUT_OF_BOUNDS = 0x0B
1039        FOLDER_ITEM_IS_NOT_PLAYABLE = 0x0C
1040        MEDIA_IN_USE = 0x0D
1041        NOW_PLAYING_LIST_FULL = 0x0E
1042        SEARCH_NOT_SUPPORTED = 0x0F
1043        SEARCH_IN_PROGRESS = 0x10
1044        INVALID_PLAYER_ID = 0x11
1045        PLAYER_NOT_BROWSABLE = 0x12
1046        PLAYER_NOT_ADDRESSED = 0x13
1047        NO_VALID_SEARCH_RESULTS = 0x14
1048        NO_AVAILABLE_PLAYERS = 0x15
1049        ADDRESSED_PLAYER_CHANGED = 0x16
1050
1051    class InvalidPidError(Exception):
1052        """A response frame with ipid==1 was received."""
1053
1054    class NotPendingError(Exception):
1055        """There is no pending command for a transaction label."""
1056
1057    class MismatchedResponseError(Exception):
1058        """The response type does not corresponding to the request type."""
1059
1060        def __init__(self, response: Response) -> None:
1061            self.response = response
1062
1063    class UnexpectedResponseTypeError(Exception):
1064        """The response type is not the expected one."""
1065
1066        def __init__(self, response: Protocol.ResponseContext) -> None:
1067            self.response = response
1068
1069    class UnexpectedResponseCodeError(Exception):
1070        """The response code was not the expected one."""
1071
1072        def __init__(
1073            self, response_code: avc.ResponseFrame.ResponseCode, response: Response
1074        ) -> None:
1075            self.response_code = response_code
1076            self.response = response
1077
1078    class PendingCommand:
1079        response: asyncio.Future
1080
1081        def __init__(self, transaction_label: int) -> None:
1082            self.transaction_label = transaction_label
1083            self.reset()
1084
1085        def reset(self):
1086            self.response = asyncio.get_running_loop().create_future()
1087
1088    @dataclass
1089    class ReceiveCommandState:
1090        transaction_label: int
1091        command_type: avc.CommandFrame.CommandType
1092
1093    @dataclass
1094    class ReceiveResponseState:
1095        transaction_label: int
1096        response_code: avc.ResponseFrame.ResponseCode
1097
1098    @dataclass
1099    class ResponseContext:
1100        transaction_label: int
1101        response: Response
1102
1103    @dataclass
1104    class FinalResponse(ResponseContext):
1105        response_code: avc.ResponseFrame.ResponseCode
1106
1107    @dataclass
1108    class InterimResponse(ResponseContext):
1109        final: Awaitable[Protocol.FinalResponse]
1110
1111    @dataclass
1112    class NotificationListener:
1113        transaction_label: int
1114        register_notification_command: RegisterNotificationCommand
1115
1116    delegate: Delegate
1117    send_transaction_label: int
1118    command_pdu_assembler: PduAssembler
1119    receive_command_state: Optional[ReceiveCommandState]
1120    response_pdu_assembler: PduAssembler
1121    receive_response_state: Optional[ReceiveResponseState]
1122    avctp_protocol: Optional[avctp.Protocol]
1123    free_commands: asyncio.Queue
1124    pending_commands: Dict[int, PendingCommand]  # Pending commands, by label
1125    notification_listeners: Dict[EventId, NotificationListener]
1126
1127    @staticmethod
1128    def _check_vendor_dependent_frame(
1129        frame: Union[avc.VendorDependentCommandFrame, avc.VendorDependentResponseFrame]
1130    ) -> bool:
1131        if frame.company_id != AVRCP_BLUETOOTH_SIG_COMPANY_ID:
1132            logger.debug("unsupported company id, ignoring")
1133            return False
1134
1135        if frame.subunit_type != avc.Frame.SubunitType.PANEL or frame.subunit_id != 0:
1136            logger.debug("unsupported subunit")
1137            return False
1138
1139        return True
1140
1141    def __init__(self, delegate: Optional[Delegate] = None) -> None:
1142        super().__init__()
1143        self.delegate = delegate if delegate else Delegate()
1144        self.command_pdu_assembler = PduAssembler(self._on_command_pdu)
1145        self.receive_command_state = None
1146        self.response_pdu_assembler = PduAssembler(self._on_response_pdu)
1147        self.receive_response_state = None
1148        self.avctp_protocol = None
1149        self.notification_listeners = {}
1150
1151        # Create an initial pool of free commands
1152        self.pending_commands = {}
1153        self.free_commands = asyncio.Queue()
1154        for transaction_label in range(16):
1155            self.free_commands.put_nowait(self.PendingCommand(transaction_label))
1156
1157    def listen(self, device: Device) -> None:
1158        """
1159        Listen for incoming connections.
1160
1161        A 'connection' event will be emitted when a connection is made, and a 'start'
1162        event will be emitted when the protocol is ready to be used on that connection.
1163        """
1164        device.register_l2cap_server(avctp.AVCTP_PSM, self._on_avctp_connection)
1165
1166    async def connect(self, connection: Connection) -> None:
1167        """
1168        Connect to a peer.
1169        """
1170        avctp_channel = await connection.create_l2cap_channel(
1171            l2cap.ClassicChannelSpec(psm=avctp.AVCTP_PSM)
1172        )
1173        self._on_avctp_channel_open(avctp_channel)
1174
1175    async def _obtain_pending_command(self) -> PendingCommand:
1176        pending_command = await self.free_commands.get()
1177        self.pending_commands[pending_command.transaction_label] = pending_command
1178        return pending_command
1179
1180    def recycle_pending_command(self, pending_command: PendingCommand) -> None:
1181        pending_command.reset()
1182        del self.pending_commands[pending_command.transaction_label]
1183        self.free_commands.put_nowait(pending_command)
1184        logger.debug(f"recycled pending command, {self.free_commands.qsize()} free")
1185
1186    _R = TypeVar('_R')
1187
1188    @staticmethod
1189    def _check_response(
1190        response_context: ResponseContext, expected_type: Type[_R]
1191    ) -> _R:
1192        if isinstance(response_context, Protocol.FinalResponse):
1193            if (
1194                response_context.response_code
1195                != avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE
1196            ):
1197                raise Protocol.UnexpectedResponseCodeError(
1198                    response_context.response_code, response_context.response
1199                )
1200
1201            if not (isinstance(response_context.response, expected_type)):
1202                raise Protocol.MismatchedResponseError(response_context.response)
1203
1204            return response_context.response
1205
1206        raise Protocol.UnexpectedResponseTypeError(response_context)
1207
1208    def _delegate_command(
1209        self, transaction_label: int, command: Command, method: Awaitable
1210    ) -> None:
1211        async def call():
1212            try:
1213                await method
1214            except Delegate.Error as error:
1215                self.send_rejected_avrcp_response(
1216                    transaction_label,
1217                    command.pdu_id,
1218                    error.status_code,
1219                )
1220            except Exception:
1221                logger.exception("delegate method raised exception")
1222                self.send_rejected_avrcp_response(
1223                    transaction_label,
1224                    command.pdu_id,
1225                    Protocol.StatusCode.INTERNAL_ERROR,
1226                )
1227
1228        utils.AsyncRunner.spawn(call())
1229
1230    async def get_supported_events(self) -> List[EventId]:
1231        """Get the list of events supported by the connected peer."""
1232        response_context = await self.send_avrcp_command(
1233            avc.CommandFrame.CommandType.STATUS,
1234            GetCapabilitiesCommand(
1235                GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED
1236            ),
1237        )
1238        response = self._check_response(response_context, GetCapabilitiesResponse)
1239        return cast(List[EventId], response.capabilities)
1240
1241    async def get_play_status(self) -> SongAndPlayStatus:
1242        """Get the play status of the connected peer."""
1243        response_context = await self.send_avrcp_command(
1244            avc.CommandFrame.CommandType.STATUS, GetPlayStatusCommand()
1245        )
1246        response = self._check_response(response_context, GetPlayStatusResponse)
1247        return SongAndPlayStatus(
1248            response.song_length, response.song_position, response.play_status
1249        )
1250
1251    async def get_element_attributes(
1252        self, element_identifier: int, attribute_ids: Sequence[MediaAttributeId]
1253    ) -> List[MediaAttribute]:
1254        """Get element attributes from the connected peer."""
1255        response_context = await self.send_avrcp_command(
1256            avc.CommandFrame.CommandType.STATUS,
1257            GetElementAttributesCommand(element_identifier, attribute_ids),
1258        )
1259        response = self._check_response(response_context, GetElementAttributesResponse)
1260        return response.attributes
1261
1262    async def monitor_events(
1263        self, event_id: EventId, playback_interval: int = 0
1264    ) -> AsyncIterator[Event]:
1265        """
1266        Monitor events emitted from a peer.
1267
1268        This generator yields Event objects.
1269        """
1270
1271        def check_response(response) -> Event:
1272            if not isinstance(response, RegisterNotificationResponse):
1273                raise self.MismatchedResponseError(response)
1274
1275            return response.event
1276
1277        while True:
1278            response = await self.send_avrcp_command(
1279                avc.CommandFrame.CommandType.NOTIFY,
1280                RegisterNotificationCommand(event_id, playback_interval),
1281            )
1282
1283            if isinstance(response, self.InterimResponse):
1284                logger.debug(f"interim: {response}")
1285                yield check_response(response.response)
1286
1287                logger.debug("waiting for final response")
1288                response = await response.final
1289
1290            if not isinstance(response, self.FinalResponse):
1291                raise self.UnexpectedResponseTypeError(response)
1292
1293            logger.debug(f"final: {response}")
1294            if response.response_code != avc.ResponseFrame.ResponseCode.CHANGED:
1295                raise self.UnexpectedResponseCodeError(
1296                    response.response_code, response.response
1297                )
1298
1299            yield check_response(response.response)
1300
1301    async def monitor_playback_status(
1302        self,
1303    ) -> AsyncIterator[PlayStatus]:
1304        """Monitor Playback Status changes from the connected peer."""
1305        async for event in self.monitor_events(EventId.PLAYBACK_STATUS_CHANGED, 0):
1306            if not isinstance(event, PlaybackStatusChangedEvent):
1307                logger.warning("unexpected event class")
1308                continue
1309            yield event.play_status
1310
1311    async def monitor_track_changed(
1312        self,
1313    ) -> AsyncIterator[bytes]:
1314        """Monitor Track changes from the connected peer."""
1315        async for event in self.monitor_events(EventId.TRACK_CHANGED, 0):
1316            if not isinstance(event, TrackChangedEvent):
1317                logger.warning("unexpected event class")
1318                continue
1319            yield event.identifier
1320
1321    async def monitor_playback_position(
1322        self, playback_interval: int
1323    ) -> AsyncIterator[int]:
1324        """Monitor Playback Position changes from the connected peer."""
1325        async for event in self.monitor_events(
1326            EventId.PLAYBACK_POS_CHANGED, playback_interval
1327        ):
1328            if not isinstance(event, PlaybackPositionChangedEvent):
1329                logger.warning("unexpected event class")
1330                continue
1331            yield event.playback_position
1332
1333    async def monitor_player_application_settings(
1334        self,
1335    ) -> AsyncIterator[List[PlayerApplicationSettingChangedEvent.Setting]]:
1336        """Monitor Player Application Setting changes from the connected peer."""
1337        async for event in self.monitor_events(
1338            EventId.PLAYER_APPLICATION_SETTING_CHANGED, 0
1339        ):
1340            if not isinstance(event, PlayerApplicationSettingChangedEvent):
1341                logger.warning("unexpected event class")
1342                continue
1343            yield event.player_application_settings
1344
1345    async def monitor_now_playing_content(self) -> AsyncIterator[None]:
1346        """Monitor Now Playing changes from the connected peer."""
1347        async for event in self.monitor_events(EventId.NOW_PLAYING_CONTENT_CHANGED, 0):
1348            if not isinstance(event, NowPlayingContentChangedEvent):
1349                logger.warning("unexpected event class")
1350                continue
1351            yield None
1352
1353    async def monitor_available_players(self) -> AsyncIterator[None]:
1354        """Monitor Available Players changes from the connected peer."""
1355        async for event in self.monitor_events(EventId.AVAILABLE_PLAYERS_CHANGED, 0):
1356            if not isinstance(event, AvailablePlayersChangedEvent):
1357                logger.warning("unexpected event class")
1358                continue
1359            yield None
1360
1361    async def monitor_addressed_player(
1362        self,
1363    ) -> AsyncIterator[AddressedPlayerChangedEvent.Player]:
1364        """Monitor Addressed Player changes from the connected peer."""
1365        async for event in self.monitor_events(EventId.ADDRESSED_PLAYER_CHANGED, 0):
1366            if not isinstance(event, AddressedPlayerChangedEvent):
1367                logger.warning("unexpected event class")
1368                continue
1369            yield event.player
1370
1371    async def monitor_uids(
1372        self,
1373    ) -> AsyncIterator[int]:
1374        """Monitor UID changes from the connected peer."""
1375        async for event in self.monitor_events(EventId.UIDS_CHANGED, 0):
1376            if not isinstance(event, UidsChangedEvent):
1377                logger.warning("unexpected event class")
1378                continue
1379            yield event.uid_counter
1380
1381    async def monitor_volume(
1382        self,
1383    ) -> AsyncIterator[int]:
1384        """Monitor Volume changes from the connected peer."""
1385        async for event in self.monitor_events(EventId.VOLUME_CHANGED, 0):
1386            if not isinstance(event, VolumeChangedEvent):
1387                logger.warning("unexpected event class")
1388                continue
1389            yield event.volume
1390
1391    def notify_event(self, event: Event):
1392        """Notify an event to the connected peer."""
1393        if (listener := self.notification_listeners.get(event.event_id)) is None:
1394            logger.debug(f"no listener for {event.event_id.name}")
1395            return
1396
1397        # Emit the notification.
1398        notification = RegisterNotificationResponse(event)
1399        self.send_avrcp_response(
1400            listener.transaction_label,
1401            avc.ResponseFrame.ResponseCode.CHANGED,
1402            notification,
1403        )
1404
1405        # Remove the listener (they will need to re-register).
1406        del self.notification_listeners[event.event_id]
1407
1408    def notify_playback_status_changed(self, status: PlayStatus) -> None:
1409        """Notify the connected peer of a Playback Status change."""
1410        self.notify_event(PlaybackStatusChangedEvent(status))
1411
1412    def notify_track_changed(self, identifier: bytes) -> None:
1413        """Notify the connected peer of a Track change."""
1414        if len(identifier) != 8:
1415            raise InvalidArgumentError("identifier must be 8 bytes")
1416        self.notify_event(TrackChangedEvent(identifier))
1417
1418    def notify_playback_position_changed(self, position: int) -> None:
1419        """Notify the connected peer of a Position change."""
1420        self.notify_event(PlaybackPositionChangedEvent(position))
1421
1422    def notify_player_application_settings_changed(
1423        self, settings: Sequence[PlayerApplicationSettingChangedEvent.Setting]
1424    ) -> None:
1425        """Notify the connected peer of an Player Application Setting change."""
1426        self.notify_event(
1427            PlayerApplicationSettingChangedEvent(settings),
1428        )
1429
1430    def notify_now_playing_content_changed(self) -> None:
1431        """Notify the connected peer of a Now Playing change."""
1432        self.notify_event(NowPlayingContentChangedEvent())
1433
1434    def notify_available_players_changed(self) -> None:
1435        """Notify the connected peer of an Available Players change."""
1436        self.notify_event(AvailablePlayersChangedEvent())
1437
1438    def notify_addressed_player_changed(
1439        self, player: AddressedPlayerChangedEvent.Player
1440    ) -> None:
1441        """Notify the connected peer of an Addressed Player change."""
1442        self.notify_event(AddressedPlayerChangedEvent(player))
1443
1444    def notify_uids_changed(self, uid_counter: int) -> None:
1445        """Notify the connected peer of a UID change."""
1446        self.notify_event(UidsChangedEvent(uid_counter))
1447
1448    def notify_volume_changed(self, volume: int) -> None:
1449        """Notify the connected peer of a Volume change."""
1450        self.notify_event(VolumeChangedEvent(volume))
1451
1452    def _register_notification_listener(
1453        self, transaction_label: int, command: RegisterNotificationCommand
1454    ) -> None:
1455        listener = self.NotificationListener(transaction_label, command)
1456        self.notification_listeners[command.event_id] = listener
1457
1458    def _on_avctp_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
1459        logger.debug("AVCTP connection established")
1460        l2cap_channel.on("open", lambda: self._on_avctp_channel_open(l2cap_channel))
1461
1462        self.emit("connection")
1463
1464    def _on_avctp_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
1465        logger.debug("AVCTP channel open")
1466        if self.avctp_protocol is not None:
1467            # TODO: find a better strategy instead of just closing
1468            logger.warning("AVCTP protocol already active, closing connection")
1469            AsyncRunner.spawn(l2cap_channel.disconnect())
1470            return
1471
1472        self.avctp_protocol = avctp.Protocol(l2cap_channel)
1473        self.avctp_protocol.register_command_handler(AVRCP_PID, self._on_avctp_command)
1474        self.avctp_protocol.register_response_handler(
1475            AVRCP_PID, self._on_avctp_response
1476        )
1477        l2cap_channel.on("close", self._on_avctp_channel_close)
1478
1479        self.emit("start")
1480
1481    def _on_avctp_channel_close(self) -> None:
1482        logger.debug("AVCTP channel closed")
1483        self.avctp_protocol = None
1484
1485        self.emit("stop")
1486
1487    def _on_avctp_command(
1488        self, transaction_label: int, command: avc.CommandFrame
1489    ) -> None:
1490        logger.debug(
1491            f"<<< AVCTP Command, transaction_label={transaction_label}: " f"{command}"
1492        )
1493
1494        # Only the PANEL subunit type with subunit ID 0 is supported in this profile.
1495        if (
1496            command.subunit_type != avc.Frame.SubunitType.PANEL
1497            or command.subunit_id != 0
1498        ):
1499            logger.debug("subunit not supported")
1500            self.send_not_implemented_response(transaction_label, command)
1501            return
1502
1503        if isinstance(command, avc.VendorDependentCommandFrame):
1504            if not self._check_vendor_dependent_frame(command):
1505                return
1506
1507            if self.receive_command_state is None:
1508                self.receive_command_state = self.ReceiveCommandState(
1509                    transaction_label=transaction_label, command_type=command.ctype
1510                )
1511            elif (
1512                self.receive_command_state.transaction_label != transaction_label
1513                or self.receive_command_state.command_type != command.ctype
1514            ):
1515                # We're in the middle of some other PDU
1516                logger.warning("received interleaved PDU, resetting state")
1517                self.command_pdu_assembler.reset()
1518                self.receive_command_state = None
1519                return
1520            else:
1521                self.receive_command_state.command_type = command.ctype
1522                self.receive_command_state.transaction_label = transaction_label
1523
1524            self.command_pdu_assembler.on_pdu(command.vendor_dependent_data)
1525            return
1526
1527        if isinstance(command, avc.PassThroughCommandFrame):
1528            # TODO: delegate
1529            response = avc.PassThroughResponseFrame(
1530                avc.ResponseFrame.ResponseCode.ACCEPTED,
1531                avc.Frame.SubunitType.PANEL,
1532                0,
1533                command.state_flag,
1534                command.operation_id,
1535                command.operation_data,
1536            )
1537            self.send_response(transaction_label, response)
1538            return
1539
1540        # TODO handle other types
1541        self.send_not_implemented_response(transaction_label, command)
1542
1543    def _on_avctp_response(
1544        self, transaction_label: int, response: Optional[avc.ResponseFrame]
1545    ) -> None:
1546        logger.debug(
1547            f"<<< AVCTP Response, transaction_label={transaction_label}: {response}"
1548        )
1549
1550        # Check that we have a pending command that matches this response.
1551        if not (pending_command := self.pending_commands.get(transaction_label)):
1552            logger.warning("no pending command with this transaction label")
1553            return
1554
1555        # A None response means an invalid PID was used in the request.
1556        if response is None:
1557            pending_command.response.set_exception(self.InvalidPidError())
1558
1559        if isinstance(response, avc.VendorDependentResponseFrame):
1560            if not self._check_vendor_dependent_frame(response):
1561                return
1562
1563            if self.receive_response_state is None:
1564                self.receive_response_state = self.ReceiveResponseState(
1565                    transaction_label=transaction_label, response_code=response.response
1566                )
1567            elif (
1568                self.receive_response_state.transaction_label != transaction_label
1569                or self.receive_response_state.response_code != response.response
1570            ):
1571                # We're in the middle of some other PDU
1572                logger.warning("received interleaved PDU, resetting state")
1573                self.response_pdu_assembler.reset()
1574                self.receive_response_state = None
1575                return
1576            else:
1577                self.receive_response_state.response_code = response.response
1578                self.receive_response_state.transaction_label = transaction_label
1579
1580            self.response_pdu_assembler.on_pdu(response.vendor_dependent_data)
1581            return
1582
1583        if isinstance(response, avc.PassThroughResponseFrame):
1584            pending_command.response.set_result(response)
1585
1586        # TODO handle other types
1587
1588        self.recycle_pending_command(pending_command)
1589
1590    def _on_command_pdu(self, pdu_id: PduId, pdu: bytes) -> None:
1591        logger.debug(f"<<< AVRCP command PDU [pdu_id={pdu_id.name}]: {pdu.hex()}")
1592
1593        assert self.receive_command_state is not None
1594        transaction_label = self.receive_command_state.transaction_label
1595
1596        # Dispatch the command.
1597        # NOTE: with a small number of supported commands, a manual dispatch like this
1598        # is Ok, but if/when more commands are supported, a lookup dispatch mechanism
1599        # would be more appropriate.
1600        # TODO: switch on ctype
1601        if self.receive_command_state.command_type in (
1602            avc.CommandFrame.CommandType.CONTROL,
1603            avc.CommandFrame.CommandType.STATUS,
1604            avc.CommandFrame.CommandType.NOTIFY,
1605        ):
1606            # TODO: catch exceptions from delegates
1607            if pdu_id == self.PduId.GET_CAPABILITIES:
1608                self._on_get_capabilities_command(
1609                    transaction_label, GetCapabilitiesCommand.from_bytes(pdu)
1610                )
1611            elif pdu_id == self.PduId.SET_ABSOLUTE_VOLUME:
1612                self._on_set_absolute_volume_command(
1613                    transaction_label, SetAbsoluteVolumeCommand.from_bytes(pdu)
1614                )
1615            elif pdu_id == self.PduId.REGISTER_NOTIFICATION:
1616                self._on_register_notification_command(
1617                    transaction_label, RegisterNotificationCommand.from_bytes(pdu)
1618                )
1619            else:
1620                # Not supported.
1621                # TODO: check that this is the right way to respond in this case.
1622                logger.debug("unsupported PDU ID")
1623                self.send_rejected_avrcp_response(
1624                    transaction_label, pdu_id, self.StatusCode.INVALID_PARAMETER
1625                )
1626        else:
1627            logger.debug("unsupported command type")
1628            self.send_rejected_avrcp_response(
1629                transaction_label, pdu_id, self.StatusCode.INVALID_COMMAND
1630            )
1631
1632        self.receive_command_state = None
1633
1634    def _on_response_pdu(self, pdu_id: PduId, pdu: bytes) -> None:
1635        logger.debug(f"<<< AVRCP response PDU [pdu_id={pdu_id.name}]: {pdu.hex()}")
1636
1637        assert self.receive_response_state is not None
1638
1639        transaction_label = self.receive_response_state.transaction_label
1640        response_code = self.receive_response_state.response_code
1641        self.receive_response_state = None
1642
1643        # Check that we have a pending command that matches this response.
1644        if not (pending_command := self.pending_commands.get(transaction_label)):
1645            logger.warning("no pending command with this transaction label")
1646            return
1647
1648        # Convert the PDU bytes into a response object.
1649        # NOTE: with a small number of supported responses, a manual switch like this
1650        # is Ok, but if/when more responses are supported, a lookup mechanism would be
1651        # more appropriate.
1652        response: Optional[Response] = None
1653        if response_code == avc.ResponseFrame.ResponseCode.REJECTED:
1654            response = RejectedResponse.from_bytes(pdu_id, pdu)
1655        elif response_code == avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED:
1656            response = NotImplementedResponse.from_bytes(pdu_id, pdu)
1657        elif response_code in (
1658            avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
1659            avc.ResponseFrame.ResponseCode.INTERIM,
1660            avc.ResponseFrame.ResponseCode.CHANGED,
1661            avc.ResponseFrame.ResponseCode.ACCEPTED,
1662        ):
1663            if pdu_id == self.PduId.GET_CAPABILITIES:
1664                response = GetCapabilitiesResponse.from_bytes(pdu)
1665            elif pdu_id == self.PduId.GET_PLAY_STATUS:
1666                response = GetPlayStatusResponse.from_bytes(pdu)
1667            elif pdu_id == self.PduId.GET_ELEMENT_ATTRIBUTES:
1668                response = GetElementAttributesResponse.from_bytes(pdu)
1669            elif pdu_id == self.PduId.SET_ABSOLUTE_VOLUME:
1670                response = SetAbsoluteVolumeResponse.from_bytes(pdu)
1671            elif pdu_id == self.PduId.REGISTER_NOTIFICATION:
1672                response = RegisterNotificationResponse.from_bytes(pdu)
1673            else:
1674                logger.debug("unexpected PDU ID")
1675                pending_command.response.set_exception(
1676                    ProtocolError(
1677                        error_code=None,
1678                        error_namespace="avrcp",
1679                        details="unexpected PDU ID",
1680                    )
1681                )
1682        else:
1683            logger.debug("unexpected response code")
1684            pending_command.response.set_exception(
1685                ProtocolError(
1686                    error_code=None,
1687                    error_namespace="avrcp",
1688                    details="unexpected response code",
1689                )
1690            )
1691
1692        if response is None:
1693            self.recycle_pending_command(pending_command)
1694            return
1695
1696        logger.debug(f"<<< AVRCP response: {response}")
1697
1698        # Make the response available to the waiter.
1699        if response_code == avc.ResponseFrame.ResponseCode.INTERIM:
1700            pending_interim_response = pending_command.response
1701            pending_command.reset()
1702            pending_interim_response.set_result(
1703                self.InterimResponse(
1704                    pending_command.transaction_label,
1705                    response,
1706                    pending_command.response,
1707                )
1708            )
1709        else:
1710            pending_command.response.set_result(
1711                self.FinalResponse(
1712                    pending_command.transaction_label,
1713                    response,
1714                    response_code,
1715                )
1716            )
1717            self.recycle_pending_command(pending_command)
1718
1719    def send_command(self, transaction_label: int, command: avc.CommandFrame) -> None:
1720        logger.debug(f">>> AVRCP command: {command}")
1721
1722        if self.avctp_protocol is None:
1723            logger.warning("trying to send command while avctp_protocol is None")
1724            return
1725
1726        self.avctp_protocol.send_command(transaction_label, AVRCP_PID, bytes(command))
1727
1728    async def send_passthrough_command(
1729        self, command: avc.PassThroughCommandFrame
1730    ) -> avc.PassThroughResponseFrame:
1731        # Wait for a free command slot.
1732        pending_command = await self._obtain_pending_command()
1733
1734        # Send the command.
1735        self.send_command(pending_command.transaction_label, command)
1736
1737        # Wait for the response.
1738        return await pending_command.response
1739
1740    async def send_key_event(
1741        self, key: avc.PassThroughCommandFrame.OperationId, pressed: bool
1742    ) -> avc.PassThroughResponseFrame:
1743        """Send a key event to the connected peer."""
1744        return await self.send_passthrough_command(
1745            avc.PassThroughCommandFrame(
1746                avc.CommandFrame.CommandType.CONTROL,
1747                avc.Frame.SubunitType.PANEL,
1748                0,
1749                (
1750                    avc.PassThroughFrame.StateFlag.PRESSED
1751                    if pressed
1752                    else avc.PassThroughFrame.StateFlag.RELEASED
1753                ),
1754                key,
1755                b'',
1756            )
1757        )
1758
1759    async def send_avrcp_command(
1760        self, command_type: avc.CommandFrame.CommandType, command: Command
1761    ) -> ResponseContext:
1762        # Wait for a free command slot.
1763        pending_command = await self._obtain_pending_command()
1764
1765        # TODO: fragmentation
1766        # Send the command.
1767        logger.debug(f">>> AVRCP command PDU: {command}")
1768        pdu = (
1769            struct.pack(">BBH", command.pdu_id, 0, len(command.parameter))
1770            + command.parameter
1771        )
1772        command_frame = avc.VendorDependentCommandFrame(
1773            command_type,
1774            avc.Frame.SubunitType.PANEL,
1775            0,
1776            AVRCP_BLUETOOTH_SIG_COMPANY_ID,
1777            pdu,
1778        )
1779        self.send_command(pending_command.transaction_label, command_frame)
1780
1781        # Wait for the response.
1782        return await pending_command.response
1783
1784    def send_response(
1785        self, transaction_label: int, response: avc.ResponseFrame
1786    ) -> None:
1787        assert self.avctp_protocol is not None
1788        logger.debug(f">>> AVRCP response: {response}")
1789        self.avctp_protocol.send_response(transaction_label, AVRCP_PID, bytes(response))
1790
1791    def send_passthrough_response(
1792        self,
1793        transaction_label: int,
1794        command: avc.PassThroughCommandFrame,
1795        response_code: avc.ResponseFrame.ResponseCode,
1796    ):
1797        response = avc.PassThroughResponseFrame(
1798            response_code,
1799            avc.Frame.SubunitType.PANEL,
1800            0,
1801            command.state_flag,
1802            command.operation_id,
1803            command.operation_data,
1804        )
1805        self.send_response(transaction_label, response)
1806
1807    def send_avrcp_response(
1808        self,
1809        transaction_label: int,
1810        response_code: avc.ResponseFrame.ResponseCode,
1811        response: Response,
1812    ) -> None:
1813        # TODO: fragmentation
1814        logger.debug(f">>> AVRCP response PDU: {response}")
1815        pdu = (
1816            struct.pack(">BBH", response.pdu_id, 0, len(response.parameter))
1817            + response.parameter
1818        )
1819        response_frame = avc.VendorDependentResponseFrame(
1820            response_code,
1821            avc.Frame.SubunitType.PANEL,
1822            0,
1823            AVRCP_BLUETOOTH_SIG_COMPANY_ID,
1824            pdu,
1825        )
1826        self.send_response(transaction_label, response_frame)
1827
1828    def send_not_implemented_response(
1829        self, transaction_label: int, command: avc.CommandFrame
1830    ) -> None:
1831        response = avc.ResponseFrame(
1832            avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED,
1833            command.subunit_type,
1834            command.subunit_id,
1835            command.opcode,
1836            command.operands,
1837        )
1838        self.send_response(transaction_label, response)
1839
1840    def send_rejected_avrcp_response(
1841        self, transaction_label: int, pdu_id: Protocol.PduId, status_code: StatusCode
1842    ) -> None:
1843        self.send_avrcp_response(
1844            transaction_label,
1845            avc.ResponseFrame.ResponseCode.REJECTED,
1846            RejectedResponse(pdu_id, status_code),
1847        )
1848
1849    def _on_get_capabilities_command(
1850        self, transaction_label: int, command: GetCapabilitiesCommand
1851    ) -> None:
1852        logger.debug(f"<<< AVRCP command PDU: {command}")
1853
1854        async def get_supported_events():
1855            if (
1856                command.capability_id
1857                != GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED
1858            ):
1859                raise Protocol.InvalidParameterError
1860
1861            supported_events = await self.delegate.get_supported_events()
1862            self.send_avrcp_response(
1863                transaction_label,
1864                avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
1865                GetCapabilitiesResponse(command.capability_id, supported_events),
1866            )
1867
1868        self._delegate_command(transaction_label, command, get_supported_events())
1869
1870    def _on_set_absolute_volume_command(
1871        self, transaction_label: int, command: SetAbsoluteVolumeCommand
1872    ) -> None:
1873        logger.debug(f"<<< AVRCP command PDU: {command}")
1874
1875        async def set_absolute_volume():
1876            await self.delegate.set_absolute_volume(command.volume)
1877            effective_volume = await self.delegate.get_absolute_volume()
1878            self.send_avrcp_response(
1879                transaction_label,
1880                avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
1881                SetAbsoluteVolumeResponse(effective_volume),
1882            )
1883
1884        self._delegate_command(transaction_label, command, set_absolute_volume())
1885
1886    def _on_register_notification_command(
1887        self, transaction_label: int, command: RegisterNotificationCommand
1888    ) -> None:
1889        logger.debug(f"<<< AVRCP command PDU: {command}")
1890
1891        async def register_notification():
1892            # Check if the event is supported.
1893            supported_events = await self.delegate.get_supported_events()
1894            if command.event_id in supported_events:
1895                if command.event_id == EventId.VOLUME_CHANGED:
1896                    volume = await self.delegate.get_absolute_volume()
1897                    response = RegisterNotificationResponse(VolumeChangedEvent(volume))
1898                    self.send_avrcp_response(
1899                        transaction_label,
1900                        avc.ResponseFrame.ResponseCode.INTERIM,
1901                        response,
1902                    )
1903                    self._register_notification_listener(transaction_label, command)
1904                    return
1905
1906                if command.event_id == EventId.PLAYBACK_STATUS_CHANGED:
1907                    # TODO: testing only, use delegate
1908                    response = RegisterNotificationResponse(
1909                        PlaybackStatusChangedEvent(play_status=PlayStatus.PLAYING)
1910                    )
1911                    self.send_avrcp_response(
1912                        transaction_label,
1913                        avc.ResponseFrame.ResponseCode.INTERIM,
1914                        response,
1915                    )
1916                    self._register_notification_listener(transaction_label, command)
1917                    return
1918
1919        self._delegate_command(transaction_label, command, register_notification())
1920