xref: /aosp_15_r20/external/openscreen/osp/impl/message_demuxer.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1*3f982cf4SFabien Sanglard // Copyright 2018 The Chromium Authors. All rights reserved.
2*3f982cf4SFabien Sanglard // Use of this source code is governed by a BSD-style license that can be
3*3f982cf4SFabien Sanglard // found in the LICENSE file.
4*3f982cf4SFabien Sanglard 
5*3f982cf4SFabien Sanglard #include "osp/public/message_demuxer.h"
6*3f982cf4SFabien Sanglard 
7*3f982cf4SFabien Sanglard #include <memory>
8*3f982cf4SFabien Sanglard #include <utility>
9*3f982cf4SFabien Sanglard 
10*3f982cf4SFabien Sanglard #include "osp/impl/quic/quic_connection.h"
11*3f982cf4SFabien Sanglard #include "platform/base/error.h"
12*3f982cf4SFabien Sanglard #include "util/big_endian.h"
13*3f982cf4SFabien Sanglard #include "util/osp_logging.h"
14*3f982cf4SFabien Sanglard 
15*3f982cf4SFabien Sanglard namespace openscreen {
16*3f982cf4SFabien Sanglard namespace osp {
17*3f982cf4SFabien Sanglard 
18*3f982cf4SFabien Sanglard // static
19*3f982cf4SFabien Sanglard // Decodes a varUint, expecting it to follow the encoding format described here:
20*3f982cf4SFabien Sanglard // https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
DecodeVarUint(const std::vector<uint8_t> & buffer,size_t * num_bytes_decoded)21*3f982cf4SFabien Sanglard ErrorOr<uint64_t> MessageTypeDecoder::DecodeVarUint(
22*3f982cf4SFabien Sanglard     const std::vector<uint8_t>& buffer,
23*3f982cf4SFabien Sanglard     size_t* num_bytes_decoded) {
24*3f982cf4SFabien Sanglard   if (buffer.size() == 0) {
25*3f982cf4SFabien Sanglard     return Error::Code::kCborIncompleteMessage;
26*3f982cf4SFabien Sanglard   }
27*3f982cf4SFabien Sanglard 
28*3f982cf4SFabien Sanglard   uint8_t num_type_bytes = static_cast<uint8_t>(buffer[0] >> 6 & 0x03);
29*3f982cf4SFabien Sanglard   *num_bytes_decoded = 0x1 << num_type_bytes;
30*3f982cf4SFabien Sanglard 
31*3f982cf4SFabien Sanglard   // Ensure that ReadBigEndian won't read beyond the end of the buffer. Also,
32*3f982cf4SFabien Sanglard   // since we expect the id to be followed by the message, equality is not valid
33*3f982cf4SFabien Sanglard   if (buffer.size() <= *num_bytes_decoded) {
34*3f982cf4SFabien Sanglard     return Error::Code::kCborIncompleteMessage;
35*3f982cf4SFabien Sanglard   }
36*3f982cf4SFabien Sanglard 
37*3f982cf4SFabien Sanglard   switch (num_type_bytes) {
38*3f982cf4SFabien Sanglard     case 0:
39*3f982cf4SFabien Sanglard       return ReadBigEndian<uint8_t>(&buffer[0]) & ~0xC0;
40*3f982cf4SFabien Sanglard     case 1:
41*3f982cf4SFabien Sanglard       return ReadBigEndian<uint16_t>(&buffer[0]) & ~(0xC0 << 8);
42*3f982cf4SFabien Sanglard     case 2:
43*3f982cf4SFabien Sanglard       return ReadBigEndian<uint32_t>(&buffer[0]) & ~(0xC0 << 24);
44*3f982cf4SFabien Sanglard     case 3:
45*3f982cf4SFabien Sanglard       return ReadBigEndian<uint64_t>(&buffer[0]) & ~(uint64_t{0xC0} << 56);
46*3f982cf4SFabien Sanglard     default:
47*3f982cf4SFabien Sanglard       OSP_NOTREACHED();
48*3f982cf4SFabien Sanglard   }
49*3f982cf4SFabien Sanglard }
50*3f982cf4SFabien Sanglard 
51*3f982cf4SFabien Sanglard // static
52*3f982cf4SFabien Sanglard // Decodes the Type of message, expecting it to follow the encoding format
53*3f982cf4SFabien Sanglard // described here:
54*3f982cf4SFabien Sanglard // https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
DecodeType(const std::vector<uint8_t> & buffer,size_t * num_bytes_decoded)55*3f982cf4SFabien Sanglard ErrorOr<msgs::Type> MessageTypeDecoder::DecodeType(
56*3f982cf4SFabien Sanglard     const std::vector<uint8_t>& buffer,
57*3f982cf4SFabien Sanglard     size_t* num_bytes_decoded) {
58*3f982cf4SFabien Sanglard   ErrorOr<uint64_t> message_type =
59*3f982cf4SFabien Sanglard       MessageTypeDecoder::DecodeVarUint(buffer, num_bytes_decoded);
60*3f982cf4SFabien Sanglard   if (message_type.is_error()) {
61*3f982cf4SFabien Sanglard     return message_type.error();
62*3f982cf4SFabien Sanglard   }
63*3f982cf4SFabien Sanglard 
64*3f982cf4SFabien Sanglard   msgs::Type parsed_type =
65*3f982cf4SFabien Sanglard       msgs::TypeEnumValidator::SafeCast(message_type.value());
66*3f982cf4SFabien Sanglard   if (parsed_type == msgs::Type::kUnknown) {
67*3f982cf4SFabien Sanglard     return Error::Code::kCborInvalidMessage;
68*3f982cf4SFabien Sanglard   }
69*3f982cf4SFabien Sanglard 
70*3f982cf4SFabien Sanglard   return parsed_type;
71*3f982cf4SFabien Sanglard }
72*3f982cf4SFabien Sanglard 
73*3f982cf4SFabien Sanglard // static
74*3f982cf4SFabien Sanglard constexpr size_t MessageDemuxer::kDefaultBufferLimit;
75*3f982cf4SFabien Sanglard 
76*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch::MessageWatch() = default;
77*3f982cf4SFabien Sanglard 
MessageWatch(MessageDemuxer * parent,bool is_default,uint64_t endpoint_id,msgs::Type message_type)78*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent,
79*3f982cf4SFabien Sanglard                                            bool is_default,
80*3f982cf4SFabien Sanglard                                            uint64_t endpoint_id,
81*3f982cf4SFabien Sanglard                                            msgs::Type message_type)
82*3f982cf4SFabien Sanglard     : parent_(parent),
83*3f982cf4SFabien Sanglard       is_default_(is_default),
84*3f982cf4SFabien Sanglard       endpoint_id_(endpoint_id),
85*3f982cf4SFabien Sanglard       message_type_(message_type) {}
86*3f982cf4SFabien Sanglard 
MessageWatch(MessageDemuxer::MessageWatch && other)87*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch::MessageWatch(
88*3f982cf4SFabien Sanglard     MessageDemuxer::MessageWatch&& other) noexcept
89*3f982cf4SFabien Sanglard     : parent_(other.parent_),
90*3f982cf4SFabien Sanglard       is_default_(other.is_default_),
91*3f982cf4SFabien Sanglard       endpoint_id_(other.endpoint_id_),
92*3f982cf4SFabien Sanglard       message_type_(other.message_type_) {
93*3f982cf4SFabien Sanglard   other.parent_ = nullptr;
94*3f982cf4SFabien Sanglard }
95*3f982cf4SFabien Sanglard 
~MessageWatch()96*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch::~MessageWatch() {
97*3f982cf4SFabien Sanglard   if (parent_) {
98*3f982cf4SFabien Sanglard     if (is_default_) {
99*3f982cf4SFabien Sanglard       OSP_VLOG << "dropping default handler for type: "
100*3f982cf4SFabien Sanglard                << static_cast<int>(message_type_);
101*3f982cf4SFabien Sanglard       parent_->StopDefaultMessageTypeWatch(message_type_);
102*3f982cf4SFabien Sanglard     } else {
103*3f982cf4SFabien Sanglard       OSP_VLOG << "dropping handler for type: "
104*3f982cf4SFabien Sanglard                << static_cast<int>(message_type_);
105*3f982cf4SFabien Sanglard       parent_->StopWatchingMessageType(endpoint_id_, message_type_);
106*3f982cf4SFabien Sanglard     }
107*3f982cf4SFabien Sanglard   }
108*3f982cf4SFabien Sanglard }
109*3f982cf4SFabien Sanglard 
operator =(MessageWatch && other)110*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=(
111*3f982cf4SFabien Sanglard     MessageWatch&& other) noexcept {
112*3f982cf4SFabien Sanglard   using std::swap;
113*3f982cf4SFabien Sanglard   swap(parent_, other.parent_);
114*3f982cf4SFabien Sanglard   swap(is_default_, other.is_default_);
115*3f982cf4SFabien Sanglard   swap(endpoint_id_, other.endpoint_id_);
116*3f982cf4SFabien Sanglard   swap(message_type_, other.message_type_);
117*3f982cf4SFabien Sanglard   return *this;
118*3f982cf4SFabien Sanglard }
119*3f982cf4SFabien Sanglard 
MessageDemuxer(ClockNowFunctionPtr now_function,size_t buffer_limit=kDefaultBufferLimit)120*3f982cf4SFabien Sanglard MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function,
121*3f982cf4SFabien Sanglard                                size_t buffer_limit = kDefaultBufferLimit)
122*3f982cf4SFabien Sanglard     : now_function_(now_function), buffer_limit_(buffer_limit) {
123*3f982cf4SFabien Sanglard   OSP_DCHECK(now_function_);
124*3f982cf4SFabien Sanglard }
125*3f982cf4SFabien Sanglard 
126*3f982cf4SFabien Sanglard MessageDemuxer::~MessageDemuxer() = default;
127*3f982cf4SFabien Sanglard 
WatchMessageType(uint64_t endpoint_id,msgs::Type message_type,MessageCallback * callback)128*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch MessageDemuxer::WatchMessageType(
129*3f982cf4SFabien Sanglard     uint64_t endpoint_id,
130*3f982cf4SFabien Sanglard     msgs::Type message_type,
131*3f982cf4SFabien Sanglard     MessageCallback* callback) {
132*3f982cf4SFabien Sanglard   auto callbacks_entry = message_callbacks_.find(endpoint_id);
133*3f982cf4SFabien Sanglard   if (callbacks_entry == message_callbacks_.end()) {
134*3f982cf4SFabien Sanglard     callbacks_entry =
135*3f982cf4SFabien Sanglard         message_callbacks_
136*3f982cf4SFabien Sanglard             .emplace(endpoint_id, std::map<msgs::Type, MessageCallback*>{})
137*3f982cf4SFabien Sanglard             .first;
138*3f982cf4SFabien Sanglard   }
139*3f982cf4SFabien Sanglard   auto emplace_result = callbacks_entry->second.emplace(message_type, callback);
140*3f982cf4SFabien Sanglard   if (!emplace_result.second)
141*3f982cf4SFabien Sanglard     return MessageWatch();
142*3f982cf4SFabien Sanglard   auto endpoint_entry = buffers_.find(endpoint_id);
143*3f982cf4SFabien Sanglard   if (endpoint_entry != buffers_.end()) {
144*3f982cf4SFabien Sanglard     for (auto& buffer : endpoint_entry->second) {
145*3f982cf4SFabien Sanglard       if (buffer.second.empty())
146*3f982cf4SFabien Sanglard         continue;
147*3f982cf4SFabien Sanglard       auto buffered_type = static_cast<msgs::Type>(buffer.second[0]);
148*3f982cf4SFabien Sanglard       if (message_type == buffered_type) {
149*3f982cf4SFabien Sanglard         HandleStreamBufferLoop(endpoint_id, buffer.first, callbacks_entry,
150*3f982cf4SFabien Sanglard                                &buffer.second);
151*3f982cf4SFabien Sanglard       }
152*3f982cf4SFabien Sanglard     }
153*3f982cf4SFabien Sanglard   }
154*3f982cf4SFabien Sanglard   return MessageWatch(this, false, endpoint_id, message_type);
155*3f982cf4SFabien Sanglard }
156*3f982cf4SFabien Sanglard 
SetDefaultMessageTypeWatch(msgs::Type message_type,MessageCallback * callback)157*3f982cf4SFabien Sanglard MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch(
158*3f982cf4SFabien Sanglard     msgs::Type message_type,
159*3f982cf4SFabien Sanglard     MessageCallback* callback) {
160*3f982cf4SFabien Sanglard   auto emplace_result = default_callbacks_.emplace(message_type, callback);
161*3f982cf4SFabien Sanglard   if (!emplace_result.second)
162*3f982cf4SFabien Sanglard     return MessageWatch();
163*3f982cf4SFabien Sanglard   for (auto& endpoint_buffers : buffers_) {
164*3f982cf4SFabien Sanglard     auto endpoint_id = endpoint_buffers.first;
165*3f982cf4SFabien Sanglard     for (auto& stream_map : endpoint_buffers.second) {
166*3f982cf4SFabien Sanglard       if (stream_map.second.empty())
167*3f982cf4SFabien Sanglard         continue;
168*3f982cf4SFabien Sanglard       auto buffered_type = static_cast<msgs::Type>(stream_map.second[0]);
169*3f982cf4SFabien Sanglard       if (message_type == buffered_type) {
170*3f982cf4SFabien Sanglard         auto connection_id = stream_map.first;
171*3f982cf4SFabien Sanglard         auto callbacks_entry = message_callbacks_.find(endpoint_id);
172*3f982cf4SFabien Sanglard         HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry,
173*3f982cf4SFabien Sanglard                                &stream_map.second);
174*3f982cf4SFabien Sanglard       }
175*3f982cf4SFabien Sanglard     }
176*3f982cf4SFabien Sanglard   }
177*3f982cf4SFabien Sanglard   return MessageWatch(this, true, 0, message_type);
178*3f982cf4SFabien Sanglard }
179*3f982cf4SFabien Sanglard 
OnStreamData(uint64_t endpoint_id,uint64_t connection_id,const uint8_t * data,size_t data_size)180*3f982cf4SFabien Sanglard void MessageDemuxer::OnStreamData(uint64_t endpoint_id,
181*3f982cf4SFabien Sanglard                                   uint64_t connection_id,
182*3f982cf4SFabien Sanglard                                   const uint8_t* data,
183*3f982cf4SFabien Sanglard                                   size_t data_size) {
184*3f982cf4SFabien Sanglard   OSP_VLOG << __func__ << ": [" << endpoint_id << ", " << connection_id
185*3f982cf4SFabien Sanglard            << "] - (" << data_size << ")";
186*3f982cf4SFabien Sanglard   auto& stream_map = buffers_[endpoint_id];
187*3f982cf4SFabien Sanglard   if (!data_size) {
188*3f982cf4SFabien Sanglard     stream_map.erase(connection_id);
189*3f982cf4SFabien Sanglard     if (stream_map.empty())
190*3f982cf4SFabien Sanglard       buffers_.erase(endpoint_id);
191*3f982cf4SFabien Sanglard     return;
192*3f982cf4SFabien Sanglard   }
193*3f982cf4SFabien Sanglard   std::vector<uint8_t>& buffer = stream_map[connection_id];
194*3f982cf4SFabien Sanglard   buffer.insert(buffer.end(), data, data + data_size);
195*3f982cf4SFabien Sanglard 
196*3f982cf4SFabien Sanglard   auto callbacks_entry = message_callbacks_.find(endpoint_id);
197*3f982cf4SFabien Sanglard   HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, &buffer);
198*3f982cf4SFabien Sanglard 
199*3f982cf4SFabien Sanglard   if (buffer.size() > buffer_limit_)
200*3f982cf4SFabien Sanglard     stream_map.erase(connection_id);
201*3f982cf4SFabien Sanglard }
202*3f982cf4SFabien Sanglard 
StopWatchingMessageType(uint64_t endpoint_id,msgs::Type message_type)203*3f982cf4SFabien Sanglard void MessageDemuxer::StopWatchingMessageType(uint64_t endpoint_id,
204*3f982cf4SFabien Sanglard                                              msgs::Type message_type) {
205*3f982cf4SFabien Sanglard   auto& message_map = message_callbacks_[endpoint_id];
206*3f982cf4SFabien Sanglard   auto it = message_map.find(message_type);
207*3f982cf4SFabien Sanglard   message_map.erase(it);
208*3f982cf4SFabien Sanglard }
209*3f982cf4SFabien Sanglard 
StopDefaultMessageTypeWatch(msgs::Type message_type)210*3f982cf4SFabien Sanglard void MessageDemuxer::StopDefaultMessageTypeWatch(msgs::Type message_type) {
211*3f982cf4SFabien Sanglard   default_callbacks_.erase(message_type);
212*3f982cf4SFabien Sanglard }
213*3f982cf4SFabien Sanglard 
HandleStreamBufferLoop(uint64_t endpoint_id,uint64_t connection_id,std::map<uint64_t,std::map<msgs::Type,MessageCallback * >>::iterator callbacks_entry,std::vector<uint8_t> * buffer)214*3f982cf4SFabien Sanglard MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBufferLoop(
215*3f982cf4SFabien Sanglard     uint64_t endpoint_id,
216*3f982cf4SFabien Sanglard     uint64_t connection_id,
217*3f982cf4SFabien Sanglard     std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator
218*3f982cf4SFabien Sanglard         callbacks_entry,
219*3f982cf4SFabien Sanglard     std::vector<uint8_t>* buffer) {
220*3f982cf4SFabien Sanglard   HandleStreamBufferResult result;
221*3f982cf4SFabien Sanglard   do {
222*3f982cf4SFabien Sanglard     result = {false, 0};
223*3f982cf4SFabien Sanglard     if (callbacks_entry != message_callbacks_.end()) {
224*3f982cf4SFabien Sanglard       OSP_VLOG << "attempting endpoint-specific handling";
225*3f982cf4SFabien Sanglard       result = HandleStreamBuffer(endpoint_id, connection_id,
226*3f982cf4SFabien Sanglard                                   &callbacks_entry->second, buffer);
227*3f982cf4SFabien Sanglard     }
228*3f982cf4SFabien Sanglard     if (!result.handled) {
229*3f982cf4SFabien Sanglard       if (!default_callbacks_.empty()) {
230*3f982cf4SFabien Sanglard         OSP_VLOG << "attempting generic message handling";
231*3f982cf4SFabien Sanglard         result = HandleStreamBuffer(endpoint_id, connection_id,
232*3f982cf4SFabien Sanglard                                     &default_callbacks_, buffer);
233*3f982cf4SFabien Sanglard       }
234*3f982cf4SFabien Sanglard     }
235*3f982cf4SFabien Sanglard     OSP_VLOG_IF(!result.handled) << "no message handler matched";
236*3f982cf4SFabien Sanglard   } while (result.consumed && !buffer->empty());
237*3f982cf4SFabien Sanglard   return result;
238*3f982cf4SFabien Sanglard }
239*3f982cf4SFabien Sanglard 
240*3f982cf4SFabien Sanglard // TODO(rwkeane) Use absl::Span for the buffer
HandleStreamBuffer(uint64_t endpoint_id,uint64_t connection_id,std::map<msgs::Type,MessageCallback * > * message_callbacks,std::vector<uint8_t> * buffer)241*3f982cf4SFabien Sanglard MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBuffer(
242*3f982cf4SFabien Sanglard     uint64_t endpoint_id,
243*3f982cf4SFabien Sanglard     uint64_t connection_id,
244*3f982cf4SFabien Sanglard     std::map<msgs::Type, MessageCallback*>* message_callbacks,
245*3f982cf4SFabien Sanglard     std::vector<uint8_t>* buffer) {
246*3f982cf4SFabien Sanglard   size_t consumed = 0;
247*3f982cf4SFabien Sanglard   size_t total_consumed = 0;
248*3f982cf4SFabien Sanglard   bool handled = false;
249*3f982cf4SFabien Sanglard   do {
250*3f982cf4SFabien Sanglard     consumed = 0;
251*3f982cf4SFabien Sanglard     size_t msg_type_byte_length;
252*3f982cf4SFabien Sanglard     ErrorOr<msgs::Type> message_type =
253*3f982cf4SFabien Sanglard         MessageTypeDecoder::DecodeType(*buffer, &msg_type_byte_length);
254*3f982cf4SFabien Sanglard     if (message_type.is_error()) {
255*3f982cf4SFabien Sanglard       buffer->clear();
256*3f982cf4SFabien Sanglard       break;
257*3f982cf4SFabien Sanglard     }
258*3f982cf4SFabien Sanglard     auto callback_entry = message_callbacks->find(message_type.value());
259*3f982cf4SFabien Sanglard     if (callback_entry == message_callbacks->end())
260*3f982cf4SFabien Sanglard       break;
261*3f982cf4SFabien Sanglard     handled = true;
262*3f982cf4SFabien Sanglard     OSP_VLOG << "handling message type "
263*3f982cf4SFabien Sanglard              << static_cast<int>(message_type.value());
264*3f982cf4SFabien Sanglard     auto consumed_or_error = callback_entry->second->OnStreamMessage(
265*3f982cf4SFabien Sanglard         endpoint_id, connection_id, message_type.value(),
266*3f982cf4SFabien Sanglard         buffer->data() + msg_type_byte_length,
267*3f982cf4SFabien Sanglard         buffer->size() - msg_type_byte_length, now_function_());
268*3f982cf4SFabien Sanglard     if (!consumed_or_error) {
269*3f982cf4SFabien Sanglard       if (consumed_or_error.error().code() !=
270*3f982cf4SFabien Sanglard           Error::Code::kCborIncompleteMessage) {
271*3f982cf4SFabien Sanglard         buffer->clear();
272*3f982cf4SFabien Sanglard         break;
273*3f982cf4SFabien Sanglard       }
274*3f982cf4SFabien Sanglard     } else {
275*3f982cf4SFabien Sanglard       consumed = consumed_or_error.value();
276*3f982cf4SFabien Sanglard       buffer->erase(buffer->begin(),
277*3f982cf4SFabien Sanglard                     buffer->begin() + consumed + msg_type_byte_length);
278*3f982cf4SFabien Sanglard     }
279*3f982cf4SFabien Sanglard     total_consumed += consumed;
280*3f982cf4SFabien Sanglard   } while (consumed && !buffer->empty());
281*3f982cf4SFabien Sanglard   return HandleStreamBufferResult{handled, total_consumed};
282*3f982cf4SFabien Sanglard }
283*3f982cf4SFabien Sanglard 
StopWatching(MessageDemuxer::MessageWatch * watch)284*3f982cf4SFabien Sanglard void StopWatching(MessageDemuxer::MessageWatch* watch) {
285*3f982cf4SFabien Sanglard   *watch = MessageDemuxer::MessageWatch();
286*3f982cf4SFabien Sanglard }
287*3f982cf4SFabien Sanglard 
288*3f982cf4SFabien Sanglard }  // namespace osp
289*3f982cf4SFabien Sanglard }  // namespace openscreen
290