xref: /aosp_15_r20/external/webrtc/pc/test/mock_peer_connection_observers.h (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright 2012 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 // This file contains mock implementations of observers used in PeerConnection.
12 // TODO(steveanton): These aren't really mocks and should be renamed.
13 
14 #ifndef PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_
15 #define PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_
16 
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "api/data_channel_interface.h"
24 #include "api/jsep_ice_candidate.h"
25 #include "pc/stream_collection.h"
26 #include "rtc_base/checks.h"
27 
28 namespace webrtc {
29 
30 class MockPeerConnectionObserver : public PeerConnectionObserver {
31  public:
32   struct AddTrackEvent {
AddTrackEventAddTrackEvent33     explicit AddTrackEvent(
34         rtc::scoped_refptr<RtpReceiverInterface> event_receiver,
35         std::vector<rtc::scoped_refptr<MediaStreamInterface>> event_streams)
36         : receiver(std::move(event_receiver)),
37           streams(std::move(event_streams)) {
38       for (auto stream : streams) {
39         std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>> tracks;
40         for (auto audio_track : stream->GetAudioTracks()) {
41           tracks.push_back(audio_track);
42         }
43         for (auto video_track : stream->GetVideoTracks()) {
44           tracks.push_back(video_track);
45         }
46         snapshotted_stream_tracks[stream] = tracks;
47       }
48     }
49 
50     rtc::scoped_refptr<RtpReceiverInterface> receiver;
51     std::vector<rtc::scoped_refptr<MediaStreamInterface>> streams;
52     // This map records the tracks present in each stream at the time the
53     // OnAddTrack callback was issued.
54     std::map<rtc::scoped_refptr<MediaStreamInterface>,
55              std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>>>
56         snapshotted_stream_tracks;
57   };
58 
MockPeerConnectionObserver()59   MockPeerConnectionObserver() : remote_streams_(StreamCollection::Create()) {}
~MockPeerConnectionObserver()60   virtual ~MockPeerConnectionObserver() {}
SetPeerConnectionInterface(PeerConnectionInterface * pc)61   void SetPeerConnectionInterface(PeerConnectionInterface* pc) {
62     pc_ = pc;
63     if (pc) {
64       state_ = pc_->signaling_state();
65     }
66   }
OnSignalingChange(PeerConnectionInterface::SignalingState new_state)67   void OnSignalingChange(
68       PeerConnectionInterface::SignalingState new_state) override {
69     RTC_DCHECK(pc_);
70     RTC_DCHECK(pc_->signaling_state() == new_state);
71     state_ = new_state;
72   }
73 
RemoteStream(const std::string & label)74   MediaStreamInterface* RemoteStream(const std::string& label) {
75     return remote_streams_->find(label);
76   }
remote_streams()77   StreamCollectionInterface* remote_streams() const {
78     return remote_streams_.get();
79   }
OnAddStream(rtc::scoped_refptr<MediaStreamInterface> stream)80   void OnAddStream(rtc::scoped_refptr<MediaStreamInterface> stream) override {
81     last_added_stream_ = stream;
82     remote_streams_->AddStream(stream);
83   }
OnRemoveStream(rtc::scoped_refptr<MediaStreamInterface> stream)84   void OnRemoveStream(
85       rtc::scoped_refptr<MediaStreamInterface> stream) override {
86     last_removed_stream_ = stream;
87     remote_streams_->RemoveStream(stream.get());
88   }
OnRenegotiationNeeded()89   void OnRenegotiationNeeded() override { renegotiation_needed_ = true; }
OnNegotiationNeededEvent(uint32_t event_id)90   void OnNegotiationNeededEvent(uint32_t event_id) override {
91     latest_negotiation_needed_event_ = event_id;
92   }
OnDataChannel(rtc::scoped_refptr<DataChannelInterface> data_channel)93   void OnDataChannel(
94       rtc::scoped_refptr<DataChannelInterface> data_channel) override {
95     last_datachannel_ = data_channel;
96   }
97 
OnIceConnectionChange(PeerConnectionInterface::IceConnectionState new_state)98   void OnIceConnectionChange(
99       PeerConnectionInterface::IceConnectionState new_state) override {
100     RTC_DCHECK(pc_);
101     RTC_DCHECK(pc_->ice_connection_state() == new_state);
102     // When ICE is finished, the caller will get to a kIceConnectionCompleted
103     // state, because it has the ICE controlling role, while the callee
104     // will get to a kIceConnectionConnected state. This means that both ICE
105     // and DTLS are connected.
106     ice_connected_ =
107         (new_state == PeerConnectionInterface::kIceConnectionConnected) ||
108         (new_state == PeerConnectionInterface::kIceConnectionCompleted);
109     callback_triggered_ = true;
110   }
OnIceGatheringChange(PeerConnectionInterface::IceGatheringState new_state)111   void OnIceGatheringChange(
112       PeerConnectionInterface::IceGatheringState new_state) override {
113     RTC_DCHECK(pc_);
114     RTC_DCHECK(pc_->ice_gathering_state() == new_state);
115     ice_gathering_complete_ =
116         new_state == PeerConnectionInterface::kIceGatheringComplete;
117     callback_triggered_ = true;
118   }
OnIceCandidate(const IceCandidateInterface * candidate)119   void OnIceCandidate(const IceCandidateInterface* candidate) override {
120     RTC_DCHECK(pc_);
121     candidates_.push_back(std::make_unique<JsepIceCandidate>(
122         candidate->sdp_mid(), candidate->sdp_mline_index(),
123         candidate->candidate()));
124     callback_triggered_ = true;
125   }
126 
OnIceCandidatesRemoved(const std::vector<cricket::Candidate> & candidates)127   void OnIceCandidatesRemoved(
128       const std::vector<cricket::Candidate>& candidates) override {
129     num_candidates_removed_++;
130     callback_triggered_ = true;
131   }
132 
OnIceConnectionReceivingChange(bool receiving)133   void OnIceConnectionReceivingChange(bool receiving) override {
134     callback_triggered_ = true;
135   }
136 
OnAddTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver,const std::vector<rtc::scoped_refptr<MediaStreamInterface>> & streams)137   void OnAddTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver,
138                   const std::vector<rtc::scoped_refptr<MediaStreamInterface>>&
139                       streams) override {
140     RTC_DCHECK(receiver);
141     num_added_tracks_++;
142     last_added_track_label_ = receiver->id();
143     add_track_events_.push_back(AddTrackEvent(receiver, streams));
144   }
145 
OnTrack(rtc::scoped_refptr<RtpTransceiverInterface> transceiver)146   void OnTrack(
147       rtc::scoped_refptr<RtpTransceiverInterface> transceiver) override {
148     on_track_transceivers_.push_back(transceiver);
149   }
150 
OnRemoveTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver)151   void OnRemoveTrack(
152       rtc::scoped_refptr<RtpReceiverInterface> receiver) override {
153     remove_track_events_.push_back(receiver);
154   }
155 
GetAddTrackReceivers()156   std::vector<rtc::scoped_refptr<RtpReceiverInterface>> GetAddTrackReceivers() {
157     std::vector<rtc::scoped_refptr<RtpReceiverInterface>> receivers;
158     for (const AddTrackEvent& event : add_track_events_) {
159       receivers.push_back(event.receiver);
160     }
161     return receivers;
162   }
163 
CountAddTrackEventsForStream(const std::string & stream_id)164   int CountAddTrackEventsForStream(const std::string& stream_id) {
165     int found_tracks = 0;
166     for (const AddTrackEvent& event : add_track_events_) {
167       bool has_stream_id = false;
168       for (auto stream : event.streams) {
169         if (stream->id() == stream_id) {
170           has_stream_id = true;
171           break;
172         }
173       }
174       if (has_stream_id) {
175         ++found_tracks;
176       }
177     }
178     return found_tracks;
179   }
180 
181   // Returns the id of the last added stream.
182   // Empty string if no stream have been added.
GetLastAddedStreamId()183   std::string GetLastAddedStreamId() {
184     if (last_added_stream_.get())
185       return last_added_stream_->id();
186     return "";
187   }
GetLastRemovedStreamId()188   std::string GetLastRemovedStreamId() {
189     if (last_removed_stream_.get())
190       return last_removed_stream_->id();
191     return "";
192   }
193 
last_candidate()194   IceCandidateInterface* last_candidate() {
195     if (candidates_.empty()) {
196       return nullptr;
197     } else {
198       return candidates_.back().get();
199     }
200   }
201 
GetAllCandidates()202   std::vector<const IceCandidateInterface*> GetAllCandidates() {
203     std::vector<const IceCandidateInterface*> candidates;
204     for (const auto& candidate : candidates_) {
205       candidates.push_back(candidate.get());
206     }
207     return candidates;
208   }
209 
GetCandidatesByMline(int mline_index)210   std::vector<IceCandidateInterface*> GetCandidatesByMline(int mline_index) {
211     std::vector<IceCandidateInterface*> candidates;
212     for (const auto& candidate : candidates_) {
213       if (candidate->sdp_mline_index() == mline_index) {
214         candidates.push_back(candidate.get());
215       }
216     }
217     return candidates;
218   }
219 
legacy_renegotiation_needed()220   bool legacy_renegotiation_needed() const { return renegotiation_needed_; }
clear_legacy_renegotiation_needed()221   void clear_legacy_renegotiation_needed() { renegotiation_needed_ = false; }
222 
has_negotiation_needed_event()223   bool has_negotiation_needed_event() {
224     return latest_negotiation_needed_event_.has_value();
225   }
latest_negotiation_needed_event()226   uint32_t latest_negotiation_needed_event() {
227     return latest_negotiation_needed_event_.value_or(0u);
228   }
clear_latest_negotiation_needed_event()229   void clear_latest_negotiation_needed_event() {
230     latest_negotiation_needed_event_ = absl::nullopt;
231   }
232 
233   rtc::scoped_refptr<PeerConnectionInterface> pc_;
234   PeerConnectionInterface::SignalingState state_;
235   std::vector<std::unique_ptr<IceCandidateInterface>> candidates_;
236   rtc::scoped_refptr<DataChannelInterface> last_datachannel_;
237   rtc::scoped_refptr<StreamCollection> remote_streams_;
238   bool renegotiation_needed_ = false;
239   absl::optional<uint32_t> latest_negotiation_needed_event_;
240   bool ice_gathering_complete_ = false;
241   bool ice_connected_ = false;
242   bool callback_triggered_ = false;
243   int num_added_tracks_ = 0;
244   std::string last_added_track_label_;
245   std::vector<AddTrackEvent> add_track_events_;
246   std::vector<rtc::scoped_refptr<RtpReceiverInterface>> remove_track_events_;
247   std::vector<rtc::scoped_refptr<RtpTransceiverInterface>>
248       on_track_transceivers_;
249   int num_candidates_removed_ = 0;
250 
251  private:
252   rtc::scoped_refptr<MediaStreamInterface> last_added_stream_;
253   rtc::scoped_refptr<MediaStreamInterface> last_removed_stream_;
254 };
255 
256 class MockCreateSessionDescriptionObserver
257     : public webrtc::CreateSessionDescriptionObserver {
258  public:
MockCreateSessionDescriptionObserver()259   MockCreateSessionDescriptionObserver()
260       : called_(false),
261         error_("MockCreateSessionDescriptionObserver not called") {}
~MockCreateSessionDescriptionObserver()262   virtual ~MockCreateSessionDescriptionObserver() {}
OnSuccess(SessionDescriptionInterface * desc)263   void OnSuccess(SessionDescriptionInterface* desc) override {
264     MutexLock lock(&mutex_);
265     called_ = true;
266     error_ = "";
267     desc_.reset(desc);
268   }
OnFailure(webrtc::RTCError error)269   void OnFailure(webrtc::RTCError error) override {
270     MutexLock lock(&mutex_);
271     called_ = true;
272     error_ = error.message();
273   }
called()274   bool called() const {
275     MutexLock lock(&mutex_);
276     return called_;
277   }
result()278   bool result() const {
279     MutexLock lock(&mutex_);
280     return error_.empty();
281   }
error()282   const std::string& error() const {
283     MutexLock lock(&mutex_);
284     return error_;
285   }
MoveDescription()286   std::unique_ptr<SessionDescriptionInterface> MoveDescription() {
287     MutexLock lock(&mutex_);
288     return std::move(desc_);
289   }
290 
291  private:
292   mutable Mutex mutex_;
293   bool called_ RTC_GUARDED_BY(mutex_);
294   std::string error_ RTC_GUARDED_BY(mutex_);
295   std::unique_ptr<SessionDescriptionInterface> desc_ RTC_GUARDED_BY(mutex_);
296 };
297 
298 class MockSetSessionDescriptionObserver
299     : public webrtc::SetSessionDescriptionObserver {
300  public:
Create()301   static rtc::scoped_refptr<MockSetSessionDescriptionObserver> Create() {
302     return rtc::make_ref_counted<MockSetSessionDescriptionObserver>();
303   }
304 
MockSetSessionDescriptionObserver()305   MockSetSessionDescriptionObserver()
306       : called_(false),
307         error_("MockSetSessionDescriptionObserver not called") {}
~MockSetSessionDescriptionObserver()308   ~MockSetSessionDescriptionObserver() override {}
OnSuccess()309   void OnSuccess() override {
310     MutexLock lock(&mutex_);
311 
312     called_ = true;
313     error_ = "";
314   }
OnFailure(webrtc::RTCError error)315   void OnFailure(webrtc::RTCError error) override {
316     MutexLock lock(&mutex_);
317     called_ = true;
318     error_ = error.message();
319   }
320 
called()321   bool called() const {
322     MutexLock lock(&mutex_);
323     return called_;
324   }
result()325   bool result() const {
326     MutexLock lock(&mutex_);
327     return error_.empty();
328   }
error()329   const std::string& error() const {
330     MutexLock lock(&mutex_);
331     return error_;
332   }
333 
334  private:
335   mutable Mutex mutex_;
336   bool called_;
337   std::string error_;
338 };
339 
340 class FakeSetLocalDescriptionObserver
341     : public SetLocalDescriptionObserverInterface {
342  public:
called()343   bool called() const { return error_.has_value(); }
error()344   RTCError& error() {
345     RTC_DCHECK(error_.has_value());
346     return *error_;
347   }
348 
349   // SetLocalDescriptionObserverInterface implementation.
OnSetLocalDescriptionComplete(RTCError error)350   void OnSetLocalDescriptionComplete(RTCError error) override {
351     error_ = std::move(error);
352   }
353 
354  private:
355   // Set on complete, on success this is set to an RTCError::OK() error.
356   absl::optional<RTCError> error_;
357 };
358 
359 class FakeSetRemoteDescriptionObserver
360     : public SetRemoteDescriptionObserverInterface {
361  public:
called()362   bool called() const { return error_.has_value(); }
error()363   RTCError& error() {
364     RTC_DCHECK(error_.has_value());
365     return *error_;
366   }
367 
368   // SetRemoteDescriptionObserverInterface implementation.
OnSetRemoteDescriptionComplete(RTCError error)369   void OnSetRemoteDescriptionComplete(RTCError error) override {
370     error_ = std::move(error);
371   }
372 
373  private:
374   // Set on complete, on success this is set to an RTCError::OK() error.
375   absl::optional<RTCError> error_;
376 };
377 
378 class MockDataChannelObserver : public webrtc::DataChannelObserver {
379  public:
380   struct Message {
381     std::string data;
382     bool binary;
383   };
384 
MockDataChannelObserver(webrtc::DataChannelInterface * channel)385   explicit MockDataChannelObserver(webrtc::DataChannelInterface* channel)
386       : channel_(channel) {
387     channel_->RegisterObserver(this);
388     states_.push_back(channel_->state());
389   }
~MockDataChannelObserver()390   virtual ~MockDataChannelObserver() { channel_->UnregisterObserver(); }
391 
OnBufferedAmountChange(uint64_t previous_amount)392   void OnBufferedAmountChange(uint64_t previous_amount) override {}
393 
OnStateChange()394   void OnStateChange() override { states_.push_back(channel_->state()); }
OnMessage(const DataBuffer & buffer)395   void OnMessage(const DataBuffer& buffer) override {
396     messages_.push_back(
397         {std::string(buffer.data.data<char>(), buffer.data.size()),
398          buffer.binary});
399   }
400 
IsOpen()401   bool IsOpen() const { return state() == DataChannelInterface::kOpen; }
messages()402   std::vector<Message> messages() const { return messages_; }
last_message()403   std::string last_message() const {
404     if (messages_.empty())
405       return {};
406 
407     return messages_.back().data;
408   }
last_message_is_binary()409   bool last_message_is_binary() const {
410     if (messages_.empty())
411       return false;
412     return messages_.back().binary;
413   }
received_message_count()414   size_t received_message_count() const { return messages_.size(); }
415 
state()416   DataChannelInterface::DataState state() const { return states_.back(); }
states()417   const std::vector<DataChannelInterface::DataState>& states() const {
418     return states_;
419   }
420 
421  private:
422   rtc::scoped_refptr<webrtc::DataChannelInterface> channel_;
423   std::vector<DataChannelInterface::DataState> states_;
424   std::vector<Message> messages_;
425 };
426 
427 class MockStatsObserver : public webrtc::StatsObserver {
428  public:
MockStatsObserver()429   MockStatsObserver() : called_(false), stats_() {}
~MockStatsObserver()430   virtual ~MockStatsObserver() {}
431 
OnComplete(const StatsReports & reports)432   virtual void OnComplete(const StatsReports& reports) {
433     RTC_CHECK(!called_);
434     called_ = true;
435     stats_.Clear();
436     stats_.number_of_reports = reports.size();
437     for (const auto* r : reports) {
438       if (r->type() == StatsReport::kStatsReportTypeSsrc) {
439         stats_.timestamp = r->timestamp();
440         GetIntValue(r, StatsReport::kStatsValueNameAudioOutputLevel,
441                     &stats_.audio_output_level);
442         GetIntValue(r, StatsReport::kStatsValueNameAudioInputLevel,
443                     &stats_.audio_input_level);
444         GetIntValue(r, StatsReport::kStatsValueNameBytesReceived,
445                     &stats_.bytes_received);
446         GetIntValue(r, StatsReport::kStatsValueNameBytesSent,
447                     &stats_.bytes_sent);
448         GetInt64Value(r, StatsReport::kStatsValueNameCaptureStartNtpTimeMs,
449                       &stats_.capture_start_ntp_time);
450         stats_.track_ids.emplace_back();
451         GetStringValue(r, StatsReport::kStatsValueNameTrackId,
452                        &stats_.track_ids.back());
453       } else if (r->type() == StatsReport::kStatsReportTypeBwe) {
454         stats_.timestamp = r->timestamp();
455         GetIntValue(r, StatsReport::kStatsValueNameAvailableReceiveBandwidth,
456                     &stats_.available_receive_bandwidth);
457       } else if (r->type() == StatsReport::kStatsReportTypeComponent) {
458         stats_.timestamp = r->timestamp();
459         GetStringValue(r, StatsReport::kStatsValueNameDtlsCipher,
460                        &stats_.dtls_cipher);
461         GetStringValue(r, StatsReport::kStatsValueNameSrtpCipher,
462                        &stats_.srtp_cipher);
463       }
464     }
465   }
466 
called()467   bool called() const { return called_; }
number_of_reports()468   size_t number_of_reports() const { return stats_.number_of_reports; }
timestamp()469   double timestamp() const { return stats_.timestamp; }
470 
AudioOutputLevel()471   int AudioOutputLevel() const {
472     RTC_CHECK(called_);
473     return stats_.audio_output_level;
474   }
475 
AudioInputLevel()476   int AudioInputLevel() const {
477     RTC_CHECK(called_);
478     return stats_.audio_input_level;
479   }
480 
BytesReceived()481   int BytesReceived() const {
482     RTC_CHECK(called_);
483     return stats_.bytes_received;
484   }
485 
BytesSent()486   int BytesSent() const {
487     RTC_CHECK(called_);
488     return stats_.bytes_sent;
489   }
490 
CaptureStartNtpTime()491   int64_t CaptureStartNtpTime() const {
492     RTC_CHECK(called_);
493     return stats_.capture_start_ntp_time;
494   }
495 
AvailableReceiveBandwidth()496   int AvailableReceiveBandwidth() const {
497     RTC_CHECK(called_);
498     return stats_.available_receive_bandwidth;
499   }
500 
DtlsCipher()501   std::string DtlsCipher() const {
502     RTC_CHECK(called_);
503     return stats_.dtls_cipher;
504   }
505 
SrtpCipher()506   std::string SrtpCipher() const {
507     RTC_CHECK(called_);
508     return stats_.srtp_cipher;
509   }
510 
TrackIds()511   std::vector<std::string> TrackIds() const {
512     RTC_CHECK(called_);
513     return stats_.track_ids;
514   }
515 
516  private:
GetIntValue(const StatsReport * report,StatsReport::StatsValueName name,int * value)517   bool GetIntValue(const StatsReport* report,
518                    StatsReport::StatsValueName name,
519                    int* value) {
520     const StatsReport::Value* v = report->FindValue(name);
521     if (v) {
522       // TODO(tommi): We should really just be using an int here :-/
523       *value = rtc::FromString<int>(v->ToString());
524     }
525     return v != nullptr;
526   }
527 
GetInt64Value(const StatsReport * report,StatsReport::StatsValueName name,int64_t * value)528   bool GetInt64Value(const StatsReport* report,
529                      StatsReport::StatsValueName name,
530                      int64_t* value) {
531     const StatsReport::Value* v = report->FindValue(name);
532     if (v) {
533       // TODO(tommi): We should really just be using an int here :-/
534       *value = rtc::FromString<int64_t>(v->ToString());
535     }
536     return v != nullptr;
537   }
538 
GetStringValue(const StatsReport * report,StatsReport::StatsValueName name,std::string * value)539   bool GetStringValue(const StatsReport* report,
540                       StatsReport::StatsValueName name,
541                       std::string* value) {
542     const StatsReport::Value* v = report->FindValue(name);
543     if (v)
544       *value = v->ToString();
545     return v != nullptr;
546   }
547 
548   bool called_;
549   struct {
Clear__anonec6a7c820108550     void Clear() {
551       number_of_reports = 0;
552       timestamp = 0;
553       audio_output_level = 0;
554       audio_input_level = 0;
555       bytes_received = 0;
556       bytes_sent = 0;
557       capture_start_ntp_time = 0;
558       available_receive_bandwidth = 0;
559       dtls_cipher.clear();
560       srtp_cipher.clear();
561       track_ids.clear();
562     }
563 
564     size_t number_of_reports;
565     double timestamp;
566     int audio_output_level;
567     int audio_input_level;
568     int bytes_received;
569     int bytes_sent;
570     int64_t capture_start_ntp_time;
571     int available_receive_bandwidth;
572     std::string dtls_cipher;
573     std::string srtp_cipher;
574     std::vector<std::string> track_ids;
575   } stats_;
576 };
577 
578 // Helper class that just stores the report from the callback.
579 class MockRTCStatsCollectorCallback : public webrtc::RTCStatsCollectorCallback {
580  public:
report()581   rtc::scoped_refptr<const RTCStatsReport> report() { return report_; }
582 
called()583   bool called() const { return called_; }
584 
585  protected:
OnStatsDelivered(const rtc::scoped_refptr<const RTCStatsReport> & report)586   void OnStatsDelivered(
587       const rtc::scoped_refptr<const RTCStatsReport>& report) override {
588     report_ = report;
589     called_ = true;
590   }
591 
592  private:
593   bool called_ = false;
594   rtc::scoped_refptr<const RTCStatsReport> report_;
595 };
596 
597 }  // namespace webrtc
598 
599 #endif  // PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_
600