xref: /aosp_15_r20/external/webrtc/p2p/stunprober/stun_prober.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright 2015 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "p2p/stunprober/stun_prober.h"
12 
13 #include <map>
14 #include <memory>
15 #include <set>
16 #include <string>
17 #include <utility>
18 
19 #include "api/packet_socket_factory.h"
20 #include "api/task_queue/pending_task_safety_flag.h"
21 #include "api/transport/stun.h"
22 #include "api/units/time_delta.h"
23 #include "rtc_base/async_packet_socket.h"
24 #include "rtc_base/async_resolver_interface.h"
25 #include "rtc_base/checks.h"
26 #include "rtc_base/helpers.h"
27 #include "rtc_base/logging.h"
28 #include "rtc_base/thread.h"
29 #include "rtc_base/time_utils.h"
30 
31 namespace stunprober {
32 
33 namespace {
34 using ::webrtc::SafeTask;
35 using ::webrtc::TimeDelta;
36 
37 const int THREAD_WAKE_UP_INTERVAL_MS = 5;
38 
39 template <typename T>
IncrementCounterByAddress(std::map<T,int> * counter_per_ip,const T & ip)40 void IncrementCounterByAddress(std::map<T, int>* counter_per_ip, const T& ip) {
41   counter_per_ip->insert(std::make_pair(ip, 0)).first->second++;
42 }
43 
44 }  // namespace
45 
46 // A requester tracks the requests and responses from a single socket to many
47 // STUN servers
48 class StunProber::Requester : public sigslot::has_slots<> {
49  public:
50   // Each Request maps to a request and response.
51   struct Request {
52     // Actual time the STUN bind request was sent.
53     int64_t sent_time_ms = 0;
54     // Time the response was received.
55     int64_t received_time_ms = 0;
56 
57     // Server reflexive address from STUN response for this given request.
58     rtc::SocketAddress srflx_addr;
59 
60     rtc::IPAddress server_addr;
61 
rttstunprober::StunProber::Requester::Request62     int64_t rtt() { return received_time_ms - sent_time_ms; }
63     void ProcessResponse(const char* buf, size_t buf_len);
64   };
65 
66   // StunProber provides `server_ips` for Requester to probe. For shared
67   // socket mode, it'll be all the resolved IP addresses. For non-shared mode,
68   // it'll just be a single address.
69   Requester(StunProber* prober,
70             rtc::AsyncPacketSocket* socket,
71             const std::vector<rtc::SocketAddress>& server_ips);
72   ~Requester() override;
73 
74   Requester(const Requester&) = delete;
75   Requester& operator=(const Requester&) = delete;
76 
77   // There is no callback for SendStunRequest as the underneath socket send is
78   // expected to be completed immediately. Otherwise, it'll skip this request
79   // and move to the next one.
80   void SendStunRequest();
81 
82   void OnStunResponseReceived(rtc::AsyncPacketSocket* socket,
83                               const char* buf,
84                               size_t size,
85                               const rtc::SocketAddress& addr,
86                               const int64_t& packet_time_us);
87 
requests()88   const std::vector<Request*>& requests() { return requests_; }
89 
90   // Whether this Requester has completed all requests.
Done()91   bool Done() {
92     return static_cast<size_t>(num_request_sent_) == server_ips_.size();
93   }
94 
95  private:
96   Request* GetRequestByAddress(const rtc::IPAddress& ip);
97 
98   StunProber* prober_;
99 
100   // The socket for this session.
101   std::unique_ptr<rtc::AsyncPacketSocket> socket_;
102 
103   // Temporary SocketAddress and buffer for RecvFrom.
104   rtc::SocketAddress addr_;
105   std::unique_ptr<rtc::ByteBufferWriter> response_packet_;
106 
107   std::vector<Request*> requests_;
108   std::vector<rtc::SocketAddress> server_ips_;
109   int16_t num_request_sent_ = 0;
110   int16_t num_response_received_ = 0;
111 
112   webrtc::SequenceChecker& thread_checker_;
113 };
114 
Requester(StunProber * prober,rtc::AsyncPacketSocket * socket,const std::vector<rtc::SocketAddress> & server_ips)115 StunProber::Requester::Requester(
116     StunProber* prober,
117     rtc::AsyncPacketSocket* socket,
118     const std::vector<rtc::SocketAddress>& server_ips)
119     : prober_(prober),
120       socket_(socket),
121       response_packet_(new rtc::ByteBufferWriter(nullptr, kMaxUdpBufferSize)),
122       server_ips_(server_ips),
123       thread_checker_(prober->thread_checker_) {
124   socket_->SignalReadPacket.connect(
125       this, &StunProber::Requester::OnStunResponseReceived);
126 }
127 
~Requester()128 StunProber::Requester::~Requester() {
129   if (socket_) {
130     socket_->Close();
131   }
132   for (auto* req : requests_) {
133     if (req) {
134       delete req;
135     }
136   }
137 }
138 
SendStunRequest()139 void StunProber::Requester::SendStunRequest() {
140   RTC_DCHECK(thread_checker_.IsCurrent());
141   requests_.push_back(new Request());
142   Request& request = *(requests_.back());
143   // Random transaction ID, STUN_BINDING_REQUEST
144   cricket::StunMessage message(cricket::STUN_BINDING_REQUEST);
145 
146   std::unique_ptr<rtc::ByteBufferWriter> request_packet(
147       new rtc::ByteBufferWriter(nullptr, kMaxUdpBufferSize));
148   if (!message.Write(request_packet.get())) {
149     prober_->ReportOnFinished(WRITE_FAILED);
150     return;
151   }
152 
153   auto addr = server_ips_[num_request_sent_];
154   request.server_addr = addr.ipaddr();
155 
156   // The write must succeed immediately. Otherwise, the calculating of the STUN
157   // request timing could become too complicated. Callback is ignored by passing
158   // empty AsyncCallback.
159   rtc::PacketOptions options;
160   int rv = socket_->SendTo(const_cast<char*>(request_packet->Data()),
161                            request_packet->Length(), addr, options);
162   if (rv < 0) {
163     prober_->ReportOnFinished(WRITE_FAILED);
164     return;
165   }
166 
167   request.sent_time_ms = rtc::TimeMillis();
168 
169   num_request_sent_++;
170   RTC_DCHECK(static_cast<size_t>(num_request_sent_) <= server_ips_.size());
171 }
172 
ProcessResponse(const char * buf,size_t buf_len)173 void StunProber::Requester::Request::ProcessResponse(const char* buf,
174                                                      size_t buf_len) {
175   int64_t now = rtc::TimeMillis();
176   rtc::ByteBufferReader message(buf, buf_len);
177   cricket::StunMessage stun_response;
178   if (!stun_response.Read(&message)) {
179     // Invalid or incomplete STUN packet.
180     received_time_ms = 0;
181     return;
182   }
183 
184   // Get external address of the socket.
185   const cricket::StunAddressAttribute* addr_attr =
186       stun_response.GetAddress(cricket::STUN_ATTR_MAPPED_ADDRESS);
187   if (addr_attr == nullptr) {
188     // Addresses not available to detect whether or not behind a NAT.
189     return;
190   }
191 
192   if (addr_attr->family() != cricket::STUN_ADDRESS_IPV4 &&
193       addr_attr->family() != cricket::STUN_ADDRESS_IPV6) {
194     return;
195   }
196 
197   received_time_ms = now;
198 
199   srflx_addr = addr_attr->GetAddress();
200 }
201 
OnStunResponseReceived(rtc::AsyncPacketSocket * socket,const char * buf,size_t size,const rtc::SocketAddress & addr,const int64_t &)202 void StunProber::Requester::OnStunResponseReceived(
203     rtc::AsyncPacketSocket* socket,
204     const char* buf,
205     size_t size,
206     const rtc::SocketAddress& addr,
207     const int64_t& /* packet_time_us */) {
208   RTC_DCHECK(thread_checker_.IsCurrent());
209   RTC_DCHECK(socket_);
210   Request* request = GetRequestByAddress(addr.ipaddr());
211   if (!request) {
212     // Something is wrong, finish the test.
213     prober_->ReportOnFinished(GENERIC_FAILURE);
214     return;
215   }
216 
217   num_response_received_++;
218   request->ProcessResponse(buf, size);
219 }
220 
GetRequestByAddress(const rtc::IPAddress & ipaddr)221 StunProber::Requester::Request* StunProber::Requester::GetRequestByAddress(
222     const rtc::IPAddress& ipaddr) {
223   RTC_DCHECK(thread_checker_.IsCurrent());
224   for (auto* request : requests_) {
225     if (request->server_addr == ipaddr) {
226       return request;
227     }
228   }
229 
230   return nullptr;
231 }
232 
233 StunProber::Stats::Stats() = default;
234 
235 StunProber::Stats::~Stats() = default;
236 
237 StunProber::ObserverAdapter::ObserverAdapter() = default;
238 
239 StunProber::ObserverAdapter::~ObserverAdapter() = default;
240 
OnPrepared(StunProber * stunprober,Status status)241 void StunProber::ObserverAdapter::OnPrepared(StunProber* stunprober,
242                                              Status status) {
243   if (status == SUCCESS) {
244     stunprober->Start(this);
245   } else {
246     callback_(stunprober, status);
247   }
248 }
249 
OnFinished(StunProber * stunprober,Status status)250 void StunProber::ObserverAdapter::OnFinished(StunProber* stunprober,
251                                              Status status) {
252   callback_(stunprober, status);
253 }
254 
StunProber(rtc::PacketSocketFactory * socket_factory,rtc::Thread * thread,std::vector<const rtc::Network * > networks)255 StunProber::StunProber(rtc::PacketSocketFactory* socket_factory,
256                        rtc::Thread* thread,
257                        std::vector<const rtc::Network*> networks)
258     : interval_ms_(0),
259       socket_factory_(socket_factory),
260       thread_(thread),
261       networks_(std::move(networks)) {}
262 
~StunProber()263 StunProber::~StunProber() {
264   RTC_DCHECK(thread_checker_.IsCurrent());
265   for (auto* req : requesters_) {
266     if (req) {
267       delete req;
268     }
269   }
270   for (auto* s : sockets_) {
271     if (s) {
272       delete s;
273     }
274   }
275 }
276 
Start(const std::vector<rtc::SocketAddress> & servers,bool shared_socket_mode,int interval_ms,int num_request_per_ip,int timeout_ms,const AsyncCallback callback)277 bool StunProber::Start(const std::vector<rtc::SocketAddress>& servers,
278                        bool shared_socket_mode,
279                        int interval_ms,
280                        int num_request_per_ip,
281                        int timeout_ms,
282                        const AsyncCallback callback) {
283   observer_adapter_.set_callback(callback);
284   return Prepare(servers, shared_socket_mode, interval_ms, num_request_per_ip,
285                  timeout_ms, &observer_adapter_);
286 }
287 
Prepare(const std::vector<rtc::SocketAddress> & servers,bool shared_socket_mode,int interval_ms,int num_request_per_ip,int timeout_ms,StunProber::Observer * observer)288 bool StunProber::Prepare(const std::vector<rtc::SocketAddress>& servers,
289                          bool shared_socket_mode,
290                          int interval_ms,
291                          int num_request_per_ip,
292                          int timeout_ms,
293                          StunProber::Observer* observer) {
294   RTC_DCHECK(thread_checker_.IsCurrent());
295   interval_ms_ = interval_ms;
296   shared_socket_mode_ = shared_socket_mode;
297 
298   requests_per_ip_ = num_request_per_ip;
299   if (requests_per_ip_ == 0 || servers.size() == 0) {
300     return false;
301   }
302 
303   timeout_ms_ = timeout_ms;
304   servers_ = servers;
305   observer_ = observer;
306   // Remove addresses that are already resolved.
307   for (auto it = servers_.begin(); it != servers_.end();) {
308     if (it->ipaddr().family() != AF_UNSPEC) {
309       all_servers_addrs_.push_back(*it);
310       it = servers_.erase(it);
311     } else {
312       ++it;
313     }
314   }
315   if (servers_.empty()) {
316     CreateSockets();
317     return true;
318   }
319   return ResolveServerName(servers_.back());
320 }
321 
Start(StunProber::Observer * observer)322 bool StunProber::Start(StunProber::Observer* observer) {
323   observer_ = observer;
324   if (total_ready_sockets_ != total_socket_required()) {
325     return false;
326   }
327   MaybeScheduleStunRequests();
328   return true;
329 }
330 
ResolveServerName(const rtc::SocketAddress & addr)331 bool StunProber::ResolveServerName(const rtc::SocketAddress& addr) {
332   rtc::AsyncResolverInterface* resolver =
333       socket_factory_->CreateAsyncResolver();
334   if (!resolver) {
335     return false;
336   }
337   resolver->SignalDone.connect(this, &StunProber::OnServerResolved);
338   resolver->Start(addr);
339   return true;
340 }
341 
OnSocketReady(rtc::AsyncPacketSocket * socket,const rtc::SocketAddress & addr)342 void StunProber::OnSocketReady(rtc::AsyncPacketSocket* socket,
343                                const rtc::SocketAddress& addr) {
344   total_ready_sockets_++;
345   if (total_ready_sockets_ == total_socket_required()) {
346     ReportOnPrepared(SUCCESS);
347   }
348 }
349 
OnServerResolved(rtc::AsyncResolverInterface * resolver)350 void StunProber::OnServerResolved(rtc::AsyncResolverInterface* resolver) {
351   RTC_DCHECK(thread_checker_.IsCurrent());
352 
353   if (resolver->GetError() == 0) {
354     rtc::SocketAddress addr(resolver->address().ipaddr(),
355                             resolver->address().port());
356     all_servers_addrs_.push_back(addr);
357   }
358 
359   // Deletion of AsyncResolverInterface can't be done in OnResolveResult which
360   // handles SignalDone.
361   thread_->PostTask([resolver] { resolver->Destroy(false); });
362   servers_.pop_back();
363 
364   if (servers_.size()) {
365     if (!ResolveServerName(servers_.back())) {
366       ReportOnPrepared(RESOLVE_FAILED);
367     }
368     return;
369   }
370 
371   if (all_servers_addrs_.size() == 0) {
372     ReportOnPrepared(RESOLVE_FAILED);
373     return;
374   }
375 
376   CreateSockets();
377 }
378 
CreateSockets()379 void StunProber::CreateSockets() {
380   // Dedupe.
381   std::set<rtc::SocketAddress> addrs(all_servers_addrs_.begin(),
382                                      all_servers_addrs_.end());
383   all_servers_addrs_.assign(addrs.begin(), addrs.end());
384 
385   // Prepare all the sockets beforehand. All of them will bind to "any" address.
386   while (sockets_.size() < total_socket_required()) {
387     std::unique_ptr<rtc::AsyncPacketSocket> socket(
388         socket_factory_->CreateUdpSocket(rtc::SocketAddress(INADDR_ANY, 0), 0,
389                                          0));
390     if (!socket) {
391       ReportOnPrepared(GENERIC_FAILURE);
392       return;
393     }
394     // Chrome and WebRTC behave differently in terms of the state of a socket
395     // once returned from PacketSocketFactory::CreateUdpSocket.
396     if (socket->GetState() == rtc::AsyncPacketSocket::STATE_BINDING) {
397       socket->SignalAddressReady.connect(this, &StunProber::OnSocketReady);
398     } else {
399       OnSocketReady(socket.get(), rtc::SocketAddress(INADDR_ANY, 0));
400     }
401     sockets_.push_back(socket.release());
402   }
403 }
404 
CreateRequester()405 StunProber::Requester* StunProber::CreateRequester() {
406   RTC_DCHECK(thread_checker_.IsCurrent());
407   if (!sockets_.size()) {
408     return nullptr;
409   }
410   StunProber::Requester* requester;
411   if (shared_socket_mode_) {
412     requester = new Requester(this, sockets_.back(), all_servers_addrs_);
413   } else {
414     std::vector<rtc::SocketAddress> server_ip;
415     server_ip.push_back(
416         all_servers_addrs_[(num_request_sent_ % all_servers_addrs_.size())]);
417     requester = new Requester(this, sockets_.back(), server_ip);
418   }
419 
420   sockets_.pop_back();
421   return requester;
422 }
423 
SendNextRequest()424 bool StunProber::SendNextRequest() {
425   if (!current_requester_ || current_requester_->Done()) {
426     current_requester_ = CreateRequester();
427     requesters_.push_back(current_requester_);
428   }
429   if (!current_requester_) {
430     return false;
431   }
432   current_requester_->SendStunRequest();
433   num_request_sent_++;
434   return true;
435 }
436 
should_send_next_request(int64_t now)437 bool StunProber::should_send_next_request(int64_t now) {
438   if (interval_ms_ < THREAD_WAKE_UP_INTERVAL_MS) {
439     return now >= next_request_time_ms_;
440   } else {
441     return (now + (THREAD_WAKE_UP_INTERVAL_MS / 2)) >= next_request_time_ms_;
442   }
443 }
444 
get_wake_up_interval_ms()445 int StunProber::get_wake_up_interval_ms() {
446   if (interval_ms_ < THREAD_WAKE_UP_INTERVAL_MS) {
447     return 1;
448   } else {
449     return THREAD_WAKE_UP_INTERVAL_MS;
450   }
451 }
452 
MaybeScheduleStunRequests()453 void StunProber::MaybeScheduleStunRequests() {
454   RTC_DCHECK_RUN_ON(thread_);
455   int64_t now = rtc::TimeMillis();
456 
457   if (Done()) {
458     thread_->PostDelayedTask(
459         SafeTask(task_safety_.flag(), [this] { ReportOnFinished(SUCCESS); }),
460         TimeDelta::Millis(timeout_ms_));
461     return;
462   }
463   if (should_send_next_request(now)) {
464     if (!SendNextRequest()) {
465       ReportOnFinished(GENERIC_FAILURE);
466       return;
467     }
468     next_request_time_ms_ = now + interval_ms_;
469   }
470   thread_->PostDelayedTask(
471       SafeTask(task_safety_.flag(), [this] { MaybeScheduleStunRequests(); }),
472       TimeDelta::Millis(get_wake_up_interval_ms()));
473 }
474 
GetStats(StunProber::Stats * prob_stats) const475 bool StunProber::GetStats(StunProber::Stats* prob_stats) const {
476   // No need to be on the same thread.
477   if (!prob_stats) {
478     return false;
479   }
480 
481   StunProber::Stats stats;
482 
483   int rtt_sum = 0;
484   int64_t first_sent_time = 0;
485   int64_t last_sent_time = 0;
486   NatType nat_type = NATTYPE_INVALID;
487 
488   // Track of how many srflx IP that we have seen.
489   std::set<rtc::IPAddress> srflx_ips;
490 
491   // If we're not receiving any response on a given IP, all requests sent to
492   // that IP should be ignored as this could just be an DNS error.
493   std::map<rtc::IPAddress, int> num_response_per_server;
494   std::map<rtc::IPAddress, int> num_request_per_server;
495 
496   for (auto* requester : requesters_) {
497     std::map<rtc::SocketAddress, int> num_response_per_srflx_addr;
498     for (auto* request : requester->requests()) {
499       if (request->sent_time_ms <= 0) {
500         continue;
501       }
502 
503       ++stats.raw_num_request_sent;
504       IncrementCounterByAddress(&num_request_per_server, request->server_addr);
505 
506       if (!first_sent_time) {
507         first_sent_time = request->sent_time_ms;
508       }
509       last_sent_time = request->sent_time_ms;
510 
511       if (request->received_time_ms < request->sent_time_ms) {
512         continue;
513       }
514 
515       IncrementCounterByAddress(&num_response_per_server, request->server_addr);
516       IncrementCounterByAddress(&num_response_per_srflx_addr,
517                                 request->srflx_addr);
518       rtt_sum += request->rtt();
519       stats.srflx_addrs.insert(request->srflx_addr.ToString());
520       srflx_ips.insert(request->srflx_addr.ipaddr());
521     }
522 
523     // If we're using shared mode and seeing >1 srflx addresses for a single
524     // requester, it's symmetric NAT.
525     if (shared_socket_mode_ && num_response_per_srflx_addr.size() > 1) {
526       nat_type = NATTYPE_SYMMETRIC;
527     }
528   }
529 
530   // We're probably not behind a regular NAT. We have more than 1 distinct
531   // server reflexive IPs.
532   if (srflx_ips.size() > 1) {
533     return false;
534   }
535 
536   int num_sent = 0;
537   int num_received = 0;
538   int num_server_ip_with_response = 0;
539 
540   for (const auto& kv : num_response_per_server) {
541     RTC_DCHECK_GT(kv.second, 0);
542     num_server_ip_with_response++;
543     num_received += kv.second;
544     num_sent += num_request_per_server[kv.first];
545   }
546 
547   // Shared mode is only true if we use the shared socket and there are more
548   // than 1 responding servers.
549   stats.shared_socket_mode =
550       shared_socket_mode_ && (num_server_ip_with_response > 1);
551 
552   if (stats.shared_socket_mode && nat_type == NATTYPE_INVALID) {
553     nat_type = NATTYPE_NON_SYMMETRIC;
554   }
555 
556   // If we could find a local IP matching srflx, we're not behind a NAT.
557   rtc::SocketAddress srflx_addr;
558   if (stats.srflx_addrs.size() &&
559       !srflx_addr.FromString(*(stats.srflx_addrs.begin()))) {
560     return false;
561   }
562   for (const auto* net : networks_) {
563     if (srflx_addr.ipaddr() == net->GetBestIP()) {
564       nat_type = stunprober::NATTYPE_NONE;
565       stats.host_ip = net->GetBestIP().ToString();
566       break;
567     }
568   }
569 
570   // Finally, we know we're behind a NAT but can't determine which type it is.
571   if (nat_type == NATTYPE_INVALID) {
572     nat_type = NATTYPE_UNKNOWN;
573   }
574 
575   stats.nat_type = nat_type;
576   stats.num_request_sent = num_sent;
577   stats.num_response_received = num_received;
578   stats.target_request_interval_ns = interval_ms_ * 1000;
579 
580   if (num_sent) {
581     stats.success_percent = static_cast<int>(100 * num_received / num_sent);
582   }
583 
584   if (stats.raw_num_request_sent > 1) {
585     stats.actual_request_interval_ns =
586         (1000 * (last_sent_time - first_sent_time)) /
587         (stats.raw_num_request_sent - 1);
588   }
589 
590   if (num_received) {
591     stats.average_rtt_ms = static_cast<int>((rtt_sum / num_received));
592   }
593 
594   *prob_stats = stats;
595   return true;
596 }
597 
ReportOnPrepared(StunProber::Status status)598 void StunProber::ReportOnPrepared(StunProber::Status status) {
599   if (observer_) {
600     observer_->OnPrepared(this, status);
601   }
602 }
603 
ReportOnFinished(StunProber::Status status)604 void StunProber::ReportOnFinished(StunProber::Status status) {
605   if (observer_) {
606     observer_->OnFinished(this, status);
607   }
608 }
609 
610 }  // namespace stunprober
611