1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "SocketProfilingConnection.hpp"
7
8 #include <common/include/SocketConnectionException.hpp>
9
10 #include <cerrno>
11 #include <cstring>
12 #include <fcntl.h>
13 #include <string>
14
15
16 namespace arm
17 {
18 namespace pipe
19 {
20
SocketProfilingConnection()21 SocketProfilingConnection::SocketProfilingConnection()
22 {
23 #if !defined(ARMNN_DISABLE_SOCKETS)
24 arm::pipe::Initialize();
25 memset(m_Socket, 0, sizeof(m_Socket));
26 // Note: we're using Linux specific SOCK_CLOEXEC flag.
27 m_Socket[0].fd = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
28 if (m_Socket[0].fd == -1)
29 {
30 throw arm::pipe::SocketConnectionException(
31 std::string("SocketProfilingConnection: Socket construction failed: ") + strerror(errno),
32 m_Socket[0].fd,
33 errno);
34 }
35
36 // Connect to the named unix domain socket.
37 sockaddr_un server{};
38 memset(&server, 0, sizeof(sockaddr_un));
39 // As m_GatorNamespace begins with a null character we need to ignore that when getting its length.
40 memcpy(server.sun_path, m_GatorNamespace, strlen(m_GatorNamespace + 1) + 1);
41 server.sun_family = AF_UNIX;
42 if (0 != connect(m_Socket[0].fd, reinterpret_cast<const sockaddr*>(&server), sizeof(sockaddr_un)))
43 {
44 Close();
45 throw arm::pipe::SocketConnectionException(
46 std::string("SocketProfilingConnection: Cannot connect to stream socket: ") + strerror(errno),
47 m_Socket[0].fd,
48 errno);
49 }
50
51 // Our socket will only be interested in polling reads.
52 m_Socket[0].events = POLLIN;
53
54 // Make the socket non blocking.
55 if (!arm::pipe::SetNonBlocking(m_Socket[0].fd))
56 {
57 Close();
58 throw arm::pipe::SocketConnectionException(
59 std::string("SocketProfilingConnection: Failed to set socket as non blocking: ") + strerror(errno),
60 m_Socket[0].fd,
61 errno);
62 }
63 #endif
64 }
65
IsOpen() const66 bool SocketProfilingConnection::IsOpen() const
67 {
68 #if !defined(ARMNN_DISABLE_SOCKETS)
69 return m_Socket[0].fd > 0;
70 #else
71 return false;
72 #endif
73 }
74
Close()75 void SocketProfilingConnection::Close()
76 {
77 #if !defined(ARMNN_DISABLE_SOCKETS)
78 if (arm::pipe::Close(m_Socket[0].fd) != 0)
79 {
80 throw arm::pipe::SocketConnectionException(
81 std::string("SocketProfilingConnection: Cannot close stream socket: ") + strerror(errno),
82 m_Socket[0].fd,
83 errno);
84 }
85
86 memset(m_Socket, 0, sizeof(m_Socket));
87 #endif
88 }
89
WritePacket(const unsigned char * buffer,uint32_t length)90 bool SocketProfilingConnection::WritePacket(const unsigned char* buffer, uint32_t length)
91 {
92 if (buffer == nullptr || length == 0)
93 {
94 return false;
95 }
96 #if !defined(ARMNN_DISABLE_SOCKETS)
97 return arm::pipe::Write(m_Socket[0].fd, buffer, length) != -1;
98 #else
99 return false;
100 #endif
101 }
102
ReadPacket(uint32_t timeout)103 arm::pipe::Packet SocketProfilingConnection::ReadPacket(uint32_t timeout)
104 {
105 #if !defined(ARMNN_DISABLE_SOCKETS)
106 // Is there currently at least a header worth of data waiting to be read?
107 int bytes_available = 0;
108 arm::pipe::Ioctl(m_Socket[0].fd, FIONREAD, &bytes_available);
109 if (bytes_available >= 8)
110 {
111 // Yes there is. Read it:
112 return ReceivePacket();
113 }
114
115 // Poll for data on the socket or until timeout occurs
116 int pollResult = arm::pipe::Poll(&m_Socket[0], 1, static_cast<int>(timeout));
117
118 switch (pollResult)
119 {
120 case -1: // Error
121 throw arm::pipe::SocketConnectionException(
122 std::string("SocketProfilingConnection: Error occured while reading from socket: ") + strerror(errno),
123 m_Socket[0].fd,
124 errno);
125
126 case 0: // Timeout
127 throw arm::pipe::TimeoutException("SocketProfilingConnection: Timeout while reading from socket");
128
129 default: // Normal poll return but it could still contain an error signal
130 // Check if the socket reported an error
131 if (m_Socket[0].revents & (POLLNVAL | POLLERR | POLLHUP))
132 {
133 if (m_Socket[0].revents == POLLNVAL)
134 {
135 // This is an unrecoverable error.
136 Close();
137 throw arm::pipe::SocketConnectionException(
138 std::string("SocketProfilingConnection: Error occured while polling receiving socket: POLLNVAL."),
139 m_Socket[0].fd);
140 }
141 if (m_Socket[0].revents == POLLERR)
142 {
143 throw arm::pipe::SocketConnectionException(
144 std::string(
145 "SocketProfilingConnection: Error occured while polling receiving socket: POLLERR: ")
146 + strerror(errno),
147 m_Socket[0].fd,
148 errno);
149 }
150 if (m_Socket[0].revents == POLLHUP)
151 {
152 // This is an unrecoverable error.
153 Close();
154 throw arm::pipe::SocketConnectionException(
155 std::string("SocketProfilingConnection: Connection closed by remote client: POLLHUP."),
156 m_Socket[0].fd);
157 }
158 }
159
160 // Check if there is data to read
161 if (!(m_Socket[0].revents & (POLLIN)))
162 {
163 // This is a corner case. The socket as been woken up but not with any data.
164 // We'll throw a timeout exception to loop around again.
165 throw arm::pipe::TimeoutException(
166 "SocketProfilingConnection: File descriptor was polled but no data was available to receive.");
167 }
168
169 return ReceivePacket();
170 }
171 #else
172 IgnoreUnused(timeout);
173 throw arm::pipe::TimeoutException(
174 "SocketProfilingConnection: Cannot use ReadPacket function with sockets disabled");
175 #endif
176 }
177
ReceivePacket()178 arm::pipe::Packet SocketProfilingConnection::ReceivePacket()
179 {
180 #if !defined(ARMNN_DISABLE_SOCKETS)
181 char header[8] = {};
182 long receiveResult = arm::pipe::Read(m_Socket[0].fd, &header, sizeof(header));
183 // We expect 8 as the result here. 0 means EOF, socket is closed. -1 means there been some other kind of error.
184 switch( receiveResult )
185 {
186 case 0:
187 // Socket has closed.
188 Close();
189 throw arm::pipe::SocketConnectionException(
190 std::string("SocketProfilingConnection: Remote socket has closed the connection."),
191 m_Socket[0].fd);
192 case -1:
193 // There's been a socket error. We will presume it's unrecoverable.
194 Close();
195 throw arm::pipe::SocketConnectionException(
196 std::string("SocketProfilingConnection: Error occured while reading the packet: ") + strerror(errno),
197 m_Socket[0].fd,
198 errno);
199 default:
200 if (receiveResult < 8)
201 {
202 throw arm::pipe::SocketConnectionException(
203 std::string(
204 "SocketProfilingConnection: The received packet did not contains a valid PIPE header."),
205 m_Socket[0].fd);
206 }
207 break;
208 }
209
210 // stream_metadata_identifier is the first 4 bytes
211 uint32_t metadataIdentifier = 0;
212 std::memcpy(&metadataIdentifier, header, sizeof(metadataIdentifier));
213
214 // data_length is the next 4 bytes
215 uint32_t dataLength = 0;
216 std::memcpy(&dataLength, header + 4u, sizeof(dataLength));
217
218 std::unique_ptr<unsigned char[]> packetData;
219 if (dataLength > 0)
220 {
221 packetData = std::make_unique<unsigned char[]>(dataLength);
222 long receivedLength = arm::pipe::Read(m_Socket[0].fd, packetData.get(), dataLength);
223 if (receivedLength < 0)
224 {
225 throw arm::pipe::SocketConnectionException(
226 std::string("SocketProfilingConnection: Error occured while reading the packet: ") + strerror(errno),
227 m_Socket[0].fd,
228 errno);
229 }
230 if (dataLength != static_cast<uint32_t>(receivedLength))
231 {
232 // What do we do here if we can't read in a full packet?
233 throw arm::pipe::SocketConnectionException(
234 std::string("SocketProfilingConnection: Invalid PIPE packet."),
235 m_Socket[0].fd);
236 }
237 }
238
239 return arm::pipe::Packet(metadataIdentifier, dataLength, packetData);
240 #else
241 throw arm::pipe::TimeoutException(
242 "SocketProfilingConnection: Cannot use ReceivePacket function with sockets disabled");
243 #endif
244 }
245
246 } // namespace pipe
247 } // namespace arm
248