xref: /aosp_15_r20/hardware/interfaces/automotive/can/1.0/default/libnl++/Socket.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <libnl++/Socket.h>
18 
19 #include <libnl++/printer.h>
20 
21 #include <android-base/logging.h>
22 
23 // Should be in sys/socket.h or linux/socket.h
24 #define SOL_NETLINK 270
25 
26 namespace android::nl {
27 
28 /**
29  * Print all outbound/inbound Netlink messages.
30  */
31 static constexpr bool kSuperVerbose = false;
32 
Socket(int protocol,unsigned pid,uint32_t groups)33 Socket::Socket(int protocol, unsigned pid, uint32_t groups) : mProtocol(protocol) {
34     mFd.reset(socket(AF_NETLINK, SOCK_RAW, protocol));
35     if (!mFd.ok()) {
36         PLOG(ERROR) << "Can't open Netlink socket";
37         mFailed = true;
38         return;
39     }
40 
41     sockaddr_nl sa = {};
42     sa.nl_family = AF_NETLINK;
43     sa.nl_pid = pid;
44     sa.nl_groups = groups;
45 
46     if (bind(mFd.get(), reinterpret_cast<sockaddr*>(&sa), sizeof(sa)) < 0) {
47         PLOG(ERROR) << "Can't bind Netlink socket";
48         mFd.reset();
49         mFailed = true;
50     }
51 }
52 
clearPollErr()53 void Socket::clearPollErr() {
54     sockaddr_nl sa = {};
55     socklen_t saLen = sizeof(sa);
56     const auto bytesReceived = recvfrom(mFd.get(), mReceiveBuffer.data(), mReceiveBuffer.size(), 0,
57                                         reinterpret_cast<sockaddr*>(&sa), &saLen);
58     if (errno != EINVAL) {
59         PLOG(WARNING) << "clearPollError() caught unexpected error: ";
60     }
61     CHECK_LE(bytesReceived, 0) << "clearPollError() didn't find an error!";
62 }
63 
send(const Buffer<nlmsghdr> & msg,const sockaddr_nl & sa)64 bool Socket::send(const Buffer<nlmsghdr>& msg, const sockaddr_nl& sa) {
65     if constexpr (kSuperVerbose) {
66         LOG(VERBOSE) << (mFailed ? "(not) " : "") << "sending to " << sa.nl_pid << ": "
67                      << toString(msg, mProtocol);
68     }
69     if (mFailed) return false;
70 
71     mSeq = msg->nlmsg_seq;
72     const auto rawMsg = msg.getRaw();
73     const auto bytesSent = sendto(mFd.get(), rawMsg.ptr(), rawMsg.len(), 0,
74                                   reinterpret_cast<const sockaddr*>(&sa), sizeof(sa));
75     if (bytesSent < 0) {
76         PLOG(ERROR) << "Can't send Netlink message";
77         return false;
78     } else if (size_t(bytesSent) != rawMsg.len()) {
79         LOG(ERROR) << "Can't send Netlink message: truncated message";
80         return false;
81     }
82     return true;
83 }
84 
send(const Buffer<nlmsghdr> & msg,uint32_t destination)85 bool Socket::send(const Buffer<nlmsghdr>& msg, uint32_t destination) {
86     sockaddr_nl sa = {.nl_family = AF_NETLINK, .nl_pad = 0, .nl_pid = destination, .nl_groups = 0};
87     return send(msg, sa);
88 }
89 
increaseReceiveBuffer(size_t maxSize)90 bool Socket::increaseReceiveBuffer(size_t maxSize) {
91     if (maxSize == 0) {
92         LOG(ERROR) << "Maximum receive size should not be zero";
93         return false;
94     }
95 
96     if (mReceiveBuffer.size() < maxSize) mReceiveBuffer.resize(maxSize);
97     return true;
98 }
99 
receive(size_t maxSize)100 std::optional<Buffer<nlmsghdr>> Socket::receive(size_t maxSize) {
101     return receiveFrom(maxSize).first;
102 }
103 
receiveFrom(size_t maxSize)104 std::pair<std::optional<Buffer<nlmsghdr>>, sockaddr_nl> Socket::receiveFrom(size_t maxSize) {
105     if (mFailed) return {std::nullopt, {}};
106 
107     if (!increaseReceiveBuffer(maxSize)) return {std::nullopt, {}};
108 
109     sockaddr_nl sa = {};
110     socklen_t saLen = sizeof(sa);
111     const auto bytesReceived = recvfrom(mFd.get(), mReceiveBuffer.data(), maxSize, MSG_TRUNC,
112                                         reinterpret_cast<sockaddr*>(&sa), &saLen);
113 
114     if (bytesReceived <= 0) {
115         PLOG(ERROR) << "Failed to receive Netlink message";
116         return {std::nullopt, {}};
117     } else if (size_t(bytesReceived) > maxSize) {
118         PLOG(ERROR) << "Received data larger than maximum receive size: "  //
119                     << bytesReceived << " > " << maxSize;
120         return {std::nullopt, {}};
121     }
122 
123     Buffer<nlmsghdr> msg(reinterpret_cast<nlmsghdr*>(mReceiveBuffer.data()), bytesReceived);
124     if constexpr (kSuperVerbose) {
125         LOG(VERBOSE) << "received from " << sa.nl_pid << ": " << toString(msg, mProtocol);
126     }
127     long headerByteTotal = 0;
128     for (const auto hdr : msg) {
129         headerByteTotal += hdr->nlmsg_len;
130     }
131     if (bytesReceived != headerByteTotal) {
132         LOG(ERROR) << "received " << bytesReceived << " bytes, header claims " << headerByteTotal;
133     }
134     return {msg, sa};
135 }
136 
receiveAck(uint32_t seq)137 bool Socket::receiveAck(uint32_t seq) {
138     const auto nlerr = receive<nlmsgerr>({NLMSG_ERROR});
139     if (!nlerr.has_value()) return false;
140 
141     if (nlerr->data.msg.nlmsg_seq != seq) {
142         LOG(ERROR) << "Received ACK for a different message (" << nlerr->data.msg.nlmsg_seq
143                    << ", expected " << seq << "). Multi-message tracking is not implemented.";
144         return false;
145     }
146 
147     if (nlerr->data.error == 0) return true;
148 
149     LOG(WARNING) << "Received Netlink error message: " << strerror(-nlerr->data.error);
150     return false;
151 }
152 
receive(const std::set<nlmsgtype_t> & msgtypes,size_t maxSize)153 std::optional<Buffer<nlmsghdr>> Socket::receive(const std::set<nlmsgtype_t>& msgtypes,
154                                                 size_t maxSize) {
155     if (mFailed || !increaseReceiveBuffer(maxSize)) return std::nullopt;
156 
157     for (const auto rawMsg : *this) {
158         if (msgtypes.count(rawMsg->nlmsg_type) == 0) {
159             LOG(WARNING) << "Received (and ignored) unexpected Netlink message of type "
160                          << rawMsg->nlmsg_type;
161             continue;
162         }
163 
164         return rawMsg;
165     }
166 
167     return std::nullopt;
168 }
169 
getPid()170 std::optional<unsigned> Socket::getPid() {
171     if (mFailed) return std::nullopt;
172 
173     sockaddr_nl sa = {};
174     socklen_t sasize = sizeof(sa);
175     if (getsockname(mFd.get(), reinterpret_cast<sockaddr*>(&sa), &sasize) < 0) {
176         PLOG(ERROR) << "Failed to get PID of Netlink socket";
177         return std::nullopt;
178     }
179     return sa.nl_pid;
180 }
181 
preparePoll(short events)182 pollfd Socket::preparePoll(short events) {
183     CHECK(mFd.get() > 0) << "Netlink socket fd is invalid!";
184     return {mFd.get(), events, 0};
185 }
186 
addMembership(unsigned group)187 bool Socket::addMembership(unsigned group) {
188     const auto res =
189             setsockopt(mFd.get(), SOL_NETLINK, NETLINK_ADD_MEMBERSHIP, &group, sizeof(group));
190     if (res < 0) {
191         PLOG(ERROR) << "Failed joining multicast group " << group;
192         return false;
193     }
194     return true;
195 }
196 
dropMembership(unsigned group)197 bool Socket::dropMembership(unsigned group) {
198     const auto res =
199             setsockopt(mFd.get(), SOL_NETLINK, NETLINK_DROP_MEMBERSHIP, &group, sizeof(group));
200     if (res < 0) {
201         PLOG(ERROR) << "Failed leaving multicast group " << group;
202         return false;
203     }
204     return true;
205 }
206 
receive_iterator(Socket & socket,bool end)207 Socket::receive_iterator::receive_iterator(Socket& socket, bool end)
208     : mSocket(socket), mIsEnd(end) {
209     if (!end) receive();
210 }
211 
operator ++()212 Socket::receive_iterator Socket::receive_iterator::operator++() {
213     CHECK(!mIsEnd) << "Trying to increment end iterator";
214     ++mCurrent;
215     if (mCurrent.isEnd()) receive();
216     return *this;
217 }
218 
operator ==(const receive_iterator & other) const219 bool Socket::receive_iterator::operator==(const receive_iterator& other) const {
220     if (mIsEnd != other.mIsEnd) return false;
221     if (mIsEnd && other.mIsEnd) return true;
222     return mCurrent == other.mCurrent;
223 }
224 
operator *() const225 const Buffer<nlmsghdr>& Socket::receive_iterator::operator*() const {
226     CHECK(!mIsEnd) << "Trying to dereference end iterator";
227     return *mCurrent;
228 }
229 
receive()230 void Socket::receive_iterator::receive() {
231     CHECK(!mIsEnd) << "Trying to receive on end iterator";
232     CHECK(mCurrent.isEnd()) << "Trying to receive without draining previous read";
233 
234     const auto buf = mSocket.receive();
235     if (buf.has_value()) {
236         mCurrent = buf->begin();
237     } else {
238         mIsEnd = true;
239     }
240 }
241 
begin()242 Socket::receive_iterator Socket::begin() {
243     return {*this, false};
244 }
245 
end()246 Socket::receive_iterator Socket::end() {
247     return {*this, true};
248 }
249 
250 }  // namespace android::nl
251