1 /* 2 * Copyright 2012 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 // This file contains mock implementations of observers used in PeerConnection. 12 // TODO(steveanton): These aren't really mocks and should be renamed. 13 14 #ifndef PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_ 15 #define PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_ 16 17 #include <map> 18 #include <memory> 19 #include <string> 20 #include <utility> 21 #include <vector> 22 23 #include "api/data_channel_interface.h" 24 #include "api/jsep_ice_candidate.h" 25 #include "pc/stream_collection.h" 26 #include "rtc_base/checks.h" 27 28 namespace webrtc { 29 30 class MockPeerConnectionObserver : public PeerConnectionObserver { 31 public: 32 struct AddTrackEvent { AddTrackEventAddTrackEvent33 explicit AddTrackEvent( 34 rtc::scoped_refptr<RtpReceiverInterface> event_receiver, 35 std::vector<rtc::scoped_refptr<MediaStreamInterface>> event_streams) 36 : receiver(std::move(event_receiver)), 37 streams(std::move(event_streams)) { 38 for (auto stream : streams) { 39 std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>> tracks; 40 for (auto audio_track : stream->GetAudioTracks()) { 41 tracks.push_back(audio_track); 42 } 43 for (auto video_track : stream->GetVideoTracks()) { 44 tracks.push_back(video_track); 45 } 46 snapshotted_stream_tracks[stream] = tracks; 47 } 48 } 49 50 rtc::scoped_refptr<RtpReceiverInterface> receiver; 51 std::vector<rtc::scoped_refptr<MediaStreamInterface>> streams; 52 // This map records the tracks present in each stream at the time the 53 // OnAddTrack callback was issued. 54 std::map<rtc::scoped_refptr<MediaStreamInterface>, 55 std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>>> 56 snapshotted_stream_tracks; 57 }; 58 MockPeerConnectionObserver()59 MockPeerConnectionObserver() : remote_streams_(StreamCollection::Create()) {} ~MockPeerConnectionObserver()60 virtual ~MockPeerConnectionObserver() {} SetPeerConnectionInterface(PeerConnectionInterface * pc)61 void SetPeerConnectionInterface(PeerConnectionInterface* pc) { 62 pc_ = pc; 63 if (pc) { 64 state_ = pc_->signaling_state(); 65 } 66 } OnSignalingChange(PeerConnectionInterface::SignalingState new_state)67 void OnSignalingChange( 68 PeerConnectionInterface::SignalingState new_state) override { 69 RTC_DCHECK(pc_); 70 RTC_DCHECK(pc_->signaling_state() == new_state); 71 state_ = new_state; 72 } 73 RemoteStream(const std::string & label)74 MediaStreamInterface* RemoteStream(const std::string& label) { 75 return remote_streams_->find(label); 76 } remote_streams()77 StreamCollectionInterface* remote_streams() const { 78 return remote_streams_.get(); 79 } OnAddStream(rtc::scoped_refptr<MediaStreamInterface> stream)80 void OnAddStream(rtc::scoped_refptr<MediaStreamInterface> stream) override { 81 last_added_stream_ = stream; 82 remote_streams_->AddStream(stream); 83 } OnRemoveStream(rtc::scoped_refptr<MediaStreamInterface> stream)84 void OnRemoveStream( 85 rtc::scoped_refptr<MediaStreamInterface> stream) override { 86 last_removed_stream_ = stream; 87 remote_streams_->RemoveStream(stream.get()); 88 } OnRenegotiationNeeded()89 void OnRenegotiationNeeded() override { renegotiation_needed_ = true; } OnNegotiationNeededEvent(uint32_t event_id)90 void OnNegotiationNeededEvent(uint32_t event_id) override { 91 latest_negotiation_needed_event_ = event_id; 92 } OnDataChannel(rtc::scoped_refptr<DataChannelInterface> data_channel)93 void OnDataChannel( 94 rtc::scoped_refptr<DataChannelInterface> data_channel) override { 95 last_datachannel_ = data_channel; 96 } 97 OnIceConnectionChange(PeerConnectionInterface::IceConnectionState new_state)98 void OnIceConnectionChange( 99 PeerConnectionInterface::IceConnectionState new_state) override { 100 RTC_DCHECK(pc_); 101 RTC_DCHECK(pc_->ice_connection_state() == new_state); 102 // When ICE is finished, the caller will get to a kIceConnectionCompleted 103 // state, because it has the ICE controlling role, while the callee 104 // will get to a kIceConnectionConnected state. This means that both ICE 105 // and DTLS are connected. 106 ice_connected_ = 107 (new_state == PeerConnectionInterface::kIceConnectionConnected) || 108 (new_state == PeerConnectionInterface::kIceConnectionCompleted); 109 callback_triggered_ = true; 110 } OnIceGatheringChange(PeerConnectionInterface::IceGatheringState new_state)111 void OnIceGatheringChange( 112 PeerConnectionInterface::IceGatheringState new_state) override { 113 RTC_DCHECK(pc_); 114 RTC_DCHECK(pc_->ice_gathering_state() == new_state); 115 ice_gathering_complete_ = 116 new_state == PeerConnectionInterface::kIceGatheringComplete; 117 callback_triggered_ = true; 118 } OnIceCandidate(const IceCandidateInterface * candidate)119 void OnIceCandidate(const IceCandidateInterface* candidate) override { 120 RTC_DCHECK(pc_); 121 candidates_.push_back(std::make_unique<JsepIceCandidate>( 122 candidate->sdp_mid(), candidate->sdp_mline_index(), 123 candidate->candidate())); 124 callback_triggered_ = true; 125 } 126 OnIceCandidatesRemoved(const std::vector<cricket::Candidate> & candidates)127 void OnIceCandidatesRemoved( 128 const std::vector<cricket::Candidate>& candidates) override { 129 num_candidates_removed_++; 130 callback_triggered_ = true; 131 } 132 OnIceConnectionReceivingChange(bool receiving)133 void OnIceConnectionReceivingChange(bool receiving) override { 134 callback_triggered_ = true; 135 } 136 OnAddTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver,const std::vector<rtc::scoped_refptr<MediaStreamInterface>> & streams)137 void OnAddTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver, 138 const std::vector<rtc::scoped_refptr<MediaStreamInterface>>& 139 streams) override { 140 RTC_DCHECK(receiver); 141 num_added_tracks_++; 142 last_added_track_label_ = receiver->id(); 143 add_track_events_.push_back(AddTrackEvent(receiver, streams)); 144 } 145 OnTrack(rtc::scoped_refptr<RtpTransceiverInterface> transceiver)146 void OnTrack( 147 rtc::scoped_refptr<RtpTransceiverInterface> transceiver) override { 148 on_track_transceivers_.push_back(transceiver); 149 } 150 OnRemoveTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver)151 void OnRemoveTrack( 152 rtc::scoped_refptr<RtpReceiverInterface> receiver) override { 153 remove_track_events_.push_back(receiver); 154 } 155 GetAddTrackReceivers()156 std::vector<rtc::scoped_refptr<RtpReceiverInterface>> GetAddTrackReceivers() { 157 std::vector<rtc::scoped_refptr<RtpReceiverInterface>> receivers; 158 for (const AddTrackEvent& event : add_track_events_) { 159 receivers.push_back(event.receiver); 160 } 161 return receivers; 162 } 163 CountAddTrackEventsForStream(const std::string & stream_id)164 int CountAddTrackEventsForStream(const std::string& stream_id) { 165 int found_tracks = 0; 166 for (const AddTrackEvent& event : add_track_events_) { 167 bool has_stream_id = false; 168 for (auto stream : event.streams) { 169 if (stream->id() == stream_id) { 170 has_stream_id = true; 171 break; 172 } 173 } 174 if (has_stream_id) { 175 ++found_tracks; 176 } 177 } 178 return found_tracks; 179 } 180 181 // Returns the id of the last added stream. 182 // Empty string if no stream have been added. GetLastAddedStreamId()183 std::string GetLastAddedStreamId() { 184 if (last_added_stream_.get()) 185 return last_added_stream_->id(); 186 return ""; 187 } GetLastRemovedStreamId()188 std::string GetLastRemovedStreamId() { 189 if (last_removed_stream_.get()) 190 return last_removed_stream_->id(); 191 return ""; 192 } 193 last_candidate()194 IceCandidateInterface* last_candidate() { 195 if (candidates_.empty()) { 196 return nullptr; 197 } else { 198 return candidates_.back().get(); 199 } 200 } 201 GetAllCandidates()202 std::vector<const IceCandidateInterface*> GetAllCandidates() { 203 std::vector<const IceCandidateInterface*> candidates; 204 for (const auto& candidate : candidates_) { 205 candidates.push_back(candidate.get()); 206 } 207 return candidates; 208 } 209 GetCandidatesByMline(int mline_index)210 std::vector<IceCandidateInterface*> GetCandidatesByMline(int mline_index) { 211 std::vector<IceCandidateInterface*> candidates; 212 for (const auto& candidate : candidates_) { 213 if (candidate->sdp_mline_index() == mline_index) { 214 candidates.push_back(candidate.get()); 215 } 216 } 217 return candidates; 218 } 219 legacy_renegotiation_needed()220 bool legacy_renegotiation_needed() const { return renegotiation_needed_; } clear_legacy_renegotiation_needed()221 void clear_legacy_renegotiation_needed() { renegotiation_needed_ = false; } 222 has_negotiation_needed_event()223 bool has_negotiation_needed_event() { 224 return latest_negotiation_needed_event_.has_value(); 225 } latest_negotiation_needed_event()226 uint32_t latest_negotiation_needed_event() { 227 return latest_negotiation_needed_event_.value_or(0u); 228 } clear_latest_negotiation_needed_event()229 void clear_latest_negotiation_needed_event() { 230 latest_negotiation_needed_event_ = absl::nullopt; 231 } 232 233 rtc::scoped_refptr<PeerConnectionInterface> pc_; 234 PeerConnectionInterface::SignalingState state_; 235 std::vector<std::unique_ptr<IceCandidateInterface>> candidates_; 236 rtc::scoped_refptr<DataChannelInterface> last_datachannel_; 237 rtc::scoped_refptr<StreamCollection> remote_streams_; 238 bool renegotiation_needed_ = false; 239 absl::optional<uint32_t> latest_negotiation_needed_event_; 240 bool ice_gathering_complete_ = false; 241 bool ice_connected_ = false; 242 bool callback_triggered_ = false; 243 int num_added_tracks_ = 0; 244 std::string last_added_track_label_; 245 std::vector<AddTrackEvent> add_track_events_; 246 std::vector<rtc::scoped_refptr<RtpReceiverInterface>> remove_track_events_; 247 std::vector<rtc::scoped_refptr<RtpTransceiverInterface>> 248 on_track_transceivers_; 249 int num_candidates_removed_ = 0; 250 251 private: 252 rtc::scoped_refptr<MediaStreamInterface> last_added_stream_; 253 rtc::scoped_refptr<MediaStreamInterface> last_removed_stream_; 254 }; 255 256 class MockCreateSessionDescriptionObserver 257 : public webrtc::CreateSessionDescriptionObserver { 258 public: MockCreateSessionDescriptionObserver()259 MockCreateSessionDescriptionObserver() 260 : called_(false), 261 error_("MockCreateSessionDescriptionObserver not called") {} ~MockCreateSessionDescriptionObserver()262 virtual ~MockCreateSessionDescriptionObserver() {} OnSuccess(SessionDescriptionInterface * desc)263 void OnSuccess(SessionDescriptionInterface* desc) override { 264 MutexLock lock(&mutex_); 265 called_ = true; 266 error_ = ""; 267 desc_.reset(desc); 268 } OnFailure(webrtc::RTCError error)269 void OnFailure(webrtc::RTCError error) override { 270 MutexLock lock(&mutex_); 271 called_ = true; 272 error_ = error.message(); 273 } called()274 bool called() const { 275 MutexLock lock(&mutex_); 276 return called_; 277 } result()278 bool result() const { 279 MutexLock lock(&mutex_); 280 return error_.empty(); 281 } error()282 const std::string& error() const { 283 MutexLock lock(&mutex_); 284 return error_; 285 } MoveDescription()286 std::unique_ptr<SessionDescriptionInterface> MoveDescription() { 287 MutexLock lock(&mutex_); 288 return std::move(desc_); 289 } 290 291 private: 292 mutable Mutex mutex_; 293 bool called_ RTC_GUARDED_BY(mutex_); 294 std::string error_ RTC_GUARDED_BY(mutex_); 295 std::unique_ptr<SessionDescriptionInterface> desc_ RTC_GUARDED_BY(mutex_); 296 }; 297 298 class MockSetSessionDescriptionObserver 299 : public webrtc::SetSessionDescriptionObserver { 300 public: Create()301 static rtc::scoped_refptr<MockSetSessionDescriptionObserver> Create() { 302 return rtc::make_ref_counted<MockSetSessionDescriptionObserver>(); 303 } 304 MockSetSessionDescriptionObserver()305 MockSetSessionDescriptionObserver() 306 : called_(false), 307 error_("MockSetSessionDescriptionObserver not called") {} ~MockSetSessionDescriptionObserver()308 ~MockSetSessionDescriptionObserver() override {} OnSuccess()309 void OnSuccess() override { 310 MutexLock lock(&mutex_); 311 312 called_ = true; 313 error_ = ""; 314 } OnFailure(webrtc::RTCError error)315 void OnFailure(webrtc::RTCError error) override { 316 MutexLock lock(&mutex_); 317 called_ = true; 318 error_ = error.message(); 319 } 320 called()321 bool called() const { 322 MutexLock lock(&mutex_); 323 return called_; 324 } result()325 bool result() const { 326 MutexLock lock(&mutex_); 327 return error_.empty(); 328 } error()329 const std::string& error() const { 330 MutexLock lock(&mutex_); 331 return error_; 332 } 333 334 private: 335 mutable Mutex mutex_; 336 bool called_; 337 std::string error_; 338 }; 339 340 class FakeSetLocalDescriptionObserver 341 : public SetLocalDescriptionObserverInterface { 342 public: called()343 bool called() const { return error_.has_value(); } error()344 RTCError& error() { 345 RTC_DCHECK(error_.has_value()); 346 return *error_; 347 } 348 349 // SetLocalDescriptionObserverInterface implementation. OnSetLocalDescriptionComplete(RTCError error)350 void OnSetLocalDescriptionComplete(RTCError error) override { 351 error_ = std::move(error); 352 } 353 354 private: 355 // Set on complete, on success this is set to an RTCError::OK() error. 356 absl::optional<RTCError> error_; 357 }; 358 359 class FakeSetRemoteDescriptionObserver 360 : public SetRemoteDescriptionObserverInterface { 361 public: called()362 bool called() const { return error_.has_value(); } error()363 RTCError& error() { 364 RTC_DCHECK(error_.has_value()); 365 return *error_; 366 } 367 368 // SetRemoteDescriptionObserverInterface implementation. OnSetRemoteDescriptionComplete(RTCError error)369 void OnSetRemoteDescriptionComplete(RTCError error) override { 370 error_ = std::move(error); 371 } 372 373 private: 374 // Set on complete, on success this is set to an RTCError::OK() error. 375 absl::optional<RTCError> error_; 376 }; 377 378 class MockDataChannelObserver : public webrtc::DataChannelObserver { 379 public: 380 struct Message { 381 std::string data; 382 bool binary; 383 }; 384 MockDataChannelObserver(webrtc::DataChannelInterface * channel)385 explicit MockDataChannelObserver(webrtc::DataChannelInterface* channel) 386 : channel_(channel) { 387 channel_->RegisterObserver(this); 388 states_.push_back(channel_->state()); 389 } ~MockDataChannelObserver()390 virtual ~MockDataChannelObserver() { channel_->UnregisterObserver(); } 391 OnBufferedAmountChange(uint64_t previous_amount)392 void OnBufferedAmountChange(uint64_t previous_amount) override {} 393 OnStateChange()394 void OnStateChange() override { states_.push_back(channel_->state()); } OnMessage(const DataBuffer & buffer)395 void OnMessage(const DataBuffer& buffer) override { 396 messages_.push_back( 397 {std::string(buffer.data.data<char>(), buffer.data.size()), 398 buffer.binary}); 399 } 400 IsOpen()401 bool IsOpen() const { return state() == DataChannelInterface::kOpen; } messages()402 std::vector<Message> messages() const { return messages_; } last_message()403 std::string last_message() const { 404 if (messages_.empty()) 405 return {}; 406 407 return messages_.back().data; 408 } last_message_is_binary()409 bool last_message_is_binary() const { 410 if (messages_.empty()) 411 return false; 412 return messages_.back().binary; 413 } received_message_count()414 size_t received_message_count() const { return messages_.size(); } 415 state()416 DataChannelInterface::DataState state() const { return states_.back(); } states()417 const std::vector<DataChannelInterface::DataState>& states() const { 418 return states_; 419 } 420 421 private: 422 rtc::scoped_refptr<webrtc::DataChannelInterface> channel_; 423 std::vector<DataChannelInterface::DataState> states_; 424 std::vector<Message> messages_; 425 }; 426 427 class MockStatsObserver : public webrtc::StatsObserver { 428 public: MockStatsObserver()429 MockStatsObserver() : called_(false), stats_() {} ~MockStatsObserver()430 virtual ~MockStatsObserver() {} 431 OnComplete(const StatsReports & reports)432 virtual void OnComplete(const StatsReports& reports) { 433 RTC_CHECK(!called_); 434 called_ = true; 435 stats_.Clear(); 436 stats_.number_of_reports = reports.size(); 437 for (const auto* r : reports) { 438 if (r->type() == StatsReport::kStatsReportTypeSsrc) { 439 stats_.timestamp = r->timestamp(); 440 GetIntValue(r, StatsReport::kStatsValueNameAudioOutputLevel, 441 &stats_.audio_output_level); 442 GetIntValue(r, StatsReport::kStatsValueNameAudioInputLevel, 443 &stats_.audio_input_level); 444 GetIntValue(r, StatsReport::kStatsValueNameBytesReceived, 445 &stats_.bytes_received); 446 GetIntValue(r, StatsReport::kStatsValueNameBytesSent, 447 &stats_.bytes_sent); 448 GetInt64Value(r, StatsReport::kStatsValueNameCaptureStartNtpTimeMs, 449 &stats_.capture_start_ntp_time); 450 stats_.track_ids.emplace_back(); 451 GetStringValue(r, StatsReport::kStatsValueNameTrackId, 452 &stats_.track_ids.back()); 453 } else if (r->type() == StatsReport::kStatsReportTypeBwe) { 454 stats_.timestamp = r->timestamp(); 455 GetIntValue(r, StatsReport::kStatsValueNameAvailableReceiveBandwidth, 456 &stats_.available_receive_bandwidth); 457 } else if (r->type() == StatsReport::kStatsReportTypeComponent) { 458 stats_.timestamp = r->timestamp(); 459 GetStringValue(r, StatsReport::kStatsValueNameDtlsCipher, 460 &stats_.dtls_cipher); 461 GetStringValue(r, StatsReport::kStatsValueNameSrtpCipher, 462 &stats_.srtp_cipher); 463 } 464 } 465 } 466 called()467 bool called() const { return called_; } number_of_reports()468 size_t number_of_reports() const { return stats_.number_of_reports; } timestamp()469 double timestamp() const { return stats_.timestamp; } 470 AudioOutputLevel()471 int AudioOutputLevel() const { 472 RTC_CHECK(called_); 473 return stats_.audio_output_level; 474 } 475 AudioInputLevel()476 int AudioInputLevel() const { 477 RTC_CHECK(called_); 478 return stats_.audio_input_level; 479 } 480 BytesReceived()481 int BytesReceived() const { 482 RTC_CHECK(called_); 483 return stats_.bytes_received; 484 } 485 BytesSent()486 int BytesSent() const { 487 RTC_CHECK(called_); 488 return stats_.bytes_sent; 489 } 490 CaptureStartNtpTime()491 int64_t CaptureStartNtpTime() const { 492 RTC_CHECK(called_); 493 return stats_.capture_start_ntp_time; 494 } 495 AvailableReceiveBandwidth()496 int AvailableReceiveBandwidth() const { 497 RTC_CHECK(called_); 498 return stats_.available_receive_bandwidth; 499 } 500 DtlsCipher()501 std::string DtlsCipher() const { 502 RTC_CHECK(called_); 503 return stats_.dtls_cipher; 504 } 505 SrtpCipher()506 std::string SrtpCipher() const { 507 RTC_CHECK(called_); 508 return stats_.srtp_cipher; 509 } 510 TrackIds()511 std::vector<std::string> TrackIds() const { 512 RTC_CHECK(called_); 513 return stats_.track_ids; 514 } 515 516 private: GetIntValue(const StatsReport * report,StatsReport::StatsValueName name,int * value)517 bool GetIntValue(const StatsReport* report, 518 StatsReport::StatsValueName name, 519 int* value) { 520 const StatsReport::Value* v = report->FindValue(name); 521 if (v) { 522 // TODO(tommi): We should really just be using an int here :-/ 523 *value = rtc::FromString<int>(v->ToString()); 524 } 525 return v != nullptr; 526 } 527 GetInt64Value(const StatsReport * report,StatsReport::StatsValueName name,int64_t * value)528 bool GetInt64Value(const StatsReport* report, 529 StatsReport::StatsValueName name, 530 int64_t* value) { 531 const StatsReport::Value* v = report->FindValue(name); 532 if (v) { 533 // TODO(tommi): We should really just be using an int here :-/ 534 *value = rtc::FromString<int64_t>(v->ToString()); 535 } 536 return v != nullptr; 537 } 538 GetStringValue(const StatsReport * report,StatsReport::StatsValueName name,std::string * value)539 bool GetStringValue(const StatsReport* report, 540 StatsReport::StatsValueName name, 541 std::string* value) { 542 const StatsReport::Value* v = report->FindValue(name); 543 if (v) 544 *value = v->ToString(); 545 return v != nullptr; 546 } 547 548 bool called_; 549 struct { Clear__anonec6a7c820108550 void Clear() { 551 number_of_reports = 0; 552 timestamp = 0; 553 audio_output_level = 0; 554 audio_input_level = 0; 555 bytes_received = 0; 556 bytes_sent = 0; 557 capture_start_ntp_time = 0; 558 available_receive_bandwidth = 0; 559 dtls_cipher.clear(); 560 srtp_cipher.clear(); 561 track_ids.clear(); 562 } 563 564 size_t number_of_reports; 565 double timestamp; 566 int audio_output_level; 567 int audio_input_level; 568 int bytes_received; 569 int bytes_sent; 570 int64_t capture_start_ntp_time; 571 int available_receive_bandwidth; 572 std::string dtls_cipher; 573 std::string srtp_cipher; 574 std::vector<std::string> track_ids; 575 } stats_; 576 }; 577 578 // Helper class that just stores the report from the callback. 579 class MockRTCStatsCollectorCallback : public webrtc::RTCStatsCollectorCallback { 580 public: report()581 rtc::scoped_refptr<const RTCStatsReport> report() { return report_; } 582 called()583 bool called() const { return called_; } 584 585 protected: OnStatsDelivered(const rtc::scoped_refptr<const RTCStatsReport> & report)586 void OnStatsDelivered( 587 const rtc::scoped_refptr<const RTCStatsReport>& report) override { 588 report_ = report; 589 called_ = true; 590 } 591 592 private: 593 bool called_ = false; 594 rtc::scoped_refptr<const RTCStatsReport> report_; 595 }; 596 597 } // namespace webrtc 598 599 #endif // PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_ 600