1 /* 2 * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 #ifndef NET_DCSCTP_TIMER_FAKE_TIMEOUT_H_ 11 #define NET_DCSCTP_TIMER_FAKE_TIMEOUT_H_ 12 13 #include <cstdint> 14 #include <functional> 15 #include <limits> 16 #include <memory> 17 #include <utility> 18 #include <vector> 19 20 #include "absl/types/optional.h" 21 #include "api/task_queue/task_queue_base.h" 22 #include "net/dcsctp/public/timeout.h" 23 #include "rtc_base/checks.h" 24 #include "rtc_base/containers/flat_set.h" 25 26 namespace dcsctp { 27 28 // A timeout used in tests. 29 class FakeTimeout : public Timeout { 30 public: FakeTimeout(std::function<TimeMs ()> get_time,std::function<void (FakeTimeout *)> on_delete)31 FakeTimeout(std::function<TimeMs()> get_time, 32 std::function<void(FakeTimeout*)> on_delete) 33 : get_time_(std::move(get_time)), on_delete_(std::move(on_delete)) {} 34 ~FakeTimeout()35 ~FakeTimeout() override { on_delete_(this); } 36 Start(DurationMs duration_ms,TimeoutID timeout_id)37 void Start(DurationMs duration_ms, TimeoutID timeout_id) override { 38 RTC_DCHECK(expiry_ == TimeMs::InfiniteFuture()); 39 timeout_id_ = timeout_id; 40 expiry_ = get_time_() + duration_ms; 41 } Stop()42 void Stop() override { 43 RTC_DCHECK(expiry_ != TimeMs::InfiniteFuture()); 44 expiry_ = TimeMs::InfiniteFuture(); 45 } 46 EvaluateHasExpired(TimeMs now)47 bool EvaluateHasExpired(TimeMs now) { 48 if (now >= expiry_) { 49 expiry_ = TimeMs::InfiniteFuture(); 50 return true; 51 } 52 return false; 53 } 54 timeout_id()55 TimeoutID timeout_id() const { return timeout_id_; } 56 57 private: 58 const std::function<TimeMs()> get_time_; 59 const std::function<void(FakeTimeout*)> on_delete_; 60 61 TimeoutID timeout_id_ = TimeoutID(0); 62 TimeMs expiry_ = TimeMs::InfiniteFuture(); 63 }; 64 65 class FakeTimeoutManager { 66 public: 67 // The `get_time` function must return the current time, relative to any 68 // epoch. FakeTimeoutManager(std::function<TimeMs ()> get_time)69 explicit FakeTimeoutManager(std::function<TimeMs()> get_time) 70 : get_time_(std::move(get_time)) {} 71 CreateTimeout()72 std::unique_ptr<FakeTimeout> CreateTimeout() { 73 auto timer = std::make_unique<FakeTimeout>( 74 get_time_, [this](FakeTimeout* timer) { timers_.erase(timer); }); 75 timers_.insert(timer.get()); 76 return timer; 77 } CreateTimeout(webrtc::TaskQueueBase::DelayPrecision precision)78 std::unique_ptr<FakeTimeout> CreateTimeout( 79 webrtc::TaskQueueBase::DelayPrecision precision) { 80 // FakeTimeout does not support implement |precision|. 81 return CreateTimeout(); 82 } 83 84 // NOTE: This can't return a vector, as calling EvaluateHasExpired requires 85 // calling socket->HandleTimeout directly afterwards, as the owning Timer 86 // still believes it's running, and it needs to be updated to set 87 // Timer::is_running_ to false before you operate on the Timer or Timeout 88 // again. GetNextExpiredTimeout()89 absl::optional<TimeoutID> GetNextExpiredTimeout() { 90 TimeMs now = get_time_(); 91 std::vector<TimeoutID> expired_timers; 92 for (auto& timer : timers_) { 93 if (timer->EvaluateHasExpired(now)) { 94 return timer->timeout_id(); 95 } 96 } 97 return absl::nullopt; 98 } 99 100 private: 101 const std::function<TimeMs()> get_time_; 102 webrtc::flat_set<FakeTimeout*> timers_; 103 }; 104 105 } // namespace dcsctp 106 107 #endif // NET_DCSCTP_TIMER_FAKE_TIMEOUT_H_ 108