1import asyncio
2from typing import Union, Type, TypeVar
3
4from .packets import uci
5from .packets.uci import CommonPacketHeader, ControlPacketHeader, DataPacketHeader
6
7UciPacket = TypeVar("UciPacket", uci.DataPacket, uci.ControlPacket)
8
9
10class Host:
11    def __init__(self, reader, writer, mac_address: bytes):
12        self.reader = reader
13        self.writer = writer
14        self.control_queue = asyncio.Queue()
15        self.data_queue = asyncio.Queue()
16        self.mac_address = mac_address
17
18        loop = asyncio.get_event_loop()
19        self.reader_task = loop.create_task(self._read_packets())
20
21    @staticmethod
22    async def connect(address: str, port: int, mac_address: bytes) -> "Host":
23        reader, writer = await asyncio.open_connection(address, port)
24        return Host(reader, writer, mac_address)
25
26    def disconnect(self):
27        self.writer.close()
28        self.reader_task.cancel()
29
30    async def _read_exact(self, expected_len: int) -> bytes:
31        """Read an exact number of bytes from the socket.
32
33        Raises an exception if the socket gets disconnected."""
34        received = bytes()
35        while len(received) < expected_len:
36            chunk = await self.reader.read(expected_len - len(received))
37            received += chunk
38        return received
39
40    async def _read_packet(self) -> bytes:
41        """Read a single UCI packet from the socket.
42
43        The packet is automatically re-assembled if segmented on
44        the UCI transport."""
45
46        complete_packet_bytes = bytes()
47
48        # Note on reassembly:
49        # For each segment of a Control Message, the
50        # header of the Control Packet SHALL contain the same MT, GID and OID
51        # values. It is correct to keep only the last header of the
52        # segmented packet.
53        while True:
54            # Read the common packet header.
55            header_bytes = await self._read_exact(4)
56            common_header: CommonPacketHeader = uci.CommonPacketHeader.parse_all(
57                header_bytes[0:1]
58            )  # type: ignore
59
60            if common_header.mt == uci.MessageType.DATA:
61                # Read the packet payload.
62                data_header: DataPacketHeader = uci.DataPacketHeader.parse_all(header_bytes)  # type: ignore
63                payload_bytes = await self._read_exact(data_header.payload_length)
64
65            else:
66                # Read the packet payload.
67                control_header: ControlPacketHeader = uci.ControlPacketHeader.parse_all(header_bytes)  # type: ignore
68                payload_bytes = await self._read_exact(control_header.payload_length)
69
70            complete_packet_bytes += payload_bytes
71
72            # Check the Packet Boundary Flag.
73            match common_header.pbf:
74                case uci.PacketBoundaryFlag.COMPLETE:
75                    return header_bytes + complete_packet_bytes
76                case uci.PacketBoundaryFlag.NOT_COMPLETE:
77                    pass
78
79    async def _read_packets(self):
80        """Loop reading UCI packets from the socket.
81        Receiving packets are added to the control queue."""
82        try:
83            while True:
84                packet = await self._read_packet()
85                header: CommonPacketHeader = uci.CommonPacketHeader.parse_all(packet[0:1])  # type: ignore
86                if header.mt == uci.MessageType.DATA:
87                    await self.data_queue.put(packet)
88                else:
89                    await self.control_queue.put(packet)
90        except Exception as exn:
91            print(f"reader task closed")
92
93    async def _recv_control(self) -> bytes:
94        return await self.control_queue.get()
95
96    async def _recv_data(self) -> bytes:
97        return await self.data_queue.get()
98
99    def send_control(self, packet: uci.ControlPacket):
100        # TODO packet fragmentation.
101        packet = bytearray(packet.serialize())
102        packet[3] = len(packet) - 4
103        self.writer.write(packet)
104
105    def send_data(self, packet: uci.DataPacket):
106        packet = bytearray(packet.serialize())
107        size = len(packet) - 4
108        size_bytes = size.to_bytes(2, byteorder="little")
109        packet[2] = size_bytes[0]
110        packet[3] = size_bytes[1]
111        self.writer.write(packet)
112
113    async def expect_control(
114        self,
115        expected: Union[Type[uci.ControlPacket], uci.ControlPacket],
116        timeout: float = 1.0,
117    ) -> uci.ControlPacket:
118        """Wait for a control packet being sent from the controller.
119
120        Raises ValueError if the packet is not well formatted.
121        Raises ValueError if the packet does not match the expected type or value.
122        Raises TimeoutError if no packet is received after `timeout` seconds.
123        Returns the received packet on success.
124        """
125
126        packet = await asyncio.wait_for(self._recv_control(), timeout=timeout)
127        received = uci.ControlPacket.parse_all(packet)
128
129        if isinstance(expected, type) and not isinstance(received, expected):
130            raise ValueError(
131                f"received unexpected packet {received.__class__.__name__},"
132                + f" expected {expected.__name__}"
133            )
134
135        if isinstance(expected, uci.ControlPacket) and received != expected:
136            raise ValueError(
137                f"received unexpected packet {received.__class__.__name__},"
138                + f" expected {expected.__class__.__name__}"
139            )
140
141        return received
142
143    async def expect_data(
144        self,
145        expected: Union[Type[uci.DataPacket], uci.DataPacket],
146        timeout: float = 1.0,
147    ) -> uci.DataPacket:
148        """Wait for a data packet being sent from the controller.
149
150        Raises ValueError if the packet is not well formatted.
151        Raises ValueError if the packet does not match the expected type or value.
152        Raises TimeoutError if no packet is received after `timeout` seconds.
153        Returns the received packet on success.
154        """
155
156        packet = await asyncio.wait_for(self._recv_data(), timeout=timeout)
157        received = uci.DataPacket.parse_all(packet)
158
159        if isinstance(expected, type) and not isinstance(received, expected):
160            raise ValueError(
161                f"received unexpected packet {received.__class__.__name__},"
162                + f" expected {expected.__name__}"
163            )
164
165        if isinstance(expected, uci.DataPacket) and received != expected:
166            raise ValueError(
167                f"received unexpected packet {received.__class__.__name__},"
168                + f" expected {expected.__class__.__name__}"
169            )
170
171        return received
172