1 /*
2 * Copyright 2015 The WebRTC Project Authors. All rights reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include "p2p/stunprober/stun_prober.h"
12
13 #include <map>
14 #include <memory>
15 #include <set>
16 #include <string>
17 #include <utility>
18
19 #include "api/packet_socket_factory.h"
20 #include "api/task_queue/pending_task_safety_flag.h"
21 #include "api/transport/stun.h"
22 #include "api/units/time_delta.h"
23 #include "rtc_base/async_packet_socket.h"
24 #include "rtc_base/async_resolver_interface.h"
25 #include "rtc_base/checks.h"
26 #include "rtc_base/helpers.h"
27 #include "rtc_base/logging.h"
28 #include "rtc_base/thread.h"
29 #include "rtc_base/time_utils.h"
30
31 namespace stunprober {
32
33 namespace {
34 using ::webrtc::SafeTask;
35 using ::webrtc::TimeDelta;
36
37 const int THREAD_WAKE_UP_INTERVAL_MS = 5;
38
39 template <typename T>
IncrementCounterByAddress(std::map<T,int> * counter_per_ip,const T & ip)40 void IncrementCounterByAddress(std::map<T, int>* counter_per_ip, const T& ip) {
41 counter_per_ip->insert(std::make_pair(ip, 0)).first->second++;
42 }
43
44 } // namespace
45
46 // A requester tracks the requests and responses from a single socket to many
47 // STUN servers
48 class StunProber::Requester : public sigslot::has_slots<> {
49 public:
50 // Each Request maps to a request and response.
51 struct Request {
52 // Actual time the STUN bind request was sent.
53 int64_t sent_time_ms = 0;
54 // Time the response was received.
55 int64_t received_time_ms = 0;
56
57 // Server reflexive address from STUN response for this given request.
58 rtc::SocketAddress srflx_addr;
59
60 rtc::IPAddress server_addr;
61
rttstunprober::StunProber::Requester::Request62 int64_t rtt() { return received_time_ms - sent_time_ms; }
63 void ProcessResponse(const char* buf, size_t buf_len);
64 };
65
66 // StunProber provides `server_ips` for Requester to probe. For shared
67 // socket mode, it'll be all the resolved IP addresses. For non-shared mode,
68 // it'll just be a single address.
69 Requester(StunProber* prober,
70 rtc::AsyncPacketSocket* socket,
71 const std::vector<rtc::SocketAddress>& server_ips);
72 ~Requester() override;
73
74 Requester(const Requester&) = delete;
75 Requester& operator=(const Requester&) = delete;
76
77 // There is no callback for SendStunRequest as the underneath socket send is
78 // expected to be completed immediately. Otherwise, it'll skip this request
79 // and move to the next one.
80 void SendStunRequest();
81
82 void OnStunResponseReceived(rtc::AsyncPacketSocket* socket,
83 const char* buf,
84 size_t size,
85 const rtc::SocketAddress& addr,
86 const int64_t& packet_time_us);
87
requests()88 const std::vector<Request*>& requests() { return requests_; }
89
90 // Whether this Requester has completed all requests.
Done()91 bool Done() {
92 return static_cast<size_t>(num_request_sent_) == server_ips_.size();
93 }
94
95 private:
96 Request* GetRequestByAddress(const rtc::IPAddress& ip);
97
98 StunProber* prober_;
99
100 // The socket for this session.
101 std::unique_ptr<rtc::AsyncPacketSocket> socket_;
102
103 // Temporary SocketAddress and buffer for RecvFrom.
104 rtc::SocketAddress addr_;
105 std::unique_ptr<rtc::ByteBufferWriter> response_packet_;
106
107 std::vector<Request*> requests_;
108 std::vector<rtc::SocketAddress> server_ips_;
109 int16_t num_request_sent_ = 0;
110 int16_t num_response_received_ = 0;
111
112 webrtc::SequenceChecker& thread_checker_;
113 };
114
Requester(StunProber * prober,rtc::AsyncPacketSocket * socket,const std::vector<rtc::SocketAddress> & server_ips)115 StunProber::Requester::Requester(
116 StunProber* prober,
117 rtc::AsyncPacketSocket* socket,
118 const std::vector<rtc::SocketAddress>& server_ips)
119 : prober_(prober),
120 socket_(socket),
121 response_packet_(new rtc::ByteBufferWriter(nullptr, kMaxUdpBufferSize)),
122 server_ips_(server_ips),
123 thread_checker_(prober->thread_checker_) {
124 socket_->SignalReadPacket.connect(
125 this, &StunProber::Requester::OnStunResponseReceived);
126 }
127
~Requester()128 StunProber::Requester::~Requester() {
129 if (socket_) {
130 socket_->Close();
131 }
132 for (auto* req : requests_) {
133 if (req) {
134 delete req;
135 }
136 }
137 }
138
SendStunRequest()139 void StunProber::Requester::SendStunRequest() {
140 RTC_DCHECK(thread_checker_.IsCurrent());
141 requests_.push_back(new Request());
142 Request& request = *(requests_.back());
143 // Random transaction ID, STUN_BINDING_REQUEST
144 cricket::StunMessage message(cricket::STUN_BINDING_REQUEST);
145
146 std::unique_ptr<rtc::ByteBufferWriter> request_packet(
147 new rtc::ByteBufferWriter(nullptr, kMaxUdpBufferSize));
148 if (!message.Write(request_packet.get())) {
149 prober_->ReportOnFinished(WRITE_FAILED);
150 return;
151 }
152
153 auto addr = server_ips_[num_request_sent_];
154 request.server_addr = addr.ipaddr();
155
156 // The write must succeed immediately. Otherwise, the calculating of the STUN
157 // request timing could become too complicated. Callback is ignored by passing
158 // empty AsyncCallback.
159 rtc::PacketOptions options;
160 int rv = socket_->SendTo(const_cast<char*>(request_packet->Data()),
161 request_packet->Length(), addr, options);
162 if (rv < 0) {
163 prober_->ReportOnFinished(WRITE_FAILED);
164 return;
165 }
166
167 request.sent_time_ms = rtc::TimeMillis();
168
169 num_request_sent_++;
170 RTC_DCHECK(static_cast<size_t>(num_request_sent_) <= server_ips_.size());
171 }
172
ProcessResponse(const char * buf,size_t buf_len)173 void StunProber::Requester::Request::ProcessResponse(const char* buf,
174 size_t buf_len) {
175 int64_t now = rtc::TimeMillis();
176 rtc::ByteBufferReader message(buf, buf_len);
177 cricket::StunMessage stun_response;
178 if (!stun_response.Read(&message)) {
179 // Invalid or incomplete STUN packet.
180 received_time_ms = 0;
181 return;
182 }
183
184 // Get external address of the socket.
185 const cricket::StunAddressAttribute* addr_attr =
186 stun_response.GetAddress(cricket::STUN_ATTR_MAPPED_ADDRESS);
187 if (addr_attr == nullptr) {
188 // Addresses not available to detect whether or not behind a NAT.
189 return;
190 }
191
192 if (addr_attr->family() != cricket::STUN_ADDRESS_IPV4 &&
193 addr_attr->family() != cricket::STUN_ADDRESS_IPV6) {
194 return;
195 }
196
197 received_time_ms = now;
198
199 srflx_addr = addr_attr->GetAddress();
200 }
201
OnStunResponseReceived(rtc::AsyncPacketSocket * socket,const char * buf,size_t size,const rtc::SocketAddress & addr,const int64_t &)202 void StunProber::Requester::OnStunResponseReceived(
203 rtc::AsyncPacketSocket* socket,
204 const char* buf,
205 size_t size,
206 const rtc::SocketAddress& addr,
207 const int64_t& /* packet_time_us */) {
208 RTC_DCHECK(thread_checker_.IsCurrent());
209 RTC_DCHECK(socket_);
210 Request* request = GetRequestByAddress(addr.ipaddr());
211 if (!request) {
212 // Something is wrong, finish the test.
213 prober_->ReportOnFinished(GENERIC_FAILURE);
214 return;
215 }
216
217 num_response_received_++;
218 request->ProcessResponse(buf, size);
219 }
220
GetRequestByAddress(const rtc::IPAddress & ipaddr)221 StunProber::Requester::Request* StunProber::Requester::GetRequestByAddress(
222 const rtc::IPAddress& ipaddr) {
223 RTC_DCHECK(thread_checker_.IsCurrent());
224 for (auto* request : requests_) {
225 if (request->server_addr == ipaddr) {
226 return request;
227 }
228 }
229
230 return nullptr;
231 }
232
233 StunProber::Stats::Stats() = default;
234
235 StunProber::Stats::~Stats() = default;
236
237 StunProber::ObserverAdapter::ObserverAdapter() = default;
238
239 StunProber::ObserverAdapter::~ObserverAdapter() = default;
240
OnPrepared(StunProber * stunprober,Status status)241 void StunProber::ObserverAdapter::OnPrepared(StunProber* stunprober,
242 Status status) {
243 if (status == SUCCESS) {
244 stunprober->Start(this);
245 } else {
246 callback_(stunprober, status);
247 }
248 }
249
OnFinished(StunProber * stunprober,Status status)250 void StunProber::ObserverAdapter::OnFinished(StunProber* stunprober,
251 Status status) {
252 callback_(stunprober, status);
253 }
254
StunProber(rtc::PacketSocketFactory * socket_factory,rtc::Thread * thread,std::vector<const rtc::Network * > networks)255 StunProber::StunProber(rtc::PacketSocketFactory* socket_factory,
256 rtc::Thread* thread,
257 std::vector<const rtc::Network*> networks)
258 : interval_ms_(0),
259 socket_factory_(socket_factory),
260 thread_(thread),
261 networks_(std::move(networks)) {}
262
~StunProber()263 StunProber::~StunProber() {
264 RTC_DCHECK(thread_checker_.IsCurrent());
265 for (auto* req : requesters_) {
266 if (req) {
267 delete req;
268 }
269 }
270 for (auto* s : sockets_) {
271 if (s) {
272 delete s;
273 }
274 }
275 }
276
Start(const std::vector<rtc::SocketAddress> & servers,bool shared_socket_mode,int interval_ms,int num_request_per_ip,int timeout_ms,const AsyncCallback callback)277 bool StunProber::Start(const std::vector<rtc::SocketAddress>& servers,
278 bool shared_socket_mode,
279 int interval_ms,
280 int num_request_per_ip,
281 int timeout_ms,
282 const AsyncCallback callback) {
283 observer_adapter_.set_callback(callback);
284 return Prepare(servers, shared_socket_mode, interval_ms, num_request_per_ip,
285 timeout_ms, &observer_adapter_);
286 }
287
Prepare(const std::vector<rtc::SocketAddress> & servers,bool shared_socket_mode,int interval_ms,int num_request_per_ip,int timeout_ms,StunProber::Observer * observer)288 bool StunProber::Prepare(const std::vector<rtc::SocketAddress>& servers,
289 bool shared_socket_mode,
290 int interval_ms,
291 int num_request_per_ip,
292 int timeout_ms,
293 StunProber::Observer* observer) {
294 RTC_DCHECK(thread_checker_.IsCurrent());
295 interval_ms_ = interval_ms;
296 shared_socket_mode_ = shared_socket_mode;
297
298 requests_per_ip_ = num_request_per_ip;
299 if (requests_per_ip_ == 0 || servers.size() == 0) {
300 return false;
301 }
302
303 timeout_ms_ = timeout_ms;
304 servers_ = servers;
305 observer_ = observer;
306 // Remove addresses that are already resolved.
307 for (auto it = servers_.begin(); it != servers_.end();) {
308 if (it->ipaddr().family() != AF_UNSPEC) {
309 all_servers_addrs_.push_back(*it);
310 it = servers_.erase(it);
311 } else {
312 ++it;
313 }
314 }
315 if (servers_.empty()) {
316 CreateSockets();
317 return true;
318 }
319 return ResolveServerName(servers_.back());
320 }
321
Start(StunProber::Observer * observer)322 bool StunProber::Start(StunProber::Observer* observer) {
323 observer_ = observer;
324 if (total_ready_sockets_ != total_socket_required()) {
325 return false;
326 }
327 MaybeScheduleStunRequests();
328 return true;
329 }
330
ResolveServerName(const rtc::SocketAddress & addr)331 bool StunProber::ResolveServerName(const rtc::SocketAddress& addr) {
332 rtc::AsyncResolverInterface* resolver =
333 socket_factory_->CreateAsyncResolver();
334 if (!resolver) {
335 return false;
336 }
337 resolver->SignalDone.connect(this, &StunProber::OnServerResolved);
338 resolver->Start(addr);
339 return true;
340 }
341
OnSocketReady(rtc::AsyncPacketSocket * socket,const rtc::SocketAddress & addr)342 void StunProber::OnSocketReady(rtc::AsyncPacketSocket* socket,
343 const rtc::SocketAddress& addr) {
344 total_ready_sockets_++;
345 if (total_ready_sockets_ == total_socket_required()) {
346 ReportOnPrepared(SUCCESS);
347 }
348 }
349
OnServerResolved(rtc::AsyncResolverInterface * resolver)350 void StunProber::OnServerResolved(rtc::AsyncResolverInterface* resolver) {
351 RTC_DCHECK(thread_checker_.IsCurrent());
352
353 if (resolver->GetError() == 0) {
354 rtc::SocketAddress addr(resolver->address().ipaddr(),
355 resolver->address().port());
356 all_servers_addrs_.push_back(addr);
357 }
358
359 // Deletion of AsyncResolverInterface can't be done in OnResolveResult which
360 // handles SignalDone.
361 thread_->PostTask([resolver] { resolver->Destroy(false); });
362 servers_.pop_back();
363
364 if (servers_.size()) {
365 if (!ResolveServerName(servers_.back())) {
366 ReportOnPrepared(RESOLVE_FAILED);
367 }
368 return;
369 }
370
371 if (all_servers_addrs_.size() == 0) {
372 ReportOnPrepared(RESOLVE_FAILED);
373 return;
374 }
375
376 CreateSockets();
377 }
378
CreateSockets()379 void StunProber::CreateSockets() {
380 // Dedupe.
381 std::set<rtc::SocketAddress> addrs(all_servers_addrs_.begin(),
382 all_servers_addrs_.end());
383 all_servers_addrs_.assign(addrs.begin(), addrs.end());
384
385 // Prepare all the sockets beforehand. All of them will bind to "any" address.
386 while (sockets_.size() < total_socket_required()) {
387 std::unique_ptr<rtc::AsyncPacketSocket> socket(
388 socket_factory_->CreateUdpSocket(rtc::SocketAddress(INADDR_ANY, 0), 0,
389 0));
390 if (!socket) {
391 ReportOnPrepared(GENERIC_FAILURE);
392 return;
393 }
394 // Chrome and WebRTC behave differently in terms of the state of a socket
395 // once returned from PacketSocketFactory::CreateUdpSocket.
396 if (socket->GetState() == rtc::AsyncPacketSocket::STATE_BINDING) {
397 socket->SignalAddressReady.connect(this, &StunProber::OnSocketReady);
398 } else {
399 OnSocketReady(socket.get(), rtc::SocketAddress(INADDR_ANY, 0));
400 }
401 sockets_.push_back(socket.release());
402 }
403 }
404
CreateRequester()405 StunProber::Requester* StunProber::CreateRequester() {
406 RTC_DCHECK(thread_checker_.IsCurrent());
407 if (!sockets_.size()) {
408 return nullptr;
409 }
410 StunProber::Requester* requester;
411 if (shared_socket_mode_) {
412 requester = new Requester(this, sockets_.back(), all_servers_addrs_);
413 } else {
414 std::vector<rtc::SocketAddress> server_ip;
415 server_ip.push_back(
416 all_servers_addrs_[(num_request_sent_ % all_servers_addrs_.size())]);
417 requester = new Requester(this, sockets_.back(), server_ip);
418 }
419
420 sockets_.pop_back();
421 return requester;
422 }
423
SendNextRequest()424 bool StunProber::SendNextRequest() {
425 if (!current_requester_ || current_requester_->Done()) {
426 current_requester_ = CreateRequester();
427 requesters_.push_back(current_requester_);
428 }
429 if (!current_requester_) {
430 return false;
431 }
432 current_requester_->SendStunRequest();
433 num_request_sent_++;
434 return true;
435 }
436
should_send_next_request(int64_t now)437 bool StunProber::should_send_next_request(int64_t now) {
438 if (interval_ms_ < THREAD_WAKE_UP_INTERVAL_MS) {
439 return now >= next_request_time_ms_;
440 } else {
441 return (now + (THREAD_WAKE_UP_INTERVAL_MS / 2)) >= next_request_time_ms_;
442 }
443 }
444
get_wake_up_interval_ms()445 int StunProber::get_wake_up_interval_ms() {
446 if (interval_ms_ < THREAD_WAKE_UP_INTERVAL_MS) {
447 return 1;
448 } else {
449 return THREAD_WAKE_UP_INTERVAL_MS;
450 }
451 }
452
MaybeScheduleStunRequests()453 void StunProber::MaybeScheduleStunRequests() {
454 RTC_DCHECK_RUN_ON(thread_);
455 int64_t now = rtc::TimeMillis();
456
457 if (Done()) {
458 thread_->PostDelayedTask(
459 SafeTask(task_safety_.flag(), [this] { ReportOnFinished(SUCCESS); }),
460 TimeDelta::Millis(timeout_ms_));
461 return;
462 }
463 if (should_send_next_request(now)) {
464 if (!SendNextRequest()) {
465 ReportOnFinished(GENERIC_FAILURE);
466 return;
467 }
468 next_request_time_ms_ = now + interval_ms_;
469 }
470 thread_->PostDelayedTask(
471 SafeTask(task_safety_.flag(), [this] { MaybeScheduleStunRequests(); }),
472 TimeDelta::Millis(get_wake_up_interval_ms()));
473 }
474
GetStats(StunProber::Stats * prob_stats) const475 bool StunProber::GetStats(StunProber::Stats* prob_stats) const {
476 // No need to be on the same thread.
477 if (!prob_stats) {
478 return false;
479 }
480
481 StunProber::Stats stats;
482
483 int rtt_sum = 0;
484 int64_t first_sent_time = 0;
485 int64_t last_sent_time = 0;
486 NatType nat_type = NATTYPE_INVALID;
487
488 // Track of how many srflx IP that we have seen.
489 std::set<rtc::IPAddress> srflx_ips;
490
491 // If we're not receiving any response on a given IP, all requests sent to
492 // that IP should be ignored as this could just be an DNS error.
493 std::map<rtc::IPAddress, int> num_response_per_server;
494 std::map<rtc::IPAddress, int> num_request_per_server;
495
496 for (auto* requester : requesters_) {
497 std::map<rtc::SocketAddress, int> num_response_per_srflx_addr;
498 for (auto* request : requester->requests()) {
499 if (request->sent_time_ms <= 0) {
500 continue;
501 }
502
503 ++stats.raw_num_request_sent;
504 IncrementCounterByAddress(&num_request_per_server, request->server_addr);
505
506 if (!first_sent_time) {
507 first_sent_time = request->sent_time_ms;
508 }
509 last_sent_time = request->sent_time_ms;
510
511 if (request->received_time_ms < request->sent_time_ms) {
512 continue;
513 }
514
515 IncrementCounterByAddress(&num_response_per_server, request->server_addr);
516 IncrementCounterByAddress(&num_response_per_srflx_addr,
517 request->srflx_addr);
518 rtt_sum += request->rtt();
519 stats.srflx_addrs.insert(request->srflx_addr.ToString());
520 srflx_ips.insert(request->srflx_addr.ipaddr());
521 }
522
523 // If we're using shared mode and seeing >1 srflx addresses for a single
524 // requester, it's symmetric NAT.
525 if (shared_socket_mode_ && num_response_per_srflx_addr.size() > 1) {
526 nat_type = NATTYPE_SYMMETRIC;
527 }
528 }
529
530 // We're probably not behind a regular NAT. We have more than 1 distinct
531 // server reflexive IPs.
532 if (srflx_ips.size() > 1) {
533 return false;
534 }
535
536 int num_sent = 0;
537 int num_received = 0;
538 int num_server_ip_with_response = 0;
539
540 for (const auto& kv : num_response_per_server) {
541 RTC_DCHECK_GT(kv.second, 0);
542 num_server_ip_with_response++;
543 num_received += kv.second;
544 num_sent += num_request_per_server[kv.first];
545 }
546
547 // Shared mode is only true if we use the shared socket and there are more
548 // than 1 responding servers.
549 stats.shared_socket_mode =
550 shared_socket_mode_ && (num_server_ip_with_response > 1);
551
552 if (stats.shared_socket_mode && nat_type == NATTYPE_INVALID) {
553 nat_type = NATTYPE_NON_SYMMETRIC;
554 }
555
556 // If we could find a local IP matching srflx, we're not behind a NAT.
557 rtc::SocketAddress srflx_addr;
558 if (stats.srflx_addrs.size() &&
559 !srflx_addr.FromString(*(stats.srflx_addrs.begin()))) {
560 return false;
561 }
562 for (const auto* net : networks_) {
563 if (srflx_addr.ipaddr() == net->GetBestIP()) {
564 nat_type = stunprober::NATTYPE_NONE;
565 stats.host_ip = net->GetBestIP().ToString();
566 break;
567 }
568 }
569
570 // Finally, we know we're behind a NAT but can't determine which type it is.
571 if (nat_type == NATTYPE_INVALID) {
572 nat_type = NATTYPE_UNKNOWN;
573 }
574
575 stats.nat_type = nat_type;
576 stats.num_request_sent = num_sent;
577 stats.num_response_received = num_received;
578 stats.target_request_interval_ns = interval_ms_ * 1000;
579
580 if (num_sent) {
581 stats.success_percent = static_cast<int>(100 * num_received / num_sent);
582 }
583
584 if (stats.raw_num_request_sent > 1) {
585 stats.actual_request_interval_ns =
586 (1000 * (last_sent_time - first_sent_time)) /
587 (stats.raw_num_request_sent - 1);
588 }
589
590 if (num_received) {
591 stats.average_rtt_ms = static_cast<int>((rtt_sum / num_received));
592 }
593
594 *prob_stats = stats;
595 return true;
596 }
597
ReportOnPrepared(StunProber::Status status)598 void StunProber::ReportOnPrepared(StunProber::Status status) {
599 if (observer_) {
600 observer_->OnPrepared(this, status);
601 }
602 }
603
ReportOnFinished(StunProber::Status status)604 void StunProber::ReportOnFinished(StunProber::Status status) {
605 if (observer_) {
606 observer_->OnFinished(this, status);
607 }
608 }
609
610 } // namespace stunprober
611