1import asyncio 2from typing import Union, Type, TypeVar 3 4from .packets import uci 5from .packets.uci import CommonPacketHeader, ControlPacketHeader, DataPacketHeader 6 7UciPacket = TypeVar("UciPacket", uci.DataPacket, uci.ControlPacket) 8 9 10class Host: 11 def __init__(self, reader, writer, mac_address: bytes): 12 self.reader = reader 13 self.writer = writer 14 self.control_queue = asyncio.Queue() 15 self.data_queue = asyncio.Queue() 16 self.mac_address = mac_address 17 18 loop = asyncio.get_event_loop() 19 self.reader_task = loop.create_task(self._read_packets()) 20 21 @staticmethod 22 async def connect(address: str, port: int, mac_address: bytes) -> "Host": 23 reader, writer = await asyncio.open_connection(address, port) 24 return Host(reader, writer, mac_address) 25 26 def disconnect(self): 27 self.writer.close() 28 self.reader_task.cancel() 29 30 async def _read_exact(self, expected_len: int) -> bytes: 31 """Read an exact number of bytes from the socket. 32 33 Raises an exception if the socket gets disconnected.""" 34 received = bytes() 35 while len(received) < expected_len: 36 chunk = await self.reader.read(expected_len - len(received)) 37 received += chunk 38 return received 39 40 async def _read_packet(self) -> bytes: 41 """Read a single UCI packet from the socket. 42 43 The packet is automatically re-assembled if segmented on 44 the UCI transport.""" 45 46 complete_packet_bytes = bytes() 47 48 # Note on reassembly: 49 # For each segment of a Control Message, the 50 # header of the Control Packet SHALL contain the same MT, GID and OID 51 # values. It is correct to keep only the last header of the 52 # segmented packet. 53 while True: 54 # Read the common packet header. 55 header_bytes = await self._read_exact(4) 56 common_header: CommonPacketHeader = uci.CommonPacketHeader.parse_all( 57 header_bytes[0:1] 58 ) # type: ignore 59 60 if common_header.mt == uci.MessageType.DATA: 61 # Read the packet payload. 62 data_header: DataPacketHeader = uci.DataPacketHeader.parse_all(header_bytes) # type: ignore 63 payload_bytes = await self._read_exact(data_header.payload_length) 64 65 else: 66 # Read the packet payload. 67 control_header: ControlPacketHeader = uci.ControlPacketHeader.parse_all(header_bytes) # type: ignore 68 payload_bytes = await self._read_exact(control_header.payload_length) 69 70 complete_packet_bytes += payload_bytes 71 72 # Check the Packet Boundary Flag. 73 match common_header.pbf: 74 case uci.PacketBoundaryFlag.COMPLETE: 75 return header_bytes + complete_packet_bytes 76 case uci.PacketBoundaryFlag.NOT_COMPLETE: 77 pass 78 79 async def _read_packets(self): 80 """Loop reading UCI packets from the socket. 81 Receiving packets are added to the control queue.""" 82 try: 83 while True: 84 packet = await self._read_packet() 85 header: CommonPacketHeader = uci.CommonPacketHeader.parse_all(packet[0:1]) # type: ignore 86 if header.mt == uci.MessageType.DATA: 87 await self.data_queue.put(packet) 88 else: 89 await self.control_queue.put(packet) 90 except Exception as exn: 91 print(f"reader task closed") 92 93 async def _recv_control(self) -> bytes: 94 return await self.control_queue.get() 95 96 async def _recv_data(self) -> bytes: 97 return await self.data_queue.get() 98 99 def send_control(self, packet: uci.ControlPacket): 100 # TODO packet fragmentation. 101 packet = bytearray(packet.serialize()) 102 packet[3] = len(packet) - 4 103 self.writer.write(packet) 104 105 def send_data(self, packet: uci.DataPacket): 106 packet = bytearray(packet.serialize()) 107 size = len(packet) - 4 108 size_bytes = size.to_bytes(2, byteorder="little") 109 packet[2] = size_bytes[0] 110 packet[3] = size_bytes[1] 111 self.writer.write(packet) 112 113 async def expect_control( 114 self, 115 expected: Union[Type[uci.ControlPacket], uci.ControlPacket], 116 timeout: float = 1.0, 117 ) -> uci.ControlPacket: 118 """Wait for a control packet being sent from the controller. 119 120 Raises ValueError if the packet is not well formatted. 121 Raises ValueError if the packet does not match the expected type or value. 122 Raises TimeoutError if no packet is received after `timeout` seconds. 123 Returns the received packet on success. 124 """ 125 126 packet = await asyncio.wait_for(self._recv_control(), timeout=timeout) 127 received = uci.ControlPacket.parse_all(packet) 128 129 if isinstance(expected, type) and not isinstance(received, expected): 130 raise ValueError( 131 f"received unexpected packet {received.__class__.__name__}," 132 + f" expected {expected.__name__}" 133 ) 134 135 if isinstance(expected, uci.ControlPacket) and received != expected: 136 raise ValueError( 137 f"received unexpected packet {received.__class__.__name__}," 138 + f" expected {expected.__class__.__name__}" 139 ) 140 141 return received 142 143 async def expect_data( 144 self, 145 expected: Union[Type[uci.DataPacket], uci.DataPacket], 146 timeout: float = 1.0, 147 ) -> uci.DataPacket: 148 """Wait for a data packet being sent from the controller. 149 150 Raises ValueError if the packet is not well formatted. 151 Raises ValueError if the packet does not match the expected type or value. 152 Raises TimeoutError if no packet is received after `timeout` seconds. 153 Returns the received packet on success. 154 """ 155 156 packet = await asyncio.wait_for(self._recv_data(), timeout=timeout) 157 received = uci.DataPacket.parse_all(packet) 158 159 if isinstance(expected, type) and not isinstance(received, expected): 160 raise ValueError( 161 f"received unexpected packet {received.__class__.__name__}," 162 + f" expected {expected.__name__}" 163 ) 164 165 if isinstance(expected, uci.DataPacket) and received != expected: 166 raise ValueError( 167 f"received unexpected packet {received.__class__.__name__}," 168 + f" expected {expected.__class__.__name__}" 169 ) 170 171 return received 172