1 /*
2 * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10 #include "net/dcsctp/socket/callback_deferrer.h"
11
12 #include "api/make_ref_counted.h"
13
14 namespace dcsctp {
15 namespace {
16 // A wrapper around the move-only DcSctpMessage, to let it be captured in a
17 // lambda.
18 class MessageDeliverer {
19 public:
MessageDeliverer(DcSctpMessage && message)20 explicit MessageDeliverer(DcSctpMessage&& message)
21 : state_(rtc::make_ref_counted<State>(std::move(message))) {}
22
Deliver(DcSctpSocketCallbacks & c)23 void Deliver(DcSctpSocketCallbacks& c) {
24 // Really ensure that it's only called once.
25 RTC_DCHECK(!state_->has_delivered);
26 state_->has_delivered = true;
27 c.OnMessageReceived(std::move(state_->message));
28 }
29
30 private:
31 struct State : public rtc::RefCountInterface {
Statedcsctp::__anon33e266ca0111::MessageDeliverer::State32 explicit State(DcSctpMessage&& m)
33 : has_delivered(false), message(std::move(m)) {}
34 bool has_delivered;
35 DcSctpMessage message;
36 };
37 rtc::scoped_refptr<State> state_;
38 };
39 } // namespace
40
Prepare()41 void CallbackDeferrer::Prepare() {
42 RTC_DCHECK(!prepared_);
43 prepared_ = true;
44 }
45
TriggerDeferred()46 void CallbackDeferrer::TriggerDeferred() {
47 // Need to swap here. The client may call into the library from within a
48 // callback, and that might result in adding new callbacks to this instance,
49 // and the vector can't be modified while iterated on.
50 RTC_DCHECK(prepared_);
51 std::vector<std::function<void(DcSctpSocketCallbacks & cb)>> deferred;
52 deferred.swap(deferred_);
53 prepared_ = false;
54
55 for (auto& cb : deferred) {
56 cb(underlying_);
57 }
58 }
59
SendPacketWithStatus(rtc::ArrayView<const uint8_t> data)60 SendPacketStatus CallbackDeferrer::SendPacketWithStatus(
61 rtc::ArrayView<const uint8_t> data) {
62 // Will not be deferred - call directly.
63 return underlying_.SendPacketWithStatus(data);
64 }
65
CreateTimeout(webrtc::TaskQueueBase::DelayPrecision precision)66 std::unique_ptr<Timeout> CallbackDeferrer::CreateTimeout(
67 webrtc::TaskQueueBase::DelayPrecision precision) {
68 // Will not be deferred - call directly.
69 return underlying_.CreateTimeout(precision);
70 }
71
TimeMillis()72 TimeMs CallbackDeferrer::TimeMillis() {
73 // Will not be deferred - call directly.
74 return underlying_.TimeMillis();
75 }
76
GetRandomInt(uint32_t low,uint32_t high)77 uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) {
78 // Will not be deferred - call directly.
79 return underlying_.GetRandomInt(low, high);
80 }
81
OnMessageReceived(DcSctpMessage message)82 void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) {
83 RTC_DCHECK(prepared_);
84 deferred_.emplace_back(
85 [deliverer = MessageDeliverer(std::move(message))](
86 DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); });
87 }
88
OnError(ErrorKind error,absl::string_view message)89 void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) {
90 RTC_DCHECK(prepared_);
91 deferred_.emplace_back(
92 [error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
93 cb.OnError(error, message);
94 });
95 }
96
OnAborted(ErrorKind error,absl::string_view message)97 void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) {
98 RTC_DCHECK(prepared_);
99 deferred_.emplace_back(
100 [error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
101 cb.OnAborted(error, message);
102 });
103 }
104
OnConnected()105 void CallbackDeferrer::OnConnected() {
106 RTC_DCHECK(prepared_);
107 deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); });
108 }
109
OnClosed()110 void CallbackDeferrer::OnClosed() {
111 RTC_DCHECK(prepared_);
112 deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); });
113 }
114
OnConnectionRestarted()115 void CallbackDeferrer::OnConnectionRestarted() {
116 RTC_DCHECK(prepared_);
117 deferred_.emplace_back(
118 [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); });
119 }
120
OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,absl::string_view reason)121 void CallbackDeferrer::OnStreamsResetFailed(
122 rtc::ArrayView<const StreamID> outgoing_streams,
123 absl::string_view reason) {
124 RTC_DCHECK(prepared_);
125 deferred_.emplace_back(
126 [streams = std::vector<StreamID>(outgoing_streams.begin(),
127 outgoing_streams.end()),
128 reason = std::string(reason)](DcSctpSocketCallbacks& cb) {
129 cb.OnStreamsResetFailed(streams, reason);
130 });
131 }
132
OnStreamsResetPerformed(rtc::ArrayView<const StreamID> outgoing_streams)133 void CallbackDeferrer::OnStreamsResetPerformed(
134 rtc::ArrayView<const StreamID> outgoing_streams) {
135 RTC_DCHECK(prepared_);
136 deferred_.emplace_back(
137 [streams = std::vector<StreamID>(outgoing_streams.begin(),
138 outgoing_streams.end())](
139 DcSctpSocketCallbacks& cb) { cb.OnStreamsResetPerformed(streams); });
140 }
141
OnIncomingStreamsReset(rtc::ArrayView<const StreamID> incoming_streams)142 void CallbackDeferrer::OnIncomingStreamsReset(
143 rtc::ArrayView<const StreamID> incoming_streams) {
144 RTC_DCHECK(prepared_);
145 deferred_.emplace_back(
146 [streams = std::vector<StreamID>(incoming_streams.begin(),
147 incoming_streams.end())](
148 DcSctpSocketCallbacks& cb) { cb.OnIncomingStreamsReset(streams); });
149 }
150
OnBufferedAmountLow(StreamID stream_id)151 void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) {
152 RTC_DCHECK(prepared_);
153 deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) {
154 cb.OnBufferedAmountLow(stream_id);
155 });
156 }
157
OnTotalBufferedAmountLow()158 void CallbackDeferrer::OnTotalBufferedAmountLow() {
159 RTC_DCHECK(prepared_);
160 deferred_.emplace_back(
161 [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); });
162 }
163
OnLifecycleMessageExpired(LifecycleId lifecycle_id,bool maybe_delivered)164 void CallbackDeferrer::OnLifecycleMessageExpired(LifecycleId lifecycle_id,
165 bool maybe_delivered) {
166 // Will not be deferred - call directly.
167 underlying_.OnLifecycleMessageExpired(lifecycle_id, maybe_delivered);
168 }
OnLifecycleMessageFullySent(LifecycleId lifecycle_id)169 void CallbackDeferrer::OnLifecycleMessageFullySent(LifecycleId lifecycle_id) {
170 // Will not be deferred - call directly.
171 underlying_.OnLifecycleMessageFullySent(lifecycle_id);
172 }
OnLifecycleMessageDelivered(LifecycleId lifecycle_id)173 void CallbackDeferrer::OnLifecycleMessageDelivered(LifecycleId lifecycle_id) {
174 // Will not be deferred - call directly.
175 underlying_.OnLifecycleMessageDelivered(lifecycle_id);
176 }
OnLifecycleEnd(LifecycleId lifecycle_id)177 void CallbackDeferrer::OnLifecycleEnd(LifecycleId lifecycle_id) {
178 // Will not be deferred - call directly.
179 underlying_.OnLifecycleEnd(lifecycle_id);
180 }
181 } // namespace dcsctp
182