1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "common/libs/utils/vsock_connection.h"
18
19 #include <sys/types.h>
20 #include <sys/socket.h>
21 #include <sys/time.h>
22
23 #include <functional>
24 #include <future>
25 #include <memory>
26 #include <mutex>
27 #include <new>
28 #include <ostream>
29 #include <string>
30 #include <tuple>
31 #include <utility>
32 #include <vector>
33
34 #include <android-base/logging.h>
35 #include <json/json.h>
36
37 #include "common/libs/fs/shared_buf.h"
38 #include "common/libs/fs/shared_select.h"
39
40 namespace cuttlefish {
41
~VsockConnection()42 VsockConnection::~VsockConnection() { Disconnect(); }
43
ConnectAsync(unsigned int port,unsigned int cid,std::optional<int> vhost_user_vsock_cid_)44 std::future<bool> VsockConnection::ConnectAsync(
45 unsigned int port, unsigned int cid,
46 std::optional<int> vhost_user_vsock_cid_) {
47 return std::async(std::launch::async,
48 [this, port, cid, vhost_user_vsock_cid_]() {
49 return Connect(port, cid, vhost_user_vsock_cid_);
50 });
51 }
52
Disconnect()53 void VsockConnection::Disconnect() {
54 // We need to serialize all accesses to the SharedFD.
55 std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
56 std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
57
58 LOG(INFO) << "Disconnecting with fd status:" << fd_->StrError();
59 fd_->Shutdown(SHUT_RDWR);
60 if (disconnect_callback_) {
61 disconnect_callback_();
62 }
63 fd_->Close();
64 }
65
SetDisconnectCallback(std::function<void ()> callback)66 void VsockConnection::SetDisconnectCallback(std::function<void()> callback) {
67 disconnect_callback_ = callback;
68 }
69
70 // This method created due to a race condition in IsConnected().
71 // TODO(b/345285391): remove this method once a fix found
IsConnected_Unguarded()72 bool VsockConnection::IsConnected_Unguarded() { return fd_->IsOpen(); }
73
IsConnected()74 bool VsockConnection::IsConnected() {
75 // We need to serialize all accesses to the SharedFD.
76 std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
77 std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
78
79 return fd_->IsOpen();
80 }
81
DataAvailable()82 bool VsockConnection::DataAvailable() {
83 SharedFDSet read_set;
84
85 // We need to serialize all accesses to the SharedFD.
86 std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
87 std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
88
89 read_set.Set(fd_);
90 struct timeval timeout = {0, 0};
91 return Select(&read_set, nullptr, nullptr, &timeout) > 0;
92 }
93
Read()94 int32_t VsockConnection::Read() {
95 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
96 int32_t result;
97 if (ReadExactBinary(fd_, &result) != sizeof(result)) {
98 Disconnect();
99 return 0;
100 }
101 return result;
102 }
103
Read(std::vector<char> & data)104 bool VsockConnection::Read(std::vector<char>& data) {
105 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
106 return ReadExact(fd_, &data) == data.size();
107 }
108
Read(size_t size)109 std::vector<char> VsockConnection::Read(size_t size) {
110 if (size == 0) {
111 return {};
112 }
113 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
114 std::vector<char> result(size);
115 if (ReadExact(fd_, &result) != size) {
116 Disconnect();
117 return {};
118 }
119 return result;
120 }
121
ReadAsync(size_t size)122 std::future<std::vector<char>> VsockConnection::ReadAsync(size_t size) {
123 return std::async(std::launch::async, [this, size]() { return Read(size); });
124 }
125
126 // Message format is buffer size followed by buffer data
ReadMessage()127 std::vector<char> VsockConnection::ReadMessage() {
128 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
129 auto size = Read();
130 if (size < 0) {
131 Disconnect();
132 return {};
133 }
134 return Read(size);
135 }
136
ReadMessage(std::vector<char> & data)137 bool VsockConnection::ReadMessage(std::vector<char>& data) {
138 std::lock_guard<std::recursive_mutex> lock(read_mutex_);
139 auto size = Read();
140 if (size < 0) {
141 Disconnect();
142 return false;
143 }
144 data.resize(size);
145 return Read(data);
146 }
147
ReadMessageAsync()148 std::future<std::vector<char>> VsockConnection::ReadMessageAsync() {
149 return std::async(std::launch::async, [this]() { return ReadMessage(); });
150 }
151
ReadJsonMessage()152 Json::Value VsockConnection::ReadJsonMessage() {
153 auto msg = ReadMessage();
154 Json::CharReaderBuilder builder;
155 std::unique_ptr<Json::CharReader> reader(builder.newCharReader());
156 Json::Value json_msg;
157 std::string errors;
158 if (!reader->parse(msg.data(), msg.data() + msg.size(), &json_msg, &errors)) {
159 return {};
160 }
161 return json_msg;
162 }
163
ReadJsonMessageAsync()164 std::future<Json::Value> VsockConnection::ReadJsonMessageAsync() {
165 return std::async(std::launch::async, [this]() { return ReadJsonMessage(); });
166 }
167
Write(int32_t data)168 bool VsockConnection::Write(int32_t data) {
169 std::lock_guard<std::recursive_mutex> lock(write_mutex_);
170 if (WriteAllBinary(fd_, &data) != sizeof(data)) {
171 Disconnect();
172 return false;
173 }
174 return true;
175 }
176
Write(const char * data,unsigned int size)177 bool VsockConnection::Write(const char* data, unsigned int size) {
178 std::lock_guard<std::recursive_mutex> lock(write_mutex_);
179 if (WriteAll(fd_, data, size) != size) {
180 Disconnect();
181 return false;
182 }
183 return true;
184 }
185
Write(const std::vector<char> & data)186 bool VsockConnection::Write(const std::vector<char>& data) {
187 return Write(data.data(), data.size());
188 }
189
190 // Message format is buffer size followed by buffer data
WriteMessage(const std::string & data)191 bool VsockConnection::WriteMessage(const std::string& data) {
192 return Write(data.size()) && Write(data.c_str(), data.length());
193 }
194
WriteMessage(const std::vector<char> & data)195 bool VsockConnection::WriteMessage(const std::vector<char>& data) {
196 std::lock_guard<std::recursive_mutex> lock(write_mutex_);
197 return Write(data.size()) && Write(data);
198 }
199
WriteMessage(const Json::Value & data)200 bool VsockConnection::WriteMessage(const Json::Value& data) {
201 Json::StreamWriterBuilder factory;
202 std::string message_str = Json::writeString(factory, data);
203 return WriteMessage(message_str);
204 }
205
WriteStrides(const char * data,unsigned int size,unsigned int num_strides,int stride_size)206 bool VsockConnection::WriteStrides(const char* data, unsigned int size,
207 unsigned int num_strides, int stride_size) {
208 const char* src = data;
209 for (unsigned int i = 0; i < num_strides; ++i, src += stride_size) {
210 if (!Write(src, size)) {
211 return false;
212 }
213 }
214 return true;
215 }
216
Connect(unsigned int port,unsigned int cid,std::optional<int> vhost_user)217 bool VsockClientConnection::Connect(unsigned int port, unsigned int cid,
218 std::optional<int> vhost_user) {
219 fd_ =
220 SharedFD::VsockClient(cid, port, SOCK_STREAM, vhost_user ? true : false);
221 if (!fd_->IsOpen()) {
222 LOG(ERROR) << "Failed to connect:" << fd_->StrError();
223 }
224 return fd_->IsOpen();
225 }
226
~VsockServerConnection()227 VsockServerConnection::~VsockServerConnection() { ServerShutdown(); }
228
ServerShutdown()229 void VsockServerConnection::ServerShutdown() {
230 if (server_fd_->IsOpen()) {
231 LOG(INFO) << __FUNCTION__
232 << ": server fd status:" << server_fd_->StrError();
233 server_fd_->Shutdown(SHUT_RDWR);
234 server_fd_->Close();
235 }
236 }
237
Connect(unsigned int port,unsigned int cid,std::optional<int> vhost_user_vsock_cid)238 bool VsockServerConnection::Connect(unsigned int port, unsigned int cid,
239 std::optional<int> vhost_user_vsock_cid) {
240 if (!server_fd_->IsOpen()) {
241 server_fd_ = cuttlefish::SharedFD::VsockServer(port, SOCK_STREAM,
242 vhost_user_vsock_cid, cid);
243 }
244 if (server_fd_->IsOpen()) {
245 fd_ = SharedFD::Accept(*server_fd_);
246 return fd_->IsOpen();
247 } else {
248 return false;
249 }
250 }
251
252 } // namespace cuttlefish
253