1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/libs/utils/vsock_connection.h"
18 
19 #include <sys/types.h>
20 #include <sys/socket.h>
21 #include <sys/time.h>
22 
23 #include <functional>
24 #include <future>
25 #include <memory>
26 #include <mutex>
27 #include <new>
28 #include <ostream>
29 #include <string>
30 #include <tuple>
31 #include <utility>
32 #include <vector>
33 
34 #include <android-base/logging.h>
35 #include <json/json.h>
36 
37 #include "common/libs/fs/shared_buf.h"
38 #include "common/libs/fs/shared_select.h"
39 
40 namespace cuttlefish {
41 
~VsockConnection()42 VsockConnection::~VsockConnection() { Disconnect(); }
43 
ConnectAsync(unsigned int port,unsigned int cid,std::optional<int> vhost_user_vsock_cid_)44 std::future<bool> VsockConnection::ConnectAsync(
45     unsigned int port, unsigned int cid,
46     std::optional<int> vhost_user_vsock_cid_) {
47   return std::async(std::launch::async,
48                     [this, port, cid, vhost_user_vsock_cid_]() {
49                       return Connect(port, cid, vhost_user_vsock_cid_);
50                     });
51 }
52 
Disconnect()53 void VsockConnection::Disconnect() {
54   // We need to serialize all accesses to the SharedFD.
55   std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
56   std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
57 
58   LOG(INFO) << "Disconnecting with fd status:" << fd_->StrError();
59   fd_->Shutdown(SHUT_RDWR);
60   if (disconnect_callback_) {
61     disconnect_callback_();
62   }
63   fd_->Close();
64 }
65 
SetDisconnectCallback(std::function<void ()> callback)66 void VsockConnection::SetDisconnectCallback(std::function<void()> callback) {
67   disconnect_callback_ = callback;
68 }
69 
70 // This method created due to a race condition in IsConnected().
71 // TODO(b/345285391): remove this method once a fix found
IsConnected_Unguarded()72 bool VsockConnection::IsConnected_Unguarded() { return fd_->IsOpen(); }
73 
IsConnected()74 bool VsockConnection::IsConnected() {
75   // We need to serialize all accesses to the SharedFD.
76   std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
77   std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
78 
79   return fd_->IsOpen();
80 }
81 
DataAvailable()82 bool VsockConnection::DataAvailable() {
83   SharedFDSet read_set;
84 
85   // We need to serialize all accesses to the SharedFD.
86   std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
87   std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);
88 
89   read_set.Set(fd_);
90   struct timeval timeout = {0, 0};
91   return Select(&read_set, nullptr, nullptr, &timeout) > 0;
92 }
93 
Read()94 int32_t VsockConnection::Read() {
95   std::lock_guard<std::recursive_mutex> lock(read_mutex_);
96   int32_t result;
97   if (ReadExactBinary(fd_, &result) != sizeof(result)) {
98     Disconnect();
99     return 0;
100   }
101   return result;
102 }
103 
Read(std::vector<char> & data)104 bool VsockConnection::Read(std::vector<char>& data) {
105   std::lock_guard<std::recursive_mutex> lock(read_mutex_);
106   return ReadExact(fd_, &data) == data.size();
107 }
108 
Read(size_t size)109 std::vector<char> VsockConnection::Read(size_t size) {
110   if (size == 0) {
111     return {};
112   }
113   std::lock_guard<std::recursive_mutex> lock(read_mutex_);
114   std::vector<char> result(size);
115   if (ReadExact(fd_, &result) != size) {
116     Disconnect();
117     return {};
118   }
119   return result;
120 }
121 
ReadAsync(size_t size)122 std::future<std::vector<char>> VsockConnection::ReadAsync(size_t size) {
123   return std::async(std::launch::async, [this, size]() { return Read(size); });
124 }
125 
126 // Message format is buffer size followed by buffer data
ReadMessage()127 std::vector<char> VsockConnection::ReadMessage() {
128   std::lock_guard<std::recursive_mutex> lock(read_mutex_);
129   auto size = Read();
130   if (size < 0) {
131     Disconnect();
132     return {};
133   }
134   return Read(size);
135 }
136 
ReadMessage(std::vector<char> & data)137 bool VsockConnection::ReadMessage(std::vector<char>& data) {
138   std::lock_guard<std::recursive_mutex> lock(read_mutex_);
139   auto size = Read();
140   if (size < 0) {
141     Disconnect();
142     return false;
143   }
144   data.resize(size);
145   return Read(data);
146 }
147 
ReadMessageAsync()148 std::future<std::vector<char>> VsockConnection::ReadMessageAsync() {
149   return std::async(std::launch::async, [this]() { return ReadMessage(); });
150 }
151 
ReadJsonMessage()152 Json::Value VsockConnection::ReadJsonMessage() {
153   auto msg = ReadMessage();
154   Json::CharReaderBuilder builder;
155   std::unique_ptr<Json::CharReader> reader(builder.newCharReader());
156   Json::Value json_msg;
157   std::string errors;
158   if (!reader->parse(msg.data(), msg.data() + msg.size(), &json_msg, &errors)) {
159     return {};
160   }
161   return json_msg;
162 }
163 
ReadJsonMessageAsync()164 std::future<Json::Value> VsockConnection::ReadJsonMessageAsync() {
165   return std::async(std::launch::async, [this]() { return ReadJsonMessage(); });
166 }
167 
Write(int32_t data)168 bool VsockConnection::Write(int32_t data) {
169   std::lock_guard<std::recursive_mutex> lock(write_mutex_);
170   if (WriteAllBinary(fd_, &data) != sizeof(data)) {
171     Disconnect();
172     return false;
173   }
174   return true;
175 }
176 
Write(const char * data,unsigned int size)177 bool VsockConnection::Write(const char* data, unsigned int size) {
178   std::lock_guard<std::recursive_mutex> lock(write_mutex_);
179   if (WriteAll(fd_, data, size) != size) {
180     Disconnect();
181     return false;
182   }
183   return true;
184 }
185 
Write(const std::vector<char> & data)186 bool VsockConnection::Write(const std::vector<char>& data) {
187   return Write(data.data(), data.size());
188 }
189 
190 // Message format is buffer size followed by buffer data
WriteMessage(const std::string & data)191 bool VsockConnection::WriteMessage(const std::string& data) {
192   return Write(data.size()) && Write(data.c_str(), data.length());
193 }
194 
WriteMessage(const std::vector<char> & data)195 bool VsockConnection::WriteMessage(const std::vector<char>& data) {
196   std::lock_guard<std::recursive_mutex> lock(write_mutex_);
197   return Write(data.size()) && Write(data);
198 }
199 
WriteMessage(const Json::Value & data)200 bool VsockConnection::WriteMessage(const Json::Value& data) {
201   Json::StreamWriterBuilder factory;
202   std::string message_str = Json::writeString(factory, data);
203   return WriteMessage(message_str);
204 }
205 
WriteStrides(const char * data,unsigned int size,unsigned int num_strides,int stride_size)206 bool VsockConnection::WriteStrides(const char* data, unsigned int size,
207                                    unsigned int num_strides, int stride_size) {
208   const char* src = data;
209   for (unsigned int i = 0; i < num_strides; ++i, src += stride_size) {
210     if (!Write(src, size)) {
211       return false;
212     }
213   }
214   return true;
215 }
216 
Connect(unsigned int port,unsigned int cid,std::optional<int> vhost_user)217 bool VsockClientConnection::Connect(unsigned int port, unsigned int cid,
218                                     std::optional<int> vhost_user) {
219   fd_ =
220       SharedFD::VsockClient(cid, port, SOCK_STREAM, vhost_user ? true : false);
221   if (!fd_->IsOpen()) {
222     LOG(ERROR) << "Failed to connect:" << fd_->StrError();
223   }
224   return fd_->IsOpen();
225 }
226 
~VsockServerConnection()227 VsockServerConnection::~VsockServerConnection() { ServerShutdown(); }
228 
ServerShutdown()229 void VsockServerConnection::ServerShutdown() {
230   if (server_fd_->IsOpen()) {
231     LOG(INFO) << __FUNCTION__
232               << ": server fd status:" << server_fd_->StrError();
233     server_fd_->Shutdown(SHUT_RDWR);
234     server_fd_->Close();
235   }
236 }
237 
Connect(unsigned int port,unsigned int cid,std::optional<int> vhost_user_vsock_cid)238 bool VsockServerConnection::Connect(unsigned int port, unsigned int cid,
239                                     std::optional<int> vhost_user_vsock_cid) {
240   if (!server_fd_->IsOpen()) {
241     server_fd_ = cuttlefish::SharedFD::VsockServer(port, SOCK_STREAM,
242                                                    vhost_user_vsock_cid, cid);
243   }
244   if (server_fd_->IsOpen()) {
245     fd_ = SharedFD::Accept(*server_fd_);
246     return fd_->IsOpen();
247   } else {
248     return false;
249   }
250 }
251 
252 }  // namespace cuttlefish
253