1# Copyright 2021-2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18from __future__ import annotations
19import enum
20import struct
21from typing import Dict, Type, Union, Tuple
22
23from bumble import core
24from bumble.utils import OpenIntEnum
25
26
27# -----------------------------------------------------------------------------
28class Frame:
29    class SubunitType(enum.IntEnum):
30        # AV/C Digital Interface Command Set General Specification Version 4.1
31        # Table 7.4
32        MONITOR = 0x00
33        AUDIO = 0x01
34        PRINTER = 0x02
35        DISC = 0x03
36        TAPE_RECORDER_OR_PLAYER = 0x04
37        TUNER = 0x05
38        CA = 0x06
39        CAMERA = 0x07
40        PANEL = 0x09
41        BULLETIN_BOARD = 0x0A
42        VENDOR_UNIQUE = 0x1C
43        EXTENDED = 0x1E
44        UNIT = 0x1F
45
46    class OperationCode(OpenIntEnum):
47        # 0x00 - 0x0F: Unit and subunit commands
48        VENDOR_DEPENDENT = 0x00
49        RESERVE = 0x01
50        PLUG_INFO = 0x02
51
52        # 0x10 - 0x3F: Unit commands
53        DIGITAL_OUTPUT = 0x10
54        DIGITAL_INPUT = 0x11
55        CHANNEL_USAGE = 0x12
56        OUTPUT_PLUG_SIGNAL_FORMAT = 0x18
57        INPUT_PLUG_SIGNAL_FORMAT = 0x19
58        GENERAL_BUS_SETUP = 0x1F
59        CONNECT_AV = 0x20
60        DISCONNECT_AV = 0x21
61        CONNECTIONS = 0x22
62        CONNECT = 0x24
63        DISCONNECT = 0x25
64        UNIT_INFO = 0x30
65        SUBUNIT_INFO = 0x31
66
67        # 0x40 - 0x7F: Subunit commands
68        PASS_THROUGH = 0x7C
69        GUI_UPDATE = 0x7D
70        PUSH_GUI_DATA = 0x7E
71        USER_ACTION = 0x7F
72
73        # 0xA0 - 0xBF: Unit and subunit commands
74        VERSION = 0xB0
75        POWER = 0xB2
76
77    subunit_type: SubunitType
78    subunit_id: int
79    opcode: OperationCode
80    operands: bytes
81
82    @staticmethod
83    def subclass(subclass):
84        # Infer the opcode from the class name
85        if subclass.__name__.endswith("CommandFrame"):
86            short_name = subclass.__name__.replace("CommandFrame", "")
87            category_class = CommandFrame
88        elif subclass.__name__.endswith("ResponseFrame"):
89            short_name = subclass.__name__.replace("ResponseFrame", "")
90            category_class = ResponseFrame
91        else:
92            raise core.InvalidArgumentError(
93                f"invalid subclass name {subclass.__name__}"
94            )
95
96        uppercase_indexes = [
97            i for i in range(len(short_name)) if short_name[i].isupper()
98        ]
99        uppercase_indexes.append(len(short_name))
100        words = [
101            short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper()
102            for i in range(len(uppercase_indexes) - 1)
103        ]
104        opcode_name = "_".join(words)
105        opcode = Frame.OperationCode[opcode_name]
106        category_class.subclasses[opcode] = subclass
107        return subclass
108
109    @staticmethod
110    def from_bytes(data: bytes) -> Frame:
111        if data[0] >> 4 != 0:
112            raise core.InvalidPacketError("first 4 bits must be 0s")
113
114        ctype_or_response = data[0] & 0xF
115        subunit_type = Frame.SubunitType(data[1] >> 3)
116        subunit_id = data[1] & 7
117
118        if subunit_type == Frame.SubunitType.EXTENDED:
119            # Not supported
120            raise NotImplementedError("extended subunit types not supported")
121
122        if subunit_id < 5:
123            opcode_offset = 2
124        elif subunit_id == 5:
125            # Extended to the next byte
126            extension = data[2]
127            if extension == 0:
128                raise core.InvalidPacketError("extended subunit ID value reserved")
129            if extension == 0xFF:
130                subunit_id = 5 + 254 + data[3]
131                opcode_offset = 4
132            else:
133                subunit_id = 5 + extension
134                opcode_offset = 3
135
136        elif subunit_id == 6:
137            raise core.InvalidPacketError("reserved subunit ID")
138
139        opcode = Frame.OperationCode(data[opcode_offset])
140        operands = data[opcode_offset + 1 :]
141
142        # Look for a registered subclass
143        if ctype_or_response < 8:
144            # Command
145            ctype = CommandFrame.CommandType(ctype_or_response)
146            if c_subclass := CommandFrame.subclasses.get(opcode):
147                return c_subclass(
148                    ctype,
149                    subunit_type,
150                    subunit_id,
151                    *c_subclass.parse_operands(operands),
152                )
153            return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands)
154        else:
155            # Response
156            response = ResponseFrame.ResponseCode(ctype_or_response)
157            if r_subclass := ResponseFrame.subclasses.get(opcode):
158                return r_subclass(
159                    response,
160                    subunit_type,
161                    subunit_id,
162                    *r_subclass.parse_operands(operands),
163                )
164            return ResponseFrame(response, subunit_type, subunit_id, opcode, operands)
165
166    def to_bytes(
167        self,
168        ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
169    ) -> bytes:
170        # TODO: support extended subunit types and ids.
171        return (
172            bytes(
173                [
174                    ctype_or_response,
175                    self.subunit_type << 3 | self.subunit_id,
176                    self.opcode,
177                ]
178            )
179            + self.operands
180        )
181
182    def to_string(self, extra: str) -> str:
183        return (
184            f"{self.__class__.__name__}({extra}"
185            f"subunit_type={self.subunit_type.name}, "
186            f"subunit_id=0x{self.subunit_id:02X}, "
187            f"opcode={self.opcode.name}, "
188            f"operands={self.operands.hex()})"
189        )
190
191    def __init__(
192        self,
193        subunit_type: SubunitType,
194        subunit_id: int,
195        opcode: OperationCode,
196        operands: bytes,
197    ) -> None:
198        self.subunit_type = subunit_type
199        self.subunit_id = subunit_id
200        self.opcode = opcode
201        self.operands = operands
202
203
204# -----------------------------------------------------------------------------
205class CommandFrame(Frame):
206    class CommandType(OpenIntEnum):
207        # AV/C Digital Interface Command Set General Specification Version 4.1
208        # Table 7.1
209        CONTROL = 0x00
210        STATUS = 0x01
211        SPECIFIC_INQUIRY = 0x02
212        NOTIFY = 0x03
213        GENERAL_INQUIRY = 0x04
214
215    subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
216    ctype: CommandType
217
218    @staticmethod
219    def parse_operands(operands: bytes) -> Tuple:
220        raise NotImplementedError
221
222    def __init__(
223        self,
224        ctype: CommandType,
225        subunit_type: Frame.SubunitType,
226        subunit_id: int,
227        opcode: Frame.OperationCode,
228        operands: bytes,
229    ) -> None:
230        super().__init__(subunit_type, subunit_id, opcode, operands)
231        self.ctype = ctype
232
233    def __bytes__(self):
234        return self.to_bytes(self.ctype)
235
236    def __str__(self):
237        return self.to_string(f"ctype={self.ctype.name}, ")
238
239
240# -----------------------------------------------------------------------------
241class ResponseFrame(Frame):
242    class ResponseCode(OpenIntEnum):
243        # AV/C Digital Interface Command Set General Specification Version 4.1
244        # Table 7.2
245        NOT_IMPLEMENTED = 0x08
246        ACCEPTED = 0x09
247        REJECTED = 0x0A
248        IN_TRANSITION = 0x0B
249        IMPLEMENTED_OR_STABLE = 0x0C
250        CHANGED = 0x0D
251        INTERIM = 0x0F
252
253    subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
254    response: ResponseCode
255
256    @staticmethod
257    def parse_operands(operands: bytes) -> Tuple:
258        raise NotImplementedError
259
260    def __init__(
261        self,
262        response: ResponseCode,
263        subunit_type: Frame.SubunitType,
264        subunit_id: int,
265        opcode: Frame.OperationCode,
266        operands: bytes,
267    ) -> None:
268        super().__init__(subunit_type, subunit_id, opcode, operands)
269        self.response = response
270
271    def __bytes__(self):
272        return self.to_bytes(self.response)
273
274    def __str__(self):
275        return self.to_string(f"response={self.response.name}, ")
276
277
278# -----------------------------------------------------------------------------
279class VendorDependentFrame:
280    company_id: int
281    vendor_dependent_data: bytes
282
283    @staticmethod
284    def parse_operands(operands: bytes) -> Tuple:
285        return (
286            struct.unpack(">I", b"\x00" + operands[:3])[0],
287            operands[3:],
288        )
289
290    def make_operands(self) -> bytes:
291        return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data
292
293    def __init__(self, company_id: int, vendor_dependent_data: bytes):
294        self.company_id = company_id
295        self.vendor_dependent_data = vendor_dependent_data
296
297
298# -----------------------------------------------------------------------------
299@Frame.subclass
300class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame):
301    def __init__(
302        self,
303        ctype: CommandFrame.CommandType,
304        subunit_type: Frame.SubunitType,
305        subunit_id: int,
306        company_id: int,
307        vendor_dependent_data: bytes,
308    ) -> None:
309        VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
310        CommandFrame.__init__(
311            self,
312            ctype,
313            subunit_type,
314            subunit_id,
315            Frame.OperationCode.VENDOR_DEPENDENT,
316            self.make_operands(),
317        )
318
319    def __str__(self):
320        return (
321            f"VendorDependentCommandFrame(ctype={self.ctype.name}, "
322            f"subunit_type={self.subunit_type.name}, "
323            f"subunit_id=0x{self.subunit_id:02X}, "
324            f"company_id=0x{self.company_id:06X}, "
325            f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
326        )
327
328
329# -----------------------------------------------------------------------------
330@Frame.subclass
331class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame):
332    def __init__(
333        self,
334        response: ResponseFrame.ResponseCode,
335        subunit_type: Frame.SubunitType,
336        subunit_id: int,
337        company_id: int,
338        vendor_dependent_data: bytes,
339    ) -> None:
340        VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
341        ResponseFrame.__init__(
342            self,
343            response,
344            subunit_type,
345            subunit_id,
346            Frame.OperationCode.VENDOR_DEPENDENT,
347            self.make_operands(),
348        )
349
350    def __str__(self):
351        return (
352            f"VendorDependentResponseFrame(response={self.response.name}, "
353            f"subunit_type={self.subunit_type.name}, "
354            f"subunit_id=0x{self.subunit_id:02X}, "
355            f"company_id=0x{self.company_id:06X}, "
356            f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
357        )
358
359
360# -----------------------------------------------------------------------------
361class PassThroughFrame:
362    """
363    See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command
364    """
365
366    class StateFlag(enum.IntEnum):
367        PRESSED = 0
368        RELEASED = 1
369
370    class OperationId(OpenIntEnum):
371        SELECT = 0x00
372        UP = 0x01
373        DOWN = 0x01
374        LEFT = 0x03
375        RIGHT = 0x04
376        RIGHT_UP = 0x05
377        RIGHT_DOWN = 0x06
378        LEFT_UP = 0x07
379        LEFT_DOWN = 0x08
380        ROOT_MENU = 0x09
381        SETUP_MENU = 0x0A
382        CONTENTS_MENU = 0x0B
383        FAVORITE_MENU = 0x0C
384        EXIT = 0x0D
385        NUMBER_0 = 0x20
386        NUMBER_1 = 0x21
387        NUMBER_2 = 0x22
388        NUMBER_3 = 0x23
389        NUMBER_4 = 0x24
390        NUMBER_5 = 0x25
391        NUMBER_6 = 0x26
392        NUMBER_7 = 0x27
393        NUMBER_8 = 0x28
394        NUMBER_9 = 0x29
395        DOT = 0x2A
396        ENTER = 0x2B
397        CLEAR = 0x2C
398        CHANNEL_UP = 0x30
399        CHANNEL_DOWN = 0x31
400        PREVIOUS_CHANNEL = 0x32
401        SOUND_SELECT = 0x33
402        INPUT_SELECT = 0x34
403        DISPLAY_INFORMATION = 0x35
404        HELP = 0x36
405        PAGE_UP = 0x37
406        PAGE_DOWN = 0x38
407        POWER = 0x40
408        VOLUME_UP = 0x41
409        VOLUME_DOWN = 0x42
410        MUTE = 0x43
411        PLAY = 0x44
412        STOP = 0x45
413        PAUSE = 0x46
414        RECORD = 0x47
415        REWIND = 0x48
416        FAST_FORWARD = 0x49
417        EJECT = 0x4A
418        FORWARD = 0x4B
419        BACKWARD = 0x4C
420        ANGLE = 0x50
421        SUBPICTURE = 0x51
422        F1 = 0x71
423        F2 = 0x72
424        F3 = 0x73
425        F4 = 0x74
426        F5 = 0x75
427        VENDOR_UNIQUE = 0x7E
428
429    state_flag: StateFlag
430    operation_id: OperationId
431    operation_data: bytes
432
433    @staticmethod
434    def parse_operands(operands: bytes) -> Tuple:
435        return (
436            PassThroughFrame.StateFlag(operands[0] >> 7),
437            PassThroughFrame.OperationId(operands[0] & 0x7F),
438            operands[1 : 1 + operands[1]],
439        )
440
441    def make_operands(self):
442        return (
443            bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)])
444            + self.operation_data
445        )
446
447    def __init__(
448        self,
449        state_flag: StateFlag,
450        operation_id: OperationId,
451        operation_data: bytes,
452    ) -> None:
453        if len(operation_data) > 255:
454            raise core.InvalidArgumentError("operation data must be <= 255 bytes")
455        self.state_flag = state_flag
456        self.operation_id = operation_id
457        self.operation_data = operation_data
458
459
460# -----------------------------------------------------------------------------
461@Frame.subclass
462class PassThroughCommandFrame(PassThroughFrame, CommandFrame):
463    def __init__(
464        self,
465        ctype: CommandFrame.CommandType,
466        subunit_type: Frame.SubunitType,
467        subunit_id: int,
468        state_flag: PassThroughFrame.StateFlag,
469        operation_id: PassThroughFrame.OperationId,
470        operation_data: bytes,
471    ) -> None:
472        PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
473        CommandFrame.__init__(
474            self,
475            ctype,
476            subunit_type,
477            subunit_id,
478            Frame.OperationCode.PASS_THROUGH,
479            self.make_operands(),
480        )
481
482    def __str__(self):
483        return (
484            f"PassThroughCommandFrame(ctype={self.ctype.name}, "
485            f"subunit_type={self.subunit_type.name}, "
486            f"subunit_id=0x{self.subunit_id:02X}, "
487            f"state_flag={self.state_flag.name}, "
488            f"operation_id={self.operation_id.name}, "
489            f"operation_data={self.operation_data.hex()})"
490        )
491
492
493# -----------------------------------------------------------------------------
494@Frame.subclass
495class PassThroughResponseFrame(PassThroughFrame, ResponseFrame):
496    def __init__(
497        self,
498        response: ResponseFrame.ResponseCode,
499        subunit_type: Frame.SubunitType,
500        subunit_id: int,
501        state_flag: PassThroughFrame.StateFlag,
502        operation_id: PassThroughFrame.OperationId,
503        operation_data: bytes,
504    ) -> None:
505        PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
506        ResponseFrame.__init__(
507            self,
508            response,
509            subunit_type,
510            subunit_id,
511            Frame.OperationCode.PASS_THROUGH,
512            self.make_operands(),
513        )
514
515    def __str__(self):
516        return (
517            f"PassThroughResponseFrame(response={self.response.name}, "
518            f"subunit_type={self.subunit_type.name}, "
519            f"subunit_id=0x{self.subunit_id:02X}, "
520            f"state_flag={self.state_flag.name}, "
521            f"operation_id={self.operation_id.name}, "
522            f"operation_data={self.operation_data.hex()})"
523        )
524