xref: /aosp_15_r20/external/webrtc/net/dcsctp/timer/timer.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #include "net/dcsctp/timer/timer.h"
11 
12 #include <algorithm>
13 #include <cstdint>
14 #include <limits>
15 #include <memory>
16 #include <utility>
17 
18 #include "absl/memory/memory.h"
19 #include "absl/strings/string_view.h"
20 #include "net/dcsctp/public/timeout.h"
21 #include "rtc_base/checks.h"
22 
23 namespace dcsctp {
24 namespace {
MakeTimeoutId(TimerID timer_id,TimerGeneration generation)25 TimeoutID MakeTimeoutId(TimerID timer_id, TimerGeneration generation) {
26   return TimeoutID(static_cast<uint64_t>(*timer_id) << 32 | *generation);
27 }
28 
GetBackoffDuration(const TimerOptions & options,DurationMs base_duration,int expiration_count)29 DurationMs GetBackoffDuration(const TimerOptions& options,
30                               DurationMs base_duration,
31                               int expiration_count) {
32   switch (options.backoff_algorithm) {
33     case TimerBackoffAlgorithm::kFixed:
34       return base_duration;
35     case TimerBackoffAlgorithm::kExponential: {
36       int32_t duration_ms = *base_duration;
37 
38       while (expiration_count > 0 && duration_ms < *Timer::kMaxTimerDuration) {
39         duration_ms *= 2;
40         --expiration_count;
41 
42         if (options.max_backoff_duration.has_value() &&
43             duration_ms > **options.max_backoff_duration) {
44           return *options.max_backoff_duration;
45         }
46       }
47 
48       return DurationMs(std::min(duration_ms, *Timer::kMaxTimerDuration));
49     }
50   }
51 }
52 }  // namespace
53 
54 constexpr DurationMs Timer::kMaxTimerDuration;
55 
Timer(TimerID id,absl::string_view name,OnExpired on_expired,UnregisterHandler unregister_handler,std::unique_ptr<Timeout> timeout,const TimerOptions & options)56 Timer::Timer(TimerID id,
57              absl::string_view name,
58              OnExpired on_expired,
59              UnregisterHandler unregister_handler,
60              std::unique_ptr<Timeout> timeout,
61              const TimerOptions& options)
62     : id_(id),
63       name_(name),
64       options_(options),
65       on_expired_(std::move(on_expired)),
66       unregister_handler_(std::move(unregister_handler)),
67       timeout_(std::move(timeout)),
68       duration_(options.duration) {}
69 
~Timer()70 Timer::~Timer() {
71   Stop();
72   unregister_handler_();
73 }
74 
Start()75 void Timer::Start() {
76   expiration_count_ = 0;
77   if (!is_running()) {
78     is_running_ = true;
79     generation_ = TimerGeneration(*generation_ + 1);
80     timeout_->Start(duration_, MakeTimeoutId(id_, generation_));
81   } else {
82     // Timer was running - stop and restart it, to make it expire in `duration_`
83     // from now.
84     generation_ = TimerGeneration(*generation_ + 1);
85     timeout_->Restart(duration_, MakeTimeoutId(id_, generation_));
86   }
87 }
88 
Stop()89 void Timer::Stop() {
90   if (is_running()) {
91     timeout_->Stop();
92     expiration_count_ = 0;
93     is_running_ = false;
94   }
95 }
96 
Trigger(TimerGeneration generation)97 void Timer::Trigger(TimerGeneration generation) {
98   if (is_running_ && generation == generation_) {
99     ++expiration_count_;
100     is_running_ = false;
101     if (!options_.max_restarts.has_value() ||
102         expiration_count_ <= *options_.max_restarts) {
103       // The timer should still be running after this triggers. Start a new
104       // timer. Note that it might be very quickly restarted again, if the
105       // `on_expired_` callback returns a new duration.
106       is_running_ = true;
107       DurationMs duration =
108           GetBackoffDuration(options_, duration_, expiration_count_);
109       generation_ = TimerGeneration(*generation_ + 1);
110       timeout_->Start(duration, MakeTimeoutId(id_, generation_));
111     }
112 
113     absl::optional<DurationMs> new_duration = on_expired_();
114     if (new_duration.has_value() && new_duration != duration_) {
115       duration_ = new_duration.value();
116       if (is_running_) {
117         // Restart it with new duration.
118         timeout_->Stop();
119 
120         DurationMs duration =
121             GetBackoffDuration(options_, duration_, expiration_count_);
122         generation_ = TimerGeneration(*generation_ + 1);
123         timeout_->Start(duration, MakeTimeoutId(id_, generation_));
124       }
125     }
126   }
127 }
128 
HandleTimeout(TimeoutID timeout_id)129 void TimerManager::HandleTimeout(TimeoutID timeout_id) {
130   TimerID timer_id(*timeout_id >> 32);
131   TimerGeneration generation(*timeout_id);
132   auto it = timers_.find(timer_id);
133   if (it != timers_.end()) {
134     it->second->Trigger(generation);
135   }
136 }
137 
CreateTimer(absl::string_view name,Timer::OnExpired on_expired,const TimerOptions & options)138 std::unique_ptr<Timer> TimerManager::CreateTimer(absl::string_view name,
139                                                  Timer::OnExpired on_expired,
140                                                  const TimerOptions& options) {
141   next_id_ = TimerID(*next_id_ + 1);
142   TimerID id = next_id_;
143   // This would overflow after 4 billion timers created, which in SCTP would be
144   // after 800 million reconnections on a single socket. Ensure this will never
145   // happen.
146   RTC_CHECK_NE(*id, std::numeric_limits<uint32_t>::max());
147   std::unique_ptr<Timeout> timeout = create_timeout_(options.precision);
148   RTC_CHECK(timeout != nullptr);
149   auto timer = absl::WrapUnique(new Timer(
150       id, name, std::move(on_expired), [this, id]() { timers_.erase(id); },
151       std::move(timeout), options));
152   timers_[id] = timer.get();
153   return timer;
154 }
155 
156 }  // namespace dcsctp
157