1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/socket/socks_client_socket.h"
6
7 #include <utility>
8
9 #include "base/compiler_specific.h"
10 #include "base/functional/bind.h"
11 #include "base/functional/callback_helpers.h"
12 #include "base/sys_byteorder.h"
13 #include "net/base/address_list.h"
14 #include "net/base/io_buffer.h"
15 #include "net/dns/public/dns_query_type.h"
16 #include "net/dns/public/secure_dns_policy.h"
17 #include "net/log/net_log.h"
18 #include "net/log/net_log_event_type.h"
19 #include "net/traffic_annotation/network_traffic_annotation.h"
20
21 namespace net {
22
23 // Every SOCKS server requests a user-id from the client. It is optional
24 // and we send an empty string.
25 static const char kEmptyUserId[] = "";
26
27 // For SOCKS4, the client sends 8 bytes plus the size of the user-id.
28 static const unsigned int kWriteHeaderSize = 8;
29
30 // For SOCKS4 the server sends 8 bytes for acknowledgement.
31 static const unsigned int kReadHeaderSize = 8;
32
33 // Server Response codes for SOCKS.
34 static const uint8_t kServerResponseOk = 0x5A;
35 static const uint8_t kServerResponseRejected = 0x5B;
36 static const uint8_t kServerResponseNotReachable = 0x5C;
37 static const uint8_t kServerResponseMismatchedUserId = 0x5D;
38
39 static const uint8_t kSOCKSVersion4 = 0x04;
40 static const uint8_t kSOCKSStreamRequest = 0x01;
41
42 // A struct holding the essential details of the SOCKS4 Server Request.
43 // The port in the header is stored in network byte order.
44 struct SOCKS4ServerRequest {
45 uint8_t version;
46 uint8_t command;
47 uint16_t nw_port;
48 uint8_t ip[4];
49 };
50 static_assert(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
51 "socks4 server request struct has incorrect size");
52
53 // A struct holding details of the SOCKS4 Server Response.
54 struct SOCKS4ServerResponse {
55 uint8_t reserved_null;
56 uint8_t code;
57 uint16_t port;
58 uint8_t ip[4];
59 };
60 static_assert(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
61 "socks4 server response struct has incorrect size");
62
SOCKSClientSocket(std::unique_ptr<StreamSocket> transport_socket,const HostPortPair & destination,const NetworkAnonymizationKey & network_anonymization_key,RequestPriority priority,HostResolver * host_resolver,SecureDnsPolicy secure_dns_policy,const NetworkTrafficAnnotationTag & traffic_annotation)63 SOCKSClientSocket::SOCKSClientSocket(
64 std::unique_ptr<StreamSocket> transport_socket,
65 const HostPortPair& destination,
66 const NetworkAnonymizationKey& network_anonymization_key,
67 RequestPriority priority,
68 HostResolver* host_resolver,
69 SecureDnsPolicy secure_dns_policy,
70 const NetworkTrafficAnnotationTag& traffic_annotation)
71 : transport_socket_(std::move(transport_socket)),
72 host_resolver_(host_resolver),
73 secure_dns_policy_(secure_dns_policy),
74 destination_(destination),
75 network_anonymization_key_(network_anonymization_key),
76 priority_(priority),
77 net_log_(transport_socket_->NetLog()),
78 traffic_annotation_(traffic_annotation) {}
79
~SOCKSClientSocket()80 SOCKSClientSocket::~SOCKSClientSocket() {
81 Disconnect();
82 }
83
Connect(CompletionOnceCallback callback)84 int SOCKSClientSocket::Connect(CompletionOnceCallback callback) {
85 DCHECK(transport_socket_);
86 DCHECK_EQ(STATE_NONE, next_state_);
87 DCHECK(user_callback_.is_null());
88
89 // If already connected, then just return OK.
90 if (completed_handshake_)
91 return OK;
92
93 next_state_ = STATE_RESOLVE_HOST;
94
95 net_log_.BeginEvent(NetLogEventType::SOCKS_CONNECT);
96
97 int rv = DoLoop(OK);
98 if (rv == ERR_IO_PENDING) {
99 user_callback_ = std::move(callback);
100 } else {
101 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS_CONNECT, rv);
102 }
103 return rv;
104 }
105
Disconnect()106 void SOCKSClientSocket::Disconnect() {
107 completed_handshake_ = false;
108 resolve_host_request_.reset();
109 transport_socket_->Disconnect();
110
111 // Reset other states to make sure they aren't mistakenly used later.
112 // These are the states initialized by Connect().
113 next_state_ = STATE_NONE;
114 user_callback_.Reset();
115 }
116
IsConnected() const117 bool SOCKSClientSocket::IsConnected() const {
118 return completed_handshake_ && transport_socket_->IsConnected();
119 }
120
IsConnectedAndIdle() const121 bool SOCKSClientSocket::IsConnectedAndIdle() const {
122 return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
123 }
124
NetLog() const125 const NetLogWithSource& SOCKSClientSocket::NetLog() const {
126 return net_log_;
127 }
128
WasEverUsed() const129 bool SOCKSClientSocket::WasEverUsed() const {
130 return was_ever_used_;
131 }
132
GetNegotiatedProtocol() const133 NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
134 if (transport_socket_)
135 return transport_socket_->GetNegotiatedProtocol();
136 NOTREACHED();
137 return kProtoUnknown;
138 }
139
GetSSLInfo(SSLInfo * ssl_info)140 bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
141 if (transport_socket_)
142 return transport_socket_->GetSSLInfo(ssl_info);
143 NOTREACHED();
144 return false;
145 }
146
GetTotalReceivedBytes() const147 int64_t SOCKSClientSocket::GetTotalReceivedBytes() const {
148 return transport_socket_->GetTotalReceivedBytes();
149 }
150
ApplySocketTag(const SocketTag & tag)151 void SOCKSClientSocket::ApplySocketTag(const SocketTag& tag) {
152 return transport_socket_->ApplySocketTag(tag);
153 }
154
155 // Read is called by the transport layer above to read. This can only be done
156 // if the SOCKS handshake is complete.
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)157 int SOCKSClientSocket::Read(IOBuffer* buf,
158 int buf_len,
159 CompletionOnceCallback callback) {
160 DCHECK(completed_handshake_);
161 DCHECK_EQ(STATE_NONE, next_state_);
162 DCHECK(user_callback_.is_null());
163 DCHECK(!callback.is_null());
164
165 int rv = transport_socket_->Read(
166 buf, buf_len,
167 base::BindOnce(&SOCKSClientSocket::OnReadWriteComplete,
168 base::Unretained(this), std::move(callback)));
169 if (rv > 0)
170 was_ever_used_ = true;
171 return rv;
172 }
173
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)174 int SOCKSClientSocket::ReadIfReady(IOBuffer* buf,
175 int buf_len,
176 CompletionOnceCallback callback) {
177 DCHECK(completed_handshake_);
178 DCHECK_EQ(STATE_NONE, next_state_);
179 DCHECK(user_callback_.is_null());
180 DCHECK(!callback.is_null());
181
182 // Pass |callback| directly instead of wrapping it with OnReadWriteComplete.
183 // This is to avoid setting |was_ever_used_| unless data is actually read.
184 int rv = transport_socket_->ReadIfReady(buf, buf_len, std::move(callback));
185 if (rv > 0)
186 was_ever_used_ = true;
187 return rv;
188 }
189
CancelReadIfReady()190 int SOCKSClientSocket::CancelReadIfReady() {
191 return transport_socket_->CancelReadIfReady();
192 }
193
194 // Write is called by the transport layer. This can only be done if the
195 // SOCKS handshake is complete.
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)196 int SOCKSClientSocket::Write(
197 IOBuffer* buf,
198 int buf_len,
199 CompletionOnceCallback callback,
200 const NetworkTrafficAnnotationTag& traffic_annotation) {
201 DCHECK(completed_handshake_);
202 DCHECK_EQ(STATE_NONE, next_state_);
203 DCHECK(user_callback_.is_null());
204 DCHECK(!callback.is_null());
205
206 int rv = transport_socket_->Write(
207 buf, buf_len,
208 base::BindOnce(&SOCKSClientSocket::OnReadWriteComplete,
209 base::Unretained(this), std::move(callback)),
210 traffic_annotation);
211 if (rv > 0)
212 was_ever_used_ = true;
213 return rv;
214 }
215
SetReceiveBufferSize(int32_t size)216 int SOCKSClientSocket::SetReceiveBufferSize(int32_t size) {
217 return transport_socket_->SetReceiveBufferSize(size);
218 }
219
SetSendBufferSize(int32_t size)220 int SOCKSClientSocket::SetSendBufferSize(int32_t size) {
221 return transport_socket_->SetSendBufferSize(size);
222 }
223
DoCallback(int result)224 void SOCKSClientSocket::DoCallback(int result) {
225 DCHECK_NE(ERR_IO_PENDING, result);
226 DCHECK(!user_callback_.is_null());
227
228 // Since Run() may result in Read being called,
229 // clear user_callback_ up front.
230 DVLOG(1) << "Finished setting up SOCKS handshake";
231 std::move(user_callback_).Run(result);
232 }
233
OnIOComplete(int result)234 void SOCKSClientSocket::OnIOComplete(int result) {
235 DCHECK_NE(STATE_NONE, next_state_);
236 int rv = DoLoop(result);
237 if (rv != ERR_IO_PENDING) {
238 net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS_CONNECT, rv);
239 DoCallback(rv);
240 }
241 }
242
OnReadWriteComplete(CompletionOnceCallback callback,int result)243 void SOCKSClientSocket::OnReadWriteComplete(CompletionOnceCallback callback,
244 int result) {
245 DCHECK_NE(ERR_IO_PENDING, result);
246 DCHECK(!callback.is_null());
247
248 if (result > 0)
249 was_ever_used_ = true;
250 std::move(callback).Run(result);
251 }
252
DoLoop(int last_io_result)253 int SOCKSClientSocket::DoLoop(int last_io_result) {
254 DCHECK_NE(next_state_, STATE_NONE);
255 int rv = last_io_result;
256 do {
257 State state = next_state_;
258 next_state_ = STATE_NONE;
259 switch (state) {
260 case STATE_RESOLVE_HOST:
261 DCHECK_EQ(OK, rv);
262 rv = DoResolveHost();
263 break;
264 case STATE_RESOLVE_HOST_COMPLETE:
265 rv = DoResolveHostComplete(rv);
266 break;
267 case STATE_HANDSHAKE_WRITE:
268 DCHECK_EQ(OK, rv);
269 rv = DoHandshakeWrite();
270 break;
271 case STATE_HANDSHAKE_WRITE_COMPLETE:
272 rv = DoHandshakeWriteComplete(rv);
273 break;
274 case STATE_HANDSHAKE_READ:
275 DCHECK_EQ(OK, rv);
276 rv = DoHandshakeRead();
277 break;
278 case STATE_HANDSHAKE_READ_COMPLETE:
279 rv = DoHandshakeReadComplete(rv);
280 break;
281 default:
282 NOTREACHED() << "bad state";
283 rv = ERR_UNEXPECTED;
284 break;
285 }
286 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
287 return rv;
288 }
289
DoResolveHost()290 int SOCKSClientSocket::DoResolveHost() {
291 next_state_ = STATE_RESOLVE_HOST_COMPLETE;
292 // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
293 // addresses for the target host.
294 HostResolver::ResolveHostParameters parameters;
295 parameters.dns_query_type = DnsQueryType::A;
296 parameters.initial_priority = priority_;
297 parameters.secure_dns_policy = secure_dns_policy_;
298 resolve_host_request_ = host_resolver_->CreateRequest(
299 destination_, network_anonymization_key_, net_log_, parameters);
300
301 return resolve_host_request_->Start(
302 base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
303 }
304
DoResolveHostComplete(int result)305 int SOCKSClientSocket::DoResolveHostComplete(int result) {
306 resolve_error_info_ = resolve_host_request_->GetResolveErrorInfo();
307 if (result != OK) {
308 // Resolving the hostname failed; fail the request rather than automatically
309 // falling back to SOCKS4a (since it can be confusing to see invalid IP
310 // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
311 return result;
312 }
313
314 next_state_ = STATE_HANDSHAKE_WRITE;
315 return OK;
316 }
317
318 // Builds the buffer that is to be sent to the server.
BuildHandshakeWriteBuffer() const319 const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
320 SOCKS4ServerRequest request;
321 request.version = kSOCKSVersion4;
322 request.command = kSOCKSStreamRequest;
323 request.nw_port = base::HostToNet16(destination_.port());
324
325 DCHECK(resolve_host_request_->GetAddressResults() &&
326 !resolve_host_request_->GetAddressResults()->empty());
327 const IPEndPoint& endpoint =
328 resolve_host_request_->GetAddressResults()->front();
329
330 // We disabled IPv6 results when resolving the hostname, so none of the
331 // results in the list will be IPv6.
332 // TODO(eroman): we only ever use the first address in the list. It would be
333 // more robust to try all the IP addresses we have before
334 // failing the connect attempt.
335 CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
336 CHECK_LE(endpoint.address().size(), sizeof(request.ip));
337 memcpy(&request.ip, &endpoint.address().bytes()[0],
338 endpoint.address().size());
339
340 DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
341
342 std::string handshake_data(reinterpret_cast<char*>(&request),
343 sizeof(request));
344 handshake_data.append(kEmptyUserId, std::size(kEmptyUserId));
345
346 return handshake_data;
347 }
348
349 // Writes the SOCKS handshake data to the underlying socket connection.
DoHandshakeWrite()350 int SOCKSClientSocket::DoHandshakeWrite() {
351 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
352
353 if (buffer_.empty()) {
354 buffer_ = BuildHandshakeWriteBuffer();
355 bytes_sent_ = 0;
356 }
357
358 int handshake_buf_len = buffer_.size() - bytes_sent_;
359 DCHECK_GT(handshake_buf_len, 0);
360 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
361 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
362 handshake_buf_len);
363 return transport_socket_->Write(
364 handshake_buf_.get(), handshake_buf_len,
365 base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
366 traffic_annotation_);
367 }
368
DoHandshakeWriteComplete(int result)369 int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
370 if (result < 0)
371 return result;
372
373 // We ignore the case when result is 0, since the underlying Write
374 // may return spurious writes while waiting on the socket.
375
376 bytes_sent_ += result;
377 if (bytes_sent_ == buffer_.size()) {
378 next_state_ = STATE_HANDSHAKE_READ;
379 buffer_.clear();
380 } else if (bytes_sent_ < buffer_.size()) {
381 next_state_ = STATE_HANDSHAKE_WRITE;
382 } else {
383 return ERR_UNEXPECTED;
384 }
385
386 return OK;
387 }
388
DoHandshakeRead()389 int SOCKSClientSocket::DoHandshakeRead() {
390 next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
391
392 if (buffer_.empty()) {
393 bytes_received_ = 0;
394 }
395
396 int handshake_buf_len = kReadHeaderSize - bytes_received_;
397 handshake_buf_ = base::MakeRefCounted<IOBufferWithSize>(handshake_buf_len);
398 return transport_socket_->Read(
399 handshake_buf_.get(), handshake_buf_len,
400 base::BindOnce(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
401 }
402
DoHandshakeReadComplete(int result)403 int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
404 if (result < 0)
405 return result;
406
407 // The underlying socket closed unexpectedly.
408 if (result == 0)
409 return ERR_CONNECTION_CLOSED;
410
411 if (bytes_received_ + result > kReadHeaderSize) {
412 // TODO(eroman): Describe failure in NetLog.
413 return ERR_SOCKS_CONNECTION_FAILED;
414 }
415
416 buffer_.append(handshake_buf_->data(), result);
417 bytes_received_ += result;
418 if (bytes_received_ < kReadHeaderSize) {
419 next_state_ = STATE_HANDSHAKE_READ;
420 return OK;
421 }
422
423 const SOCKS4ServerResponse* response =
424 reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
425
426 if (response->reserved_null != 0x00) {
427 DVLOG(1) << "Unknown response from SOCKS server.";
428 return ERR_SOCKS_CONNECTION_FAILED;
429 }
430
431 switch (response->code) {
432 case kServerResponseOk:
433 completed_handshake_ = true;
434 return OK;
435 case kServerResponseRejected:
436 DVLOG(1) << "SOCKS request rejected or failed";
437 return ERR_SOCKS_CONNECTION_FAILED;
438 case kServerResponseNotReachable:
439 DVLOG(1) << "SOCKS request failed because client is not running "
440 << "identd (or not reachable from the server)";
441 return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
442 case kServerResponseMismatchedUserId:
443 DVLOG(1) << "SOCKS request failed because client's identd could "
444 << "not confirm the user ID string in the request";
445 return ERR_SOCKS_CONNECTION_FAILED;
446 default:
447 DVLOG(1) << "SOCKS server sent unknown response";
448 return ERR_SOCKS_CONNECTION_FAILED;
449 }
450
451 // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
452 }
453
GetPeerAddress(IPEndPoint * address) const454 int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
455 return transport_socket_->GetPeerAddress(address);
456 }
457
GetLocalAddress(IPEndPoint * address) const458 int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
459 return transport_socket_->GetLocalAddress(address);
460 }
461
GetResolveErrorInfo() const462 ResolveErrorInfo SOCKSClientSocket::GetResolveErrorInfo() const {
463 return resolve_error_info_;
464 }
465
466 } // namespace net
467