1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socks5_client_socket.h"
6
7 #include <utility>
8
9 #include "base/compiler_specific.h"
10 #include "base/format_macros.h"
11 #include "base/functional/bind.h"
12 #include "base/functional/callback_helpers.h"
13 #include "base/strings/string_util.h"
14 #include "base/sys_byteorder.h"
15 #include "net/base/io_buffer.h"
16 #include "net/base/sys_addrinfo.h"
17 #include "net/base/tracing.h"
18 #include "net/log/net_log.h"
19 #include "net/log/net_log_event_type.h"
20 #include "net/traffic_annotation/network_traffic_annotation.h"
21
22 namespace net {
23
24 const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
25 const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
26 const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
27 const uint8_t SOCKS5ClientSocket::kSOCKS5Version = 0x05;
28 const uint8_t SOCKS5ClientSocket::kTunnelCommand = 0x01;
29 const uint8_t SOCKS5ClientSocket::kNullByte = 0x00;
30
31 static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4");
32 static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6");
33
SOCKS5ClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkTrafficAnnotationTag & traffic_annotation)34 SOCKS5ClientSocket::SOCKS5ClientSocket(
35 std::unique_ptr<StreamSocket> transport_socket,
36 const HostPortPair& destination,
37 const NetworkTrafficAnnotationTag& traffic_annotation)
38 : io_callback_(base::BindRepeating(&SOCKS5ClientSocket::OnIOComplete,
39 base::Unretained(this))),
40 transport_socket_(std::move(transport_socket)),
41 read_header_size(kReadHeaderSize),
42 destination_(destination),
43 net_log_(transport_socket_->NetLog()),
44 traffic_annotation_(traffic_annotation) {}
45
~SOCKS5ClientSocket()46 SOCKS5ClientSocket::~SOCKS5ClientSocket() {
47 Disconnect();
48 }
49
Connect(CompletionOnceCallback callback)50 int SOCKS5ClientSocket::Connect(CompletionOnceCallback callback) {
51 DCHECK(transport_socket_);
52 DCHECK_EQ(STATE_NONE, next_state_);
53 DCHECK(user_callback_.is_null());
54
55 // If already connected, then just return OK.
56 if (completed_handshake_)
57 return OK;
58
59 net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT);
60
61 next_state_ = STATE_GREET_WRITE;
62 buffer_.clear();
63
64 int rv = DoLoop(OK);
65 if (rv == ERR_IO_PENDING) {
66 user_callback_ = std::move(callback);
67 } else {
68 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv);
69 }
70 return rv;
71 }
72
Disconnect()73 void SOCKS5ClientSocket::Disconnect() {
74 completed_handshake_ = false;
75 transport_socket_->Disconnect();
76
77 // Reset other states to make sure they aren't mistakenly used later.
78 // These are the states initialized by Connect().
79 next_state_ = STATE_NONE;
80 user_callback_.Reset();
81 }
82
IsConnected() const83 bool SOCKS5ClientSocket::IsConnected() const {
84 return completed_handshake_ && transport_socket_->IsConnected();
85 }
86
IsConnectedAndIdle() const87 bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
88 return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
89 }
90
NetLog() const91 const NetLogWithSource& SOCKS5ClientSocket::NetLog() const {
92 return net_log_;
93 }
94
WasEverUsed() const95 bool SOCKS5ClientSocket::WasEverUsed() const {
96 return was_ever_used_;
97 }
98
GetNegotiatedProtocol() const99 NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
100 if (transport_socket_)
101 return transport_socket_->GetNegotiatedProtocol();
102 NOTREACHED();
103 return kProtoUnknown;
104 }
105
GetSSLInfo(SSLInfo * ssl_info)106 bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
107 if (transport_socket_)
108 return transport_socket_->GetSSLInfo(ssl_info);
109 NOTREACHED();
110 return false;
111 }
112
GetTotalReceivedBytes() const113 int64_t SOCKS5ClientSocket::GetTotalReceivedBytes() const {
114 return transport_socket_->GetTotalReceivedBytes();
115 }
116
ApplySocketTag(const SocketTag & tag)117 void SOCKS5ClientSocket::ApplySocketTag(const SocketTag& tag) {
118 return transport_socket_->ApplySocketTag(tag);
119 }
120
121 // Read is called by the transport layer above to read. This can only be done
122 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)123 int SOCKS5ClientSocket::Read(IOBuffer* buf,
124 int buf_len,
125 CompletionOnceCallback callback) {
126 DCHECK(completed_handshake_);
127 DCHECK_EQ(STATE_NONE, next_state_);
128 DCHECK(user_callback_.is_null());
129 DCHECK(!callback.is_null());
130
131 int rv = transport_socket_->Read(
132 buf, buf_len,
133 base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
134 base::Unretained(this), std::move(callback)));
135 if (rv > 0)
136 was_ever_used_ = true;
137 return rv;
138 }
139
140 // Write is called by the transport layer. This can only be done if the
141 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)142 int SOCKS5ClientSocket::Write(
143 IOBuffer* buf,
144 int buf_len,
145 CompletionOnceCallback callback,
146 const NetworkTrafficAnnotationTag& traffic_annotation) {
147 DCHECK(completed_handshake_);
148 DCHECK_EQ(STATE_NONE, next_state_);
149 DCHECK(user_callback_.is_null());
150 DCHECK(!callback.is_null());
151
152 int rv = transport_socket_->Write(
153 buf, buf_len,
154 base::BindOnce(&SOCKS5ClientSocket::OnReadWriteComplete,
155 base::Unretained(this), std::move(callback)),
156 traffic_annotation);
157 if (rv > 0)
158 was_ever_used_ = true;
159 return rv;
160 }
161
SetReceiveBufferSize(int32_t size)162 int SOCKS5ClientSocket::SetReceiveBufferSize(int32_t size) {
163 return transport_socket_->SetReceiveBufferSize(size);
164 }
165
SetSendBufferSize(int32_t size)166 int SOCKS5ClientSocket::SetSendBufferSize(int32_t size) {
167 return transport_socket_->SetSendBufferSize(size);
168 }
169
DoCallback(int result)170 void SOCKS5ClientSocket::DoCallback(int result) {
171 DCHECK_NE(ERR_IO_PENDING, result);
172 DCHECK(!user_callback_.is_null());
173
174 // Since Run() may result in Read being called,
175 // clear user_callback_ up front.
176 std::move(user_callback_).Run(result);
177 }
178
OnIOComplete(int result)179 void SOCKS5ClientSocket::OnIOComplete(int result) {
180 DCHECK_NE(STATE_NONE, next_state_);
181 int rv = DoLoop(result);
182 if (rv != ERR_IO_PENDING) {
183 net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT);
184 DoCallback(rv);
185 }
186 }
187
OnReadWriteComplete(CompletionOnceCallback callback,int result)188 void SOCKS5ClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
189 int result) {
190 DCHECK_NE(ERR_IO_PENDING, result);
191 DCHECK(!callback.is_null());
192
193 if (result > 0)
194 was_ever_used_ = true;
195 std::move(callback).Run(result);
196 }
197
DoLoop(int last_io_result)198 int SOCKS5ClientSocket::DoLoop(int last_io_result) {
199 DCHECK_NE(next_state_, STATE_NONE);
200 int rv = last_io_result;
201 do {
202 State state = next_state_;
203 next_state_ = STATE_NONE;
204 switch (state) {
205 case STATE_GREET_WRITE:
206 DCHECK_EQ(OK, rv);
207 net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_WRITE);
208 rv = DoGreetWrite();
209 break;
210 case STATE_GREET_WRITE_COMPLETE:
211 rv = DoGreetWriteComplete(rv);
212 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE,
213 rv);
214 break;
215 case STATE_GREET_READ:
216 DCHECK_EQ(OK, rv);
217 net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ);
218 rv = DoGreetRead();
219 break;
220 case STATE_GREET_READ_COMPLETE:
221 rv = DoGreetReadComplete(rv);
222 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ,
223 rv);
224 break;
225 case STATE_HANDSHAKE_WRITE:
226 DCHECK_EQ(OK, rv);
227 net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE);
228 rv = DoHandshakeWrite();
229 break;
230 case STATE_HANDSHAKE_WRITE_COMPLETE:
231 rv = DoHandshakeWriteComplete(rv);
232 net_log_.EndEventWithNetErrorCode(
233 NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv);
234 break;
235 case STATE_HANDSHAKE_READ:
236 DCHECK_EQ(OK, rv);
237 net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ);
238 rv = DoHandshakeRead();
239 break;
240 case STATE_HANDSHAKE_READ_COMPLETE:
241 rv = DoHandshakeReadComplete(rv);
242 net_log_.EndEventWithNetErrorCode(
243 NetLogEventType::SOCKS5_HANDSHAKE_READ, rv);
244 break;
245 default:
246 NOTREACHED() << "bad state";
247 rv = ERR_UNEXPECTED;
248 break;
249 }
250 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
251 return rv;
252 }
253
254 const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 }; // no authentication
255
DoGreetWrite()256 int SOCKS5ClientSocket::DoGreetWrite() {
257 // Since we only have 1 byte to send the hostname length in, if the
258 // URL has a hostname longer than 255 characters we can't send it.
259 if (0xFF < destination_.host().size()) {
260 net_log_.AddEvent(NetLogEventType::SOCKS_HOSTNAME_TOO_BIG);
261 return ERR_SOCKS_CONNECTION_FAILED;
262 }
263
264 if (buffer_.empty()) {
265 buffer_ =
266 std::string(kSOCKS5GreetWriteData, std::size(kSOCKS5GreetWriteData));
267 bytes_sent_ = 0;
268 }
269
270 next_state_ = STATE_GREET_WRITE_COMPLETE;
271 size_t handshake_buf_len = buffer_.size() - bytes_sent_;
272 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
273 memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
274 handshake_buf_len);
275 return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
276 io_callback_, traffic_annotation_);
277 }
278
DoGreetWriteComplete(int result)279 int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
280 if (result < 0)
281 return result;
282
283 bytes_sent_ += result;
284 if (bytes_sent_ == buffer_.size()) {
285 buffer_.clear();
286 bytes_received_ = 0;
287 next_state_ = STATE_GREET_READ;
288 } else {
289 next_state_ = STATE_GREET_WRITE;
290 }
291 return OK;
292 }
293
DoGreetRead()294 int SOCKS5ClientSocket::DoGreetRead() {
295 next_state_ = STATE_GREET_READ_COMPLETE;
296 size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
297 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
298 return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
299 io_callback_);
300 }
301
DoGreetReadComplete(int result)302 int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
303 if (result < 0)
304 return result;
305
306 if (result == 0) {
307 net_log_.AddEvent(
308 NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
309 return ERR_SOCKS_CONNECTION_FAILED;
310 }
311
312 bytes_received_ += result;
313 buffer_.append(handshake_buf_->data(), result);
314 if (bytes_received_ < kGreetReadHeaderSize) {
315 next_state_ = STATE_GREET_READ;
316 return OK;
317 }
318
319 // Got the greet data.
320 if (buffer_[0] != kSOCKS5Version) {
321 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
322 "version", buffer_[0]);
323 return ERR_SOCKS_CONNECTION_FAILED;
324 }
325 if (buffer_[1] != 0x00) {
326 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_AUTH,
327 "method", buffer_[1]);
328 return ERR_SOCKS_CONNECTION_FAILED;
329 }
330
331 buffer_.clear();
332 next_state_ = STATE_HANDSHAKE_WRITE;
333 return OK;
334 }
335
BuildHandshakeWriteBuffer(std::string * handshake) const336 int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
337 const {
338 DCHECK(handshake->empty());
339
340 handshake->push_back(kSOCKS5Version);
341 handshake->push_back(kTunnelCommand); // Connect command
342 handshake->push_back(kNullByte); // Reserved null
343
344 handshake->push_back(kEndPointDomain); // The type of the address.
345
346 DCHECK_GE(static_cast<size_t>(0xFF), destination_.host().size());
347
348 // First add the size of the hostname, followed by the hostname.
349 handshake->push_back(static_cast<unsigned char>(destination_.host().size()));
350 handshake->append(destination_.host());
351
352 uint16_t nw_port = base::HostToNet16(destination_.port());
353 handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
354 return OK;
355 }
356
357 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()358 int SOCKS5ClientSocket::DoHandshakeWrite() {
359 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
360
361 if (buffer_.empty()) {
362 int rv = BuildHandshakeWriteBuffer(&buffer_);
363 if (rv != OK)
364 return rv;
365 bytes_sent_ = 0;
366 }
367
368 int handshake_buf_len = buffer_.size() - bytes_sent_;
369 DCHECK_LT(0, handshake_buf_len);
370 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
371 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
372 handshake_buf_len);
373 return transport_socket_->Write(handshake_buf_.get(), handshake_buf_len,
374 io_callback_, traffic_annotation_);
375 }
376
DoHandshakeWriteComplete(int result)377 int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
378 if (result < 0)
379 return result;
380
381 // We ignore the case when result is 0, since the underlying Write
382 // may return spurious writes while waiting on the socket.
383
384 bytes_sent_ += result;
385 if (bytes_sent_ == buffer_.size()) {
386 next_state_ = STATE_HANDSHAKE_READ;
387 buffer_.clear();
388 } else if (bytes_sent_ < buffer_.size()) {
389 next_state_ = STATE_HANDSHAKE_WRITE;
390 } else {
391 NOTREACHED();
392 }
393
394 return OK;
395 }
396
DoHandshakeRead()397 int SOCKS5ClientSocket::DoHandshakeRead() {
398 next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
399
400 if (buffer_.empty()) {
401 bytes_received_ = 0;
402 read_header_size = kReadHeaderSize;
403 }
404
405 int handshake_buf_len = read_header_size - bytes_received_;
406 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
407 return transport_socket_->Read(handshake_buf_.get(), handshake_buf_len,
408 io_callback_);
409 }
410
DoHandshakeReadComplete(int result)411 int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
412 if (result < 0)
413 return result;
414
415 // The underlying socket closed unexpectedly.
416 if (result == 0) {
417 net_log_.AddEvent(
418 NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
419 return ERR_SOCKS_CONNECTION_FAILED;
420 }
421
422 buffer_.append(handshake_buf_->data(), result);
423 bytes_received_ += result;
424
425 // When the first few bytes are read, check how many more are required
426 // and accordingly increase them
427 if (bytes_received_ == kReadHeaderSize) {
428 if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
429 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
430 "version", buffer_[0]);
431 return ERR_SOCKS_CONNECTION_FAILED;
432 }
433 if (buffer_[1] != 0x00) {
434 net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR,
435 "error_code", buffer_[1]);
436 return ERR_SOCKS_CONNECTION_FAILED;
437 }
438
439 // We check the type of IP/Domain the server returns and accordingly
440 // increase the size of the response. For domains, we need to read the
441 // size of the domain, so the initial request size is upto the domain
442 // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
443 // read, we substract 1 byte from the additional request size.
444 SocksEndPointAddressType address_type =
445 static_cast<SocksEndPointAddressType>(buffer_[3]);
446 if (address_type == kEndPointDomain) {
447 read_header_size += static_cast<uint8_t>(buffer_[4]);
448 } else if (address_type == kEndPointResolvedIPv4) {
449 read_header_size += sizeof(struct in_addr) - 1;
450 } else if (address_type == kEndPointResolvedIPv6) {
451 read_header_size += sizeof(struct in6_addr) - 1;
452 } else {
453 net_log_.AddEventWithIntParams(
454 NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type",
455 buffer_[3]);
456 return ERR_SOCKS_CONNECTION_FAILED;
457 }
458
459 read_header_size += 2; // for the port.
460 next_state_ = STATE_HANDSHAKE_READ;
461 return OK;
462 }
463
464 // When the final bytes are read, setup handshake. We ignore the rest
465 // of the response since they represent the SOCKSv5 endpoint and have
466 // no use when doing a tunnel connection.
467 if (bytes_received_ == read_header_size) {
468 completed_handshake_ = true;
469 buffer_.clear();
470 next_state_ = STATE_NONE;
471 return OK;
472 }
473
474 next_state_ = STATE_HANDSHAKE_READ;
475 return OK;
476 }
477
GetPeerAddress(IPEndPoint * address) const478 int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
479 return transport_socket_->GetPeerAddress(address);
480 }
481
GetLocalAddress(IPEndPoint * address) const482 int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
483 return transport_socket_->GetLocalAddress(address);
484 }
485
486 } // namespace net
487