xref: /aosp_15_r20/external/cronet/net/socket/socks_client_socket.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socks_client_socket.h"
6 
7 #include <utility>
8 
9 #include "base/compiler_specific.h"
10 #include "base/functional/bind.h"
11 #include "base/functional/callback_helpers.h"
12 #include "base/sys_byteorder.h"
13 #include "net/base/address_list.h"
14 #include "net/base/io_buffer.h"
15 #include "net/dns/public/dns_query_type.h"
16 #include "net/dns/public/secure_dns_policy.h"
17 #include "net/log/net_log.h"
18 #include "net/log/net_log_event_type.h"
19 #include "net/traffic_annotation/network_traffic_annotation.h"
20 
21 namespace net {
22 
23 // Every SOCKS server requests a user-id from the client. It is optional
24 // and we send an empty string.
25 static const char kEmptyUserId[] = "";
26 
27 // For SOCKS4, the client sends 8 bytes  plus the size of the user-id.
28 static const unsigned int kWriteHeaderSize = 8;
29 
30 // For SOCKS4 the server sends 8 bytes for acknowledgement.
31 static const unsigned int kReadHeaderSize = 8;
32 
33 // Server Response codes for SOCKS.
34 static const uint8_t kServerResponseOk = 0x5A;
35 static const uint8_t kServerResponseRejected = 0x5B;
36 static const uint8_t kServerResponseNotReachable = 0x5C;
37 static const uint8_t kServerResponseMismatchedUserId = 0x5D;
38 
39 static const uint8_t kSOCKSVersion4 = 0x04;
40 static const uint8_t kSOCKSStreamRequest = 0x01;
41 
42 // A struct holding the essential details of the SOCKS4 Server Request.
43 // The port in the header is stored in network byte order.
44 struct SOCKS4ServerRequest {
45   uint8_t version;
46   uint8_t command;
47   uint16_t nw_port;
48   uint8_t ip[4];
49 };
50 static_assert(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
51               "socks4 server request struct has incorrect size");
52 
53 // A struct holding details of the SOCKS4 Server Response.
54 struct SOCKS4ServerResponse {
55   uint8_t reserved_null;
56   uint8_t code;
57   uint16_t port;
58   uint8_t ip[4];
59 };
60 static_assert(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
61               "socks4 server response struct has incorrect size");
62 
SOCKSClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkAnonymizationKey & network_anonymization_key,RequestPriority priority,HostResolver * host_resolver,SecureDnsPolicy secure_dns_policy,const NetworkTrafficAnnotationTag & traffic_annotation)63 SOCKSClientSocket::SOCKSClientSocket(
64     std::unique_ptr<StreamSocket> transport_socket,
65     const HostPortPair& destination,
66     const NetworkAnonymizationKey& network_anonymization_key,
67     RequestPriority priority,
68     HostResolver* host_resolver,
69     SecureDnsPolicy secure_dns_policy,
70     const NetworkTrafficAnnotationTag& traffic_annotation)
71     : transport_socket_(std::move(transport_socket)),
72       host_resolver_(host_resolver),
73       secure_dns_policy_(secure_dns_policy),
74       destination_(destination),
75       network_anonymization_key_(network_anonymization_key),
76       priority_(priority),
77       net_log_(transport_socket_->NetLog()),
78       traffic_annotation_(traffic_annotation) {}
79 
~SOCKSClientSocket()80 SOCKSClientSocket::~SOCKSClientSocket() {
81   Disconnect();
82 }
83 
Connect(CompletionOnceCallback callback)84 int SOCKSClientSocket::Connect(CompletionOnceCallback callback) {
85   DCHECK(transport_socket_);
86   DCHECK_EQ(STATE_NONE, next_state_);
87   DCHECK(user_callback_.is_null());
88 
89   // If already connected, then just return OK.
90   if (completed_handshake_)
91     return OK;
92 
93   next_state_ = STATE_RESOLVE_HOST;
94 
95   net_log_.BeginEvent(NetLogEventType::SOCKS_CONNECT);
96 
97   int rv = DoLoop(OK);
98   if (rv == ERR_IO_PENDING) {
99     user_callback_ = std::move(callback);
100   } else {
101     net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS_CONNECT, rv);
102   }
103   return rv;
104 }
105 
Disconnect()106 void SOCKSClientSocket::Disconnect() {
107   completed_handshake_ = false;
108   resolve_host_request_.reset();
109   transport_socket_->Disconnect();
110 
111   // Reset other states to make sure they aren't mistakenly used later.
112   // These are the states initialized by Connect().
113   next_state_ = STATE_NONE;
114   user_callback_.Reset();
115 }
116 
IsConnected() const117 bool SOCKSClientSocket::IsConnected() const {
118   return completed_handshake_ && transport_socket_->IsConnected();
119 }
120 
IsConnectedAndIdle() const121 bool SOCKSClientSocket::IsConnectedAndIdle() const {
122   return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
123 }
124 
NetLog() const125 const NetLogWithSource& SOCKSClientSocket::NetLog() const {
126   return net_log_;
127 }
128 
WasEverUsed() const129 bool SOCKSClientSocket::WasEverUsed() const {
130   return was_ever_used_;
131 }
132 
GetNegotiatedProtocol() const133 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
134   if (transport_socket_)
135     return transport_socket_->GetNegotiatedProtocol();
136   NOTREACHED();
137   return kProtoUnknown;
138 }
139 
GetSSLInfo(SSLInfo * ssl_info)140 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
141   if (transport_socket_)
142     return transport_socket_->GetSSLInfo(ssl_info);
143   NOTREACHED();
144   return false;
145 }
146 
GetTotalReceivedBytes() const147 int64_t SOCKSClientSocket::GetTotalReceivedBytes() const {
148   return transport_socket_->GetTotalReceivedBytes();
149 }
150 
ApplySocketTag(const SocketTag & tag)151 void SOCKSClientSocket::ApplySocketTag(const SocketTag& tag) {
152   return transport_socket_->ApplySocketTag(tag);
153 }
154 
155 // Read is called by the transport layer above to read. This can only be done
156 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)157 int SOCKSClientSocket::Read(IOBuffer* buf,
158                             int buf_len,
159                             CompletionOnceCallback callback) {
160   DCHECK(completed_handshake_);
161   DCHECK_EQ(STATE_NONE, next_state_);
162   DCHECK(user_callback_.is_null());
163   DCHECK(!callback.is_null());
164 
165   int rv = transport_socket_->Read(
166       buf, buf_len,
167       base::BindOnce(&SOCKSClientSocket::OnReadWriteComplete,
168                      base::Unretained(this), std::move(callback)));
169   if (rv > 0)
170     was_ever_used_ = true;
171   return rv;
172 }
173 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)174 int SOCKSClientSocket::ReadIfReady(IOBuffer* buf,
175                                    int buf_len,
176                                    CompletionOnceCallback callback) {
177   DCHECK(completed_handshake_);
178   DCHECK_EQ(STATE_NONE, next_state_);
179   DCHECK(user_callback_.is_null());
180   DCHECK(!callback.is_null());
181 
182   // Pass |callback| directly instead of wrapping it with OnReadWriteComplete.
183   // This is to avoid setting |was_ever_used_| unless data is actually read.
184   int rv = transport_socket_->ReadIfReady(buf, buf_len, std::move(callback));
185   if (rv > 0)
186     was_ever_used_ = true;
187   return rv;
188 }
189 
CancelReadIfReady()190 int SOCKSClientSocket::CancelReadIfReady() {
191   return transport_socket_->CancelReadIfReady();
192 }
193 
194 // Write is called by the transport layer. This can only be done if the
195 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)196 int SOCKSClientSocket::Write(
197     IOBuffer* buf,
198     int buf_len,
199     CompletionOnceCallback callback,
200     const NetworkTrafficAnnotationTag& traffic_annotation) {
201   DCHECK(completed_handshake_);
202   DCHECK_EQ(STATE_NONE, next_state_);
203   DCHECK(user_callback_.is_null());
204   DCHECK(!callback.is_null());
205 
206   int rv = transport_socket_->Write(
207       buf, buf_len,
208       base::BindOnce(&SOCKSClientSocket::OnReadWriteComplete,
209                      base::Unretained(this), std::move(callback)),
210       traffic_annotation);
211   if (rv > 0)
212     was_ever_used_ = true;
213   return rv;
214 }
215 
SetReceiveBufferSize(int32_t size)216 int SOCKSClientSocket::SetReceiveBufferSize(int32_t size) {
217   return transport_socket_->SetReceiveBufferSize(size);
218 }
219 
SetSendBufferSize(int32_t size)220 int SOCKSClientSocket::SetSendBufferSize(int32_t size) {
221   return transport_socket_->SetSendBufferSize(size);
222 }
223 
DoCallback(int result)224 void SOCKSClientSocket::DoCallback(int result) {
225   DCHECK_NE(ERR_IO_PENDING, result);
226   DCHECK(!user_callback_.is_null());
227 
228   // Since Run() may result in Read being called,
229   // clear user_callback_ up front.
230   DVLOG(1) << "Finished setting up SOCKS handshake";
231   std::move(user_callback_).Run(result);
232 }
233 
OnIOComplete(int result)234 void SOCKSClientSocket::OnIOComplete(int result) {
235   DCHECK_NE(STATE_NONE, next_state_);
236   int rv = DoLoop(result);
237   if (rv != ERR_IO_PENDING) {
238     net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS_CONNECT, rv);
239     DoCallback(rv);
240   }
241 }
242 
OnReadWriteComplete(CompletionOnceCallback callback,int result)243 void SOCKSClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
244                                             int result) {
245   DCHECK_NE(ERR_IO_PENDING, result);
246   DCHECK(!callback.is_null());
247 
248   if (result > 0)
249     was_ever_used_ = true;
250   std::move(callback).Run(result);
251 }
252 
DoLoop(int last_io_result)253 int SOCKSClientSocket::DoLoop(int last_io_result) {
254   DCHECK_NE(next_state_, STATE_NONE);
255   int rv = last_io_result;
256   do {
257     State state = next_state_;
258     next_state_ = STATE_NONE;
259     switch (state) {
260       case STATE_RESOLVE_HOST:
261         DCHECK_EQ(OK, rv);
262         rv = DoResolveHost();
263         break;
264       case STATE_RESOLVE_HOST_COMPLETE:
265         rv = DoResolveHostComplete(rv);
266         break;
267       case STATE_HANDSHAKE_WRITE:
268         DCHECK_EQ(OK, rv);
269         rv = DoHandshakeWrite();
270         break;
271       case STATE_HANDSHAKE_WRITE_COMPLETE:
272         rv = DoHandshakeWriteComplete(rv);
273         break;
274       case STATE_HANDSHAKE_READ:
275         DCHECK_EQ(OK, rv);
276         rv = DoHandshakeRead();
277         break;
278       case STATE_HANDSHAKE_READ_COMPLETE:
279         rv = DoHandshakeReadComplete(rv);
280         break;
281       default:
282         NOTREACHED() << "bad state";
283         rv = ERR_UNEXPECTED;
284         break;
285     }
286   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
287   return rv;
288 }
289 
DoResolveHost()290 int SOCKSClientSocket::DoResolveHost() {
291   next_state_ = STATE_RESOLVE_HOST_COMPLETE;
292   // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
293   // addresses for the target host.
294   HostResolver::ResolveHostParameters parameters;
295   parameters.dns_query_type = DnsQueryType::A;
296   parameters.initial_priority = priority_;
297   parameters.secure_dns_policy = secure_dns_policy_;
298   resolve_host_request_ = host_resolver_->CreateRequest(
299       destination_, network_anonymization_key_, net_log_, parameters);
300 
301   return resolve_host_request_->Start(
302       base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
303 }
304 
DoResolveHostComplete(int result)305 int SOCKSClientSocket::DoResolveHostComplete(int result) {
306   resolve_error_info_ = resolve_host_request_->GetResolveErrorInfo();
307   if (result != OK) {
308     // Resolving the hostname failed; fail the request rather than automatically
309     // falling back to SOCKS4a (since it can be confusing to see invalid IP
310     // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
311     return result;
312   }
313 
314   next_state_ = STATE_HANDSHAKE_WRITE;
315   return OK;
316 }
317 
318 // Builds the buffer that is to be sent to the server.
BuildHandshakeWriteBuffer() const319 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
320   SOCKS4ServerRequest request;
321   request.version = kSOCKSVersion4;
322   request.command = kSOCKSStreamRequest;
323   request.nw_port = base::HostToNet16(destination_.port());
324 
325   DCHECK(resolve_host_request_->GetAddressResults() &&
326          !resolve_host_request_->GetAddressResults()->empty());
327   const IPEndPoint& endpoint =
328       resolve_host_request_->GetAddressResults()->front();
329 
330   // We disabled IPv6 results when resolving the hostname, so none of the
331   // results in the list will be IPv6.
332   // TODO(eroman): we only ever use the first address in the list. It would be
333   //               more robust to try all the IP addresses we have before
334   //               failing the connect attempt.
335   CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
336   CHECK_LE(endpoint.address().size(), sizeof(request.ip));
337   memcpy(&request.ip, &endpoint.address().bytes()[0],
338          endpoint.address().size());
339 
340   DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
341 
342   std::string handshake_data(reinterpret_cast<char*>(&request),
343                              sizeof(request));
344   handshake_data.append(kEmptyUserId, std::size(kEmptyUserId));
345 
346   return handshake_data;
347 }
348 
349 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()350 int SOCKSClientSocket::DoHandshakeWrite() {
351   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
352 
353   if (buffer_.empty()) {
354     buffer_ = BuildHandshakeWriteBuffer();
355     bytes_sent_ = 0;
356   }
357 
358   int handshake_buf_len = buffer_.size() - bytes_sent_;
359   DCHECK_GT(handshake_buf_len, 0);
360   handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
361   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
362          handshake_buf_len);
363   return transport_socket_->Write(
364       handshake_buf_.get(), handshake_buf_len,
365       base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
366       traffic_annotation_);
367 }
368 
DoHandshakeWriteComplete(int result)369 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
370   if (result < 0)
371     return result;
372 
373   // We ignore the case when result is 0, since the underlying Write
374   // may return spurious writes while waiting on the socket.
375 
376   bytes_sent_ += result;
377   if (bytes_sent_ == buffer_.size()) {
378     next_state_ = STATE_HANDSHAKE_READ;
379     buffer_.clear();
380   } else if (bytes_sent_ < buffer_.size()) {
381     next_state_ = STATE_HANDSHAKE_WRITE;
382   } else {
383     return ERR_UNEXPECTED;
384   }
385 
386   return OK;
387 }
388 
DoHandshakeRead()389 int SOCKSClientSocket::DoHandshakeRead() {
390   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
391 
392   if (buffer_.empty()) {
393     bytes_received_ = 0;
394   }
395 
396   int handshake_buf_len = kReadHeaderSize - bytes_received_;
397   handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
398   return transport_socket_->Read(
399       handshake_buf_.get(), handshake_buf_len,
400       base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
401 }
402 
DoHandshakeReadComplete(int result)403 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
404   if (result < 0)
405     return result;
406 
407   // The underlying socket closed unexpectedly.
408   if (result == 0)
409     return ERR_CONNECTION_CLOSED;
410 
411   if (bytes_received_ + result > kReadHeaderSize) {
412     // TODO(eroman): Describe failure in NetLog.
413     return ERR_SOCKS_CONNECTION_FAILED;
414   }
415 
416   buffer_.append(handshake_buf_->data(), result);
417   bytes_received_ += result;
418   if (bytes_received_ < kReadHeaderSize) {
419     next_state_ = STATE_HANDSHAKE_READ;
420     return OK;
421   }
422 
423   const SOCKS4ServerResponse* response =
424       reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
425 
426   if (response->reserved_null != 0x00) {
427     DVLOG(1) << "Unknown response from SOCKS server.";
428     return ERR_SOCKS_CONNECTION_FAILED;
429   }
430 
431   switch (response->code) {
432     case kServerResponseOk:
433       completed_handshake_ = true;
434       return OK;
435     case kServerResponseRejected:
436       DVLOG(1) << "SOCKS request rejected or failed";
437       return ERR_SOCKS_CONNECTION_FAILED;
438     case kServerResponseNotReachable:
439       DVLOG(1) << "SOCKS request failed because client is not running "
440                << "identd (or not reachable from the server)";
441       return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
442     case kServerResponseMismatchedUserId:
443       DVLOG(1) << "SOCKS request failed because client's identd could "
444                << "not confirm the user ID string in the request";
445       return ERR_SOCKS_CONNECTION_FAILED;
446     default:
447       DVLOG(1) << "SOCKS server sent unknown response";
448       return ERR_SOCKS_CONNECTION_FAILED;
449   }
450 
451   // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
452 }
453 
GetPeerAddress(IPEndPoint * address) const454 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
455   return transport_socket_->GetPeerAddress(address);
456 }
457 
GetLocalAddress(IPEndPoint * address) const458 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
459   return transport_socket_->GetLocalAddress(address);
460 }
461 
GetResolveErrorInfo() const462 ResolveErrorInfo SOCKSClientSocket::GetResolveErrorInfo() const {
463   return resolve_error_info_;
464 }
465 
466 }  // namespace net
467