xref: /aosp_15_r20/external/cronet/net/socket/socks5_client_socket.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/socks5_client_socket.h"
6 
7 #include <utility>
8 
9 #include "base/compiler_specific.h"
10 #include "base/format_macros.h"
11 #include "base/functional/bind.h"
12 #include "base/functional/callback_helpers.h"
13 #include "base/strings/string_util.h"
14 #include "base/sys_byteorder.h"
15 #include "net/base/io_buffer.h"
16 #include "net/base/sys_addrinfo.h"
17 #include "net/base/tracing.h"
18 #include "net/log/net_log.h"
19 #include "net/log/net_log_event_type.h"
20 #include "net/traffic_annotation/network_traffic_annotation.h"
21 
22 namespace net {
23 
24 const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
25 const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
26 const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
27 const uint8_t SOCKS5ClientSocket::kSOCKS5Version = 0x05;
28 const uint8_t SOCKS5ClientSocket::kTunnelCommand = 0x01;
29 const uint8_t SOCKS5ClientSocket::kNullByte = 0x00;
30 
31 static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4");
32 static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6");
33 
SOCKS5ClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkTrafficAnnotationTag & traffic_annotation)34 SOCKS5ClientSocket::SOCKS5ClientSocket(
35     std::unique_ptr<StreamSocket> transport_socket,
36     const HostPortPair& destination,
37     const NetworkTrafficAnnotationTag& traffic_annotation)
38     : io_callback_(base::BindRepeating(&SOCKS5ClientSocket::OnIOComplete,
39                                        base::Unretained(this))),
40       transport_socket_(std::move(transport_socket)),
41       read_header_size(kReadHeaderSize),
42       destination_(destination),
43       net_log_(transport_socket_->NetLog()),
44       traffic_annotation_(traffic_annotation) {}
45 
~SOCKS5ClientSocket()46 SOCKS5ClientSocket::~SOCKS5ClientSocket() {
47   Disconnect();
48 }
49 
Connect(CompletionOnceCallback callback)50 int SOCKS5ClientSocket::Connect(CompletionOnceCallback callback) {
51   DCHECK(transport_socket_);
52   DCHECK_EQ(STATE_NONE, next_state_);
53   DCHECK(user_callback_.is_null());
54 
55   // If already connected, then just return OK.
56   if (completed_handshake_)
57     return OK;
58 
59   net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT);
60 
61   next_state_ = STATE_GREET_WRITE;
62   buffer_.clear();
63 
64   int rv = DoLoop(OK);
65   if (rv == ERR_IO_PENDING) {
66     user_callback_ = std::move(callback);
67   } else {
68     net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv);
69   }
70   return rv;
71 }
72 
Disconnect()73 void SOCKS5ClientSocket::Disconnect() {
74   completed_handshake_ = false;
75   transport_socket_->Disconnect();
76 
77   // Reset other states to make sure they aren't mistakenly used later.
78   // These are the states initialized by Connect().
79   next_state_ = STATE_NONE;
80   user_callback_.Reset();
81 }
82 
IsConnected() const83 bool SOCKS5ClientSocket::IsConnected() const {
84   return completed_handshake_ && transport_socket_->IsConnected();
85 }
86 
IsConnectedAndIdle() const87 bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
88   return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
89 }
90 
NetLog() const91 const NetLogWithSource& SOCKS5ClientSocket::NetLog() const {
92   return net_log_;
93 }
94 
WasEverUsed() const95 bool SOCKS5ClientSocket::WasEverUsed() const {
96   return was_ever_used_;
97 }
98 
GetNegotiatedProtocol() const99 NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
100   if (transport_socket_)
101     return transport_socket_->GetNegotiatedProtocol();
102   NOTREACHED();
103   return kProtoUnknown;
104 }
105 
GetSSLInfo(SSLInfo * ssl_info)106 bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
107   if (transport_socket_)
108     return transport_socket_->GetSSLInfo(ssl_info);
109   NOTREACHED();
110   return false;
111 }
112 
GetTotalReceivedBytes() const113 int64_t SOCKS5ClientSocket::GetTotalReceivedBytes() const {
114   return transport_socket_->GetTotalReceivedBytes();
115 }
116 
ApplySocketTag(const SocketTag & tag)117 void SOCKS5ClientSocket::ApplySocketTag(const SocketTag& tag) {
118   return transport_socket_->ApplySocketTag(tag);
119 }
120 
121 // Read is called by the transport layer above to read. This can only be done
122 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)123 int SOCKS5ClientSocket::Read(IOBuffer* buf,
124                              int buf_len,
125                              CompletionOnceCallback callback) {
126   DCHECK(completed_handshake_);
127   DCHECK_EQ(STATE_NONE, next_state_);
128   DCHECK(user_callback_.is_null());
129   DCHECK(!callback.is_null());
130 
131   int rv = transport_socket_->Read(
132       buf, buf_len,
133       base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
134                      base::Unretained(this), std::move(callback)));
135   if (rv > 0)
136     was_ever_used_ = true;
137   return rv;
138 }
139 
140 // Write is called by the transport layer. This can only be done if the
141 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)142 int SOCKS5ClientSocket::Write(
143     IOBuffer* buf,
144     int buf_len,
145     CompletionOnceCallback callback,
146     const NetworkTrafficAnnotationTag& traffic_annotation) {
147   DCHECK(completed_handshake_);
148   DCHECK_EQ(STATE_NONE, next_state_);
149   DCHECK(user_callback_.is_null());
150   DCHECK(!callback.is_null());
151 
152   int rv = transport_socket_->Write(
153       buf, buf_len,
154       base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
155                      base::Unretained(this), std::move(callback)),
156       traffic_annotation);
157   if (rv > 0)
158     was_ever_used_ = true;
159   return rv;
160 }
161 
SetReceiveBufferSize(int32_t size)162 int SOCKS5ClientSocket::SetReceiveBufferSize(int32_t size) {
163   return transport_socket_->SetReceiveBufferSize(size);
164 }
165 
SetSendBufferSize(int32_t size)166 int SOCKS5ClientSocket::SetSendBufferSize(int32_t size) {
167   return transport_socket_->SetSendBufferSize(size);
168 }
169 
DoCallback(int result)170 void SOCKS5ClientSocket::DoCallback(int result) {
171   DCHECK_NE(ERR_IO_PENDING, result);
172   DCHECK(!user_callback_.is_null());
173 
174   // Since Run() may result in Read being called,
175   // clear user_callback_ up front.
176   std::move(user_callback_).Run(result);
177 }
178 
OnIOComplete(int result)179 void SOCKS5ClientSocket::OnIOComplete(int result) {
180   DCHECK_NE(STATE_NONE, next_state_);
181   int rv = DoLoop(result);
182   if (rv != ERR_IO_PENDING) {
183     net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT);
184     DoCallback(rv);
185   }
186 }
187 
OnReadWriteComplete(CompletionOnceCallback callback,int result)188 void SOCKS5ClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
189                                              int result) {
190   DCHECK_NE(ERR_IO_PENDING, result);
191   DCHECK(!callback.is_null());
192 
193   if (result > 0)
194     was_ever_used_ = true;
195   std::move(callback).Run(result);
196 }
197 
DoLoop(int last_io_result)198 int SOCKS5ClientSocket::DoLoop(int last_io_result) {
199   DCHECK_NE(next_state_, STATE_NONE);
200   int rv = last_io_result;
201   do {
202     State state = next_state_;
203     next_state_ = STATE_NONE;
204     switch (state) {
205       case STATE_GREET_WRITE:
206         DCHECK_EQ(OK, rv);
207         net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_WRITE);
208         rv = DoGreetWrite();
209         break;
210       case STATE_GREET_WRITE_COMPLETE:
211         rv = DoGreetWriteComplete(rv);
212         net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE,
213                                           rv);
214         break;
215       case STATE_GREET_READ:
216         DCHECK_EQ(OK, rv);
217         net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ);
218         rv = DoGreetRead();
219         break;
220       case STATE_GREET_READ_COMPLETE:
221         rv = DoGreetReadComplete(rv);
222         net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ,
223                                           rv);
224         break;
225       case STATE_HANDSHAKE_WRITE:
226         DCHECK_EQ(OK, rv);
227         net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE);
228         rv = DoHandshakeWrite();
229         break;
230       case STATE_HANDSHAKE_WRITE_COMPLETE:
231         rv = DoHandshakeWriteComplete(rv);
232         net_log_.EndEventWithNetErrorCode(
233             NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv);
234         break;
235       case STATE_HANDSHAKE_READ:
236         DCHECK_EQ(OK, rv);
237         net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ);
238         rv = DoHandshakeRead();
239         break;
240       case STATE_HANDSHAKE_READ_COMPLETE:
241         rv = DoHandshakeReadComplete(rv);
242         net_log_.EndEventWithNetErrorCode(
243             NetLogEventType::SOCKS5_HANDSHAKE_READ, rv);
244         break;
245       default:
246         NOTREACHED() << "bad state";
247         rv = ERR_UNEXPECTED;
248         break;
249     }
250   } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
251   return rv;
252 }
253 
254 const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 };  // no authentication
255 
DoGreetWrite()256 int SOCKS5ClientSocket::DoGreetWrite() {
257   // Since we only have 1 byte to send the hostname length in, if the
258   // URL has a hostname longer than 255 characters we can't send it.
259   if (0xFF < destination_.host().size()) {
260     net_log_.AddEvent(NetLogEventType::SOCKS_HOSTNAME_TOO_BIG);
261     return ERR_SOCKS_CONNECTION_FAILED;
262   }
263 
264   if (buffer_.empty()) {
265     buffer_ =
266         std::string(kSOCKS5GreetWriteData, std::size(kSOCKS5GreetWriteData));
267     bytes_sent_ = 0;
268   }
269 
270   next_state_ = STATE_GREET_WRITE_COMPLETE;
271   size_t handshake_buf_len = buffer_.size() - bytes_sent_;
272   handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
273   memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
274          handshake_buf_len);
275   return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
276                                   io_callback_, traffic_annotation_);
277 }
278 
DoGreetWriteComplete(int result)279 int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
280   if (result < 0)
281     return result;
282 
283   bytes_sent_ += result;
284   if (bytes_sent_ == buffer_.size()) {
285     buffer_.clear();
286     bytes_received_ = 0;
287     next_state_ = STATE_GREET_READ;
288   } else {
289     next_state_ = STATE_GREET_WRITE;
290   }
291   return OK;
292 }
293 
DoGreetRead()294 int SOCKS5ClientSocket::DoGreetRead() {
295   next_state_ = STATE_GREET_READ_COMPLETE;
296   size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
297   handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
298   return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
299                                  io_callback_);
300 }
301 
DoGreetReadComplete(int result)302 int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
303   if (result < 0)
304     return result;
305 
306   if (result == 0) {
307     net_log_.AddEvent(
308         NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
309     return ERR_SOCKS_CONNECTION_FAILED;
310   }
311 
312   bytes_received_ += result;
313   buffer_.append(handshake_buf_->data(), result);
314   if (bytes_received_ < kGreetReadHeaderSize) {
315     next_state_ = STATE_GREET_READ;
316     return OK;
317   }
318 
319   // Got the greet data.
320   if (buffer_[0] != kSOCKS5Version) {
321     net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
322                                    "version", buffer_[0]);
323     return ERR_SOCKS_CONNECTION_FAILED;
324   }
325   if (buffer_[1] != 0x00) {
326     net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_AUTH,
327                                    "method", buffer_[1]);
328     return ERR_SOCKS_CONNECTION_FAILED;
329   }
330 
331   buffer_.clear();
332   next_state_ = STATE_HANDSHAKE_WRITE;
333   return OK;
334 }
335 
BuildHandshakeWriteBuffer(std::string * handshake) const336 int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
337     const {
338   DCHECK(handshake->empty());
339 
340   handshake->push_back(kSOCKS5Version);
341   handshake->push_back(kTunnelCommand);  // Connect command
342   handshake->push_back(kNullByte);  // Reserved null
343 
344   handshake->push_back(kEndPointDomain);  // The type of the address.
345 
346   DCHECK_GE(static_cast<size_t>(0xFF), destination_.host().size());
347 
348   // First add the size of the hostname, followed by the hostname.
349   handshake->push_back(static_cast<unsigned char>(destination_.host().size()));
350   handshake->append(destination_.host());
351 
352   uint16_t nw_port = base::HostToNet16(destination_.port());
353   handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
354   return OK;
355 }
356 
357 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()358 int SOCKS5ClientSocket::DoHandshakeWrite() {
359   next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
360 
361   if (buffer_.empty()) {
362     int rv = BuildHandshakeWriteBuffer(&buffer_);
363     if (rv != OK)
364       return rv;
365     bytes_sent_ = 0;
366   }
367 
368   int handshake_buf_len = buffer_.size() - bytes_sent_;
369   DCHECK_LT(0, handshake_buf_len);
370   handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
371   memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
372          handshake_buf_len);
373   return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
374                                   io_callback_, traffic_annotation_);
375 }
376 
DoHandshakeWriteComplete(int result)377 int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
378   if (result < 0)
379     return result;
380 
381   // We ignore the case when result is 0, since the underlying Write
382   // may return spurious writes while waiting on the socket.
383 
384   bytes_sent_ += result;
385   if (bytes_sent_ == buffer_.size()) {
386     next_state_ = STATE_HANDSHAKE_READ;
387     buffer_.clear();
388   } else if (bytes_sent_ < buffer_.size()) {
389     next_state_ = STATE_HANDSHAKE_WRITE;
390   } else {
391     NOTREACHED();
392   }
393 
394   return OK;
395 }
396 
DoHandshakeRead()397 int SOCKS5ClientSocket::DoHandshakeRead() {
398   next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
399 
400   if (buffer_.empty()) {
401     bytes_received_ = 0;
402     read_header_size = kReadHeaderSize;
403   }
404 
405   int handshake_buf_len = read_header_size - bytes_received_;
406   handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
407   return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
408                                  io_callback_);
409 }
410 
DoHandshakeReadComplete(int result)411 int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
412   if (result < 0)
413     return result;
414 
415   // The underlying socket closed unexpectedly.
416   if (result == 0) {
417     net_log_.AddEvent(
418         NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
419     return ERR_SOCKS_CONNECTION_FAILED;
420   }
421 
422   buffer_.append(handshake_buf_->data(), result);
423   bytes_received_ += result;
424 
425   // When the first few bytes are read, check how many more are required
426   // and accordingly increase them
427   if (bytes_received_ == kReadHeaderSize) {
428     if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
429       net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
430                                      "version", buffer_[0]);
431       return ERR_SOCKS_CONNECTION_FAILED;
432     }
433     if (buffer_[1] != 0x00) {
434       net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR,
435                                      "error_code", buffer_[1]);
436       return ERR_SOCKS_CONNECTION_FAILED;
437     }
438 
439     // We check the type of IP/Domain the server returns and accordingly
440     // increase the size of the response. For domains, we need to read the
441     // size of the domain, so the initial request size is upto the domain
442     // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
443     // read, we substract 1 byte from the additional request size.
444     SocksEndPointAddressType address_type =
445         static_cast<SocksEndPointAddressType>(buffer_[3]);
446     if (address_type == kEndPointDomain) {
447       read_header_size += static_cast<uint8_t>(buffer_[4]);
448     } else if (address_type == kEndPointResolvedIPv4) {
449       read_header_size += sizeof(struct in_addr) - 1;
450     } else if (address_type == kEndPointResolvedIPv6) {
451       read_header_size += sizeof(struct in6_addr) - 1;
452     } else {
453       net_log_.AddEventWithIntParams(
454           NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type",
455           buffer_[3]);
456       return ERR_SOCKS_CONNECTION_FAILED;
457     }
458 
459     read_header_size += 2;  // for the port.
460     next_state_ = STATE_HANDSHAKE_READ;
461     return OK;
462   }
463 
464   // When the final bytes are read, setup handshake. We ignore the rest
465   // of the response since they represent the SOCKSv5 endpoint and have
466   // no use when doing a tunnel connection.
467   if (bytes_received_ == read_header_size) {
468     completed_handshake_ = true;
469     buffer_.clear();
470     next_state_ = STATE_NONE;
471     return OK;
472   }
473 
474   next_state_ = STATE_HANDSHAKE_READ;
475   return OK;
476 }
477 
GetPeerAddress(IPEndPoint * address) const478 int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
479   return transport_socket_->GetPeerAddress(address);
480 }
481 
GetLocalAddress(IPEndPoint * address) const482 int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
483   return transport_socket_->GetLocalAddress(address);
484 }
485 
486 }  // namespace net
487