xref: /aosp_15_r20/external/webrtc/p2p/base/fake_dtls_transport.h (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright 2017 The WebRTC Project Authors. All rights reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef P2P_BASE_FAKE_DTLS_TRANSPORT_H_
12 #define P2P_BASE_FAKE_DTLS_TRANSPORT_H_
13 
14 #include <memory>
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include "absl/strings/string_view.h"
20 #include "api/crypto/crypto_options.h"
21 #include "api/dtls_transport_interface.h"
22 #include "p2p/base/dtls_transport_internal.h"
23 #include "p2p/base/fake_ice_transport.h"
24 #include "rtc_base/fake_ssl_identity.h"
25 #include "rtc_base/rtc_certificate.h"
26 
27 namespace cricket {
28 
29 // Fake DTLS transport which is implemented by wrapping a fake ICE transport.
30 // Doesn't interact directly with fake ICE transport for anything other than
31 // sending packets.
32 class FakeDtlsTransport : public DtlsTransportInternal {
33  public:
FakeDtlsTransport(FakeIceTransport * ice_transport)34   explicit FakeDtlsTransport(FakeIceTransport* ice_transport)
35       : ice_transport_(ice_transport),
36         transport_name_(ice_transport->transport_name()),
37         component_(ice_transport->component()),
38         dtls_fingerprint_("", nullptr) {
39     RTC_DCHECK(ice_transport_);
40     ice_transport_->SignalReadPacket.connect(
41         this, &FakeDtlsTransport::OnIceTransportReadPacket);
42     ice_transport_->SignalNetworkRouteChanged.connect(
43         this, &FakeDtlsTransport::OnNetworkRouteChanged);
44   }
45 
FakeDtlsTransport(std::unique_ptr<FakeIceTransport> ice)46   explicit FakeDtlsTransport(std::unique_ptr<FakeIceTransport> ice)
47       : owned_ice_transport_(std::move(ice)),
48         transport_name_(owned_ice_transport_->transport_name()),
49         component_(owned_ice_transport_->component()),
50         dtls_fingerprint_("", rtc::ArrayView<const uint8_t>()) {
51     ice_transport_ = owned_ice_transport_.get();
52     ice_transport_->SignalReadPacket.connect(
53         this, &FakeDtlsTransport::OnIceTransportReadPacket);
54     ice_transport_->SignalNetworkRouteChanged.connect(
55         this, &FakeDtlsTransport::OnNetworkRouteChanged);
56   }
57 
58   // If this constructor is called, a new fake ICE transport will be created,
59   // and this FakeDtlsTransport will take the ownership.
FakeDtlsTransport(const std::string & name,int component)60   FakeDtlsTransport(const std::string& name, int component)
61       : FakeDtlsTransport(std::make_unique<FakeIceTransport>(name, component)) {
62   }
FakeDtlsTransport(const std::string & name,int component,rtc::Thread * network_thread)63   FakeDtlsTransport(const std::string& name,
64                     int component,
65                     rtc::Thread* network_thread)
66       : FakeDtlsTransport(std::make_unique<FakeIceTransport>(name,
67                                                              component,
68                                                              network_thread)) {}
69 
~FakeDtlsTransport()70   ~FakeDtlsTransport() override {
71     if (dest_ && dest_->dest_ == this) {
72       dest_->dest_ = nullptr;
73     }
74   }
75 
76   // Get inner fake ICE transport.
fake_ice_transport()77   FakeIceTransport* fake_ice_transport() { return ice_transport_; }
78 
79   // If async, will send packets by "Post"-ing to message queue instead of
80   // synchronously "Send"-ing.
SetAsync(bool async)81   void SetAsync(bool async) { ice_transport_->SetAsync(async); }
SetAsyncDelay(int delay_ms)82   void SetAsyncDelay(int delay_ms) { ice_transport_->SetAsyncDelay(delay_ms); }
83 
84   // SetWritable, SetReceiving and SetDestination are the main methods that can
85   // be used for testing, to simulate connectivity or lack thereof.
SetWritable(bool writable)86   void SetWritable(bool writable) {
87     ice_transport_->SetWritable(writable);
88     set_writable(writable);
89   }
SetReceiving(bool receiving)90   void SetReceiving(bool receiving) {
91     ice_transport_->SetReceiving(receiving);
92     set_receiving(receiving);
93   }
SetDtlsState(webrtc::DtlsTransportState state)94   void SetDtlsState(webrtc::DtlsTransportState state) {
95     dtls_state_ = state;
96     SendDtlsState(this, dtls_state_);
97   }
98 
99   // Simulates the two DTLS transports connecting to each other.
100   // If `asymmetric` is true this method only affects this FakeDtlsTransport.
101   // If false, it affects `dest` as well.
102   void SetDestination(FakeDtlsTransport* dest, bool asymmetric = false) {
103     if (dest == dest_) {
104       return;
105     }
106     RTC_DCHECK(!dest || !dest_)
107         << "Changing fake destination from one to another is not supported.";
108     if (dest && !dest_) {
109       // This simulates the DTLS handshake.
110       dest_ = dest;
111       if (local_cert_ && dest_->local_cert_) {
112         do_dtls_ = true;
113         RTC_LOG(LS_INFO) << "FakeDtlsTransport is doing DTLS";
114       } else {
115         do_dtls_ = false;
116         RTC_LOG(LS_INFO) << "FakeDtlsTransport is not doing DTLS";
117       }
118       SetWritable(true);
119       if (!asymmetric) {
120         dest->SetDestination(this, true);
121       }
122       // If the `dtls_role_` is unset, set it to SSL_CLIENT by default.
123       if (!dtls_role_) {
124         dtls_role_ = std::move(rtc::SSL_CLIENT);
125       }
126       SetDtlsState(webrtc::DtlsTransportState::kConnected);
127       ice_transport_->SetDestination(
128           static_cast<FakeIceTransport*>(dest->ice_transport()), asymmetric);
129     } else {
130       // Simulates loss of connectivity, by asymmetrically forgetting dest_.
131       dest_ = nullptr;
132       SetWritable(false);
133       ice_transport_->SetDestination(nullptr, asymmetric);
134     }
135   }
136 
137   // Fake DtlsTransportInternal implementation.
dtls_state()138   webrtc::DtlsTransportState dtls_state() const override { return dtls_state_; }
transport_name()139   const std::string& transport_name() const override { return transport_name_; }
component()140   int component() const override { return component_; }
dtls_fingerprint()141   const rtc::SSLFingerprint& dtls_fingerprint() const {
142     return dtls_fingerprint_;
143   }
SetRemoteParameters(absl::string_view alg,const uint8_t * digest,size_t digest_len,absl::optional<rtc::SSLRole> role)144   webrtc::RTCError SetRemoteParameters(absl::string_view alg,
145                                        const uint8_t* digest,
146                                        size_t digest_len,
147                                        absl::optional<rtc::SSLRole> role) {
148     if (role) {
149       SetDtlsRole(*role);
150     }
151     SetRemoteFingerprint(alg, digest, digest_len);
152     return webrtc::RTCError::OK();
153   }
SetRemoteFingerprint(absl::string_view alg,const uint8_t * digest,size_t digest_len)154   bool SetRemoteFingerprint(absl::string_view alg,
155                             const uint8_t* digest,
156                             size_t digest_len) {
157     dtls_fingerprint_ =
158         rtc::SSLFingerprint(alg, rtc::MakeArrayView(digest, digest_len));
159     return true;
160   }
SetDtlsRole(rtc::SSLRole role)161   bool SetDtlsRole(rtc::SSLRole role) override {
162     dtls_role_ = std::move(role);
163     return true;
164   }
GetDtlsRole(rtc::SSLRole * role)165   bool GetDtlsRole(rtc::SSLRole* role) const override {
166     if (!dtls_role_) {
167       return false;
168     }
169     *role = *dtls_role_;
170     return true;
171   }
SetLocalCertificate(const rtc::scoped_refptr<rtc::RTCCertificate> & certificate)172   bool SetLocalCertificate(
173       const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override {
174     do_dtls_ = true;
175     local_cert_ = certificate;
176     return true;
177   }
SetRemoteSSLCertificate(rtc::FakeSSLCertificate * cert)178   void SetRemoteSSLCertificate(rtc::FakeSSLCertificate* cert) {
179     remote_cert_ = cert;
180   }
IsDtlsActive()181   bool IsDtlsActive() const override { return do_dtls_; }
GetSslVersionBytes(int * version)182   bool GetSslVersionBytes(int* version) const override {
183     if (!do_dtls_) {
184       return false;
185     }
186     *version = 0x0102;
187     return true;
188   }
GetSrtpCryptoSuite(int * crypto_suite)189   bool GetSrtpCryptoSuite(int* crypto_suite) override {
190     if (!do_dtls_) {
191       return false;
192     }
193     *crypto_suite = crypto_suite_;
194     return true;
195   }
SetSrtpCryptoSuite(int crypto_suite)196   void SetSrtpCryptoSuite(int crypto_suite) { crypto_suite_ = crypto_suite; }
197 
GetSslCipherSuite(int * cipher_suite)198   bool GetSslCipherSuite(int* cipher_suite) override {
199     if (ssl_cipher_suite_) {
200       *cipher_suite = *ssl_cipher_suite_;
201       return true;
202     }
203     return false;
204   }
SetSslCipherSuite(absl::optional<int> cipher_suite)205   void SetSslCipherSuite(absl::optional<int> cipher_suite) {
206     ssl_cipher_suite_ = cipher_suite;
207   }
GetLocalCertificate()208   rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const override {
209     return local_cert_;
210   }
GetRemoteSSLCertChain()211   std::unique_ptr<rtc::SSLCertChain> GetRemoteSSLCertChain() const override {
212     if (!remote_cert_) {
213       return nullptr;
214     }
215     return std::make_unique<rtc::SSLCertChain>(remote_cert_->Clone());
216   }
ExportKeyingMaterial(absl::string_view label,const uint8_t * context,size_t context_len,bool use_context,uint8_t * result,size_t result_len)217   bool ExportKeyingMaterial(absl::string_view label,
218                             const uint8_t* context,
219                             size_t context_len,
220                             bool use_context,
221                             uint8_t* result,
222                             size_t result_len) override {
223     if (!do_dtls_) {
224       return false;
225     }
226     memset(result, 0xff, result_len);
227     return true;
228   }
set_ssl_max_protocol_version(rtc::SSLProtocolVersion version)229   void set_ssl_max_protocol_version(rtc::SSLProtocolVersion version) {
230     ssl_max_version_ = version;
231   }
ssl_max_protocol_version()232   rtc::SSLProtocolVersion ssl_max_protocol_version() const {
233     return ssl_max_version_;
234   }
235 
ice_transport()236   IceTransportInternal* ice_transport() override { return ice_transport_; }
237 
238   // PacketTransportInternal implementation, which passes through to fake ICE
239   // transport for sending actual packets.
writable()240   bool writable() const override { return writable_; }
receiving()241   bool receiving() const override { return receiving_; }
SendPacket(const char * data,size_t len,const rtc::PacketOptions & options,int flags)242   int SendPacket(const char* data,
243                  size_t len,
244                  const rtc::PacketOptions& options,
245                  int flags) override {
246     // We expect only SRTP packets to be sent through this interface.
247     if (flags != PF_SRTP_BYPASS && flags != 0) {
248       return -1;
249     }
250     return ice_transport_->SendPacket(data, len, options, flags);
251   }
SetOption(rtc::Socket::Option opt,int value)252   int SetOption(rtc::Socket::Option opt, int value) override {
253     return ice_transport_->SetOption(opt, value);
254   }
GetOption(rtc::Socket::Option opt,int * value)255   bool GetOption(rtc::Socket::Option opt, int* value) override {
256     return ice_transport_->GetOption(opt, value);
257   }
GetError()258   int GetError() override { return ice_transport_->GetError(); }
259 
network_route()260   absl::optional<rtc::NetworkRoute> network_route() const override {
261     return ice_transport_->network_route();
262   }
263 
264  private:
OnIceTransportReadPacket(PacketTransportInternal * ice_,const char * data,size_t len,const int64_t & packet_time_us,int flags)265   void OnIceTransportReadPacket(PacketTransportInternal* ice_,
266                                 const char* data,
267                                 size_t len,
268                                 const int64_t& packet_time_us,
269                                 int flags) {
270     SignalReadPacket(this, data, len, packet_time_us, flags);
271   }
272 
set_receiving(bool receiving)273   void set_receiving(bool receiving) {
274     if (receiving_ == receiving) {
275       return;
276     }
277     receiving_ = receiving;
278     SignalReceivingState(this);
279   }
280 
set_writable(bool writable)281   void set_writable(bool writable) {
282     if (writable_ == writable) {
283       return;
284     }
285     writable_ = writable;
286     if (writable_) {
287       SignalReadyToSend(this);
288     }
289     SignalWritableState(this);
290   }
291 
OnNetworkRouteChanged(absl::optional<rtc::NetworkRoute> network_route)292   void OnNetworkRouteChanged(absl::optional<rtc::NetworkRoute> network_route) {
293     SignalNetworkRouteChanged(network_route);
294   }
295 
296   FakeIceTransport* ice_transport_;
297   std::unique_ptr<FakeIceTransport> owned_ice_transport_;
298   std::string transport_name_;
299   int component_;
300   FakeDtlsTransport* dest_ = nullptr;
301   rtc::scoped_refptr<rtc::RTCCertificate> local_cert_;
302   rtc::FakeSSLCertificate* remote_cert_ = nullptr;
303   bool do_dtls_ = false;
304   rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_12;
305   rtc::SSLFingerprint dtls_fingerprint_;
306   absl::optional<rtc::SSLRole> dtls_role_;
307   int crypto_suite_ = rtc::kSrtpAes128CmSha1_80;
308   absl::optional<int> ssl_cipher_suite_;
309 
310   webrtc::DtlsTransportState dtls_state_ = webrtc::DtlsTransportState::kNew;
311 
312   bool receiving_ = false;
313   bool writable_ = false;
314 };
315 
316 }  // namespace cricket
317 
318 #endif  // P2P_BASE_FAKE_DTLS_TRANSPORT_H_
319