xref: /aosp_15_r20/external/openscreen/osp/public/message_demuxer_unittest.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "osp/public/message_demuxer.h"
6 
7 #include "gmock/gmock.h"
8 #include "gtest/gtest.h"
9 #include "osp/msgs/osp_messages.h"
10 #include "osp/public/testing/message_demuxer_test_support.h"
11 #include "platform/test/fake_clock.h"
12 #include "third_party/tinycbor/src/src/cbor.h"
13 
14 namespace openscreen {
15 namespace osp {
16 namespace {
17 
18 using ::testing::_;
19 using ::testing::Invoke;
20 
ConvertDecodeResult(ssize_t result)21 ErrorOr<size_t> ConvertDecodeResult(ssize_t result) {
22   if (result < 0) {
23     if (result == -CborErrorUnexpectedEOF)
24       return Error::Code::kCborIncompleteMessage;
25     else
26       return Error::Code::kCborParsing;
27   } else {
28     return result;
29   }
30 }
31 
32 class MessageDemuxerTest : public ::testing::Test {
33  protected:
SetUp()34   void SetUp() override {
35     ASSERT_TRUE(
36         msgs::EncodePresentationConnectionOpenRequest(request_, &buffer_));
37   }
38 
ExpectDecodedRequest(ssize_t decode_result,const msgs::PresentationConnectionOpenRequest & received_request)39   void ExpectDecodedRequest(
40       ssize_t decode_result,
41       const msgs::PresentationConnectionOpenRequest& received_request) {
42     ASSERT_GT(decode_result, 0);
43     EXPECT_EQ(decode_result, static_cast<ssize_t>(buffer_.size() - 2));
44     EXPECT_EQ(request_.request_id, received_request.request_id);
45     EXPECT_EQ(request_.presentation_id, received_request.presentation_id);
46     EXPECT_EQ(request_.url, received_request.url);
47   }
48 
49   const uint64_t endpoint_id_ = 13;
50   const uint64_t connection_id_ = 45;
51   FakeClock fake_clock_{Clock::time_point(std::chrono::milliseconds(1298424))};
52   msgs::CborEncodeBuffer buffer_;
53   msgs::PresentationConnectionOpenRequest request_{1, "fry-am-the-egg-man",
54                                                    "url"};
55   MockMessageCallback mock_callback_;
56   MessageDemuxer demuxer_{FakeClock::now, MessageDemuxer::kDefaultBufferLimit};
57 };
58 
59 }  // namespace
60 
TEST_F(MessageDemuxerTest,WatchStartStop)61 TEST_F(MessageDemuxerTest, WatchStartStop) {
62   MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
63       endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
64       &mock_callback_);
65   ASSERT_TRUE(watch);
66 
67   EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _, _)).Times(0);
68   demuxer_.OnStreamData(endpoint_id_ + 1, 14, buffer_.data(), buffer_.size());
69 
70   msgs::PresentationConnectionOpenRequest received_request;
71   ssize_t decode_result = 0;
72   EXPECT_CALL(
73       mock_callback_,
74       OnStreamMessage(endpoint_id_, connection_id_,
75                       msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
76       .WillOnce(Invoke([&decode_result, &received_request](
77                            uint64_t endpoint_id, uint64_t connection_id,
78                            msgs::Type message_type, const uint8_t* buffer,
79                            size_t buffer_size, Clock::time_point now) {
80         decode_result = msgs::DecodePresentationConnectionOpenRequest(
81             buffer, buffer_size, &received_request);
82         return ConvertDecodeResult(decode_result);
83       }));
84   demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
85                         buffer_.size());
86   ExpectDecodedRequest(decode_result, received_request);
87 
88   watch = MessageDemuxer::MessageWatch();
89   EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _, _)).Times(0);
90   demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
91                         buffer_.size());
92 }
93 
TEST_F(MessageDemuxerTest,BufferPartialMessage)94 TEST_F(MessageDemuxerTest, BufferPartialMessage) {
95   MockMessageCallback mock_callback_;
96   constexpr uint64_t endpoint_id_ = 13;
97 
98   MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
99       endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
100       &mock_callback_);
101   ASSERT_TRUE(watch);
102 
103   msgs::PresentationConnectionOpenRequest received_request;
104   ssize_t decode_result = 0;
105   EXPECT_CALL(
106       mock_callback_,
107       OnStreamMessage(endpoint_id_, connection_id_,
108                       msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
109       .Times(2)
110       .WillRepeatedly(Invoke([&decode_result, &received_request](
111                                  uint64_t endpoint_id, uint64_t connection_id,
112                                  msgs::Type message_type, const uint8_t* buffer,
113                                  size_t buffer_size, Clock::time_point now) {
114         decode_result = msgs::DecodePresentationConnectionOpenRequest(
115             buffer, buffer_size, &received_request);
116         return ConvertDecodeResult(decode_result);
117       }));
118   demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
119                         buffer_.size() - 3);
120   demuxer_.OnStreamData(endpoint_id_, connection_id_,
121                         buffer_.data() + buffer_.size() - 3, 3);
122   ExpectDecodedRequest(decode_result, received_request);
123 }
124 
TEST_F(MessageDemuxerTest,DefaultWatch)125 TEST_F(MessageDemuxerTest, DefaultWatch) {
126   MockMessageCallback mock_callback_;
127   constexpr uint64_t endpoint_id_ = 13;
128 
129   MessageDemuxer::MessageWatch watch = demuxer_.SetDefaultMessageTypeWatch(
130       msgs::Type::kPresentationConnectionOpenRequest, &mock_callback_);
131   ASSERT_TRUE(watch);
132 
133   msgs::PresentationConnectionOpenRequest received_request;
134   ssize_t decode_result = 0;
135   EXPECT_CALL(
136       mock_callback_,
137       OnStreamMessage(endpoint_id_, connection_id_,
138                       msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
139       .WillOnce(Invoke([&decode_result, &received_request](
140                            uint64_t endpoint_id, uint64_t connection_id,
141                            msgs::Type message_type, const uint8_t* buffer,
142                            size_t buffer_size, Clock::time_point now) {
143         decode_result = msgs::DecodePresentationConnectionOpenRequest(
144             buffer, buffer_size, &received_request);
145         return ConvertDecodeResult(decode_result);
146       }));
147   demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
148                         buffer_.size());
149   ExpectDecodedRequest(decode_result, received_request);
150 }
151 
TEST_F(MessageDemuxerTest,DefaultWatchOverridden)152 TEST_F(MessageDemuxerTest, DefaultWatchOverridden) {
153   MockMessageCallback mock_callback_global;
154   MockMessageCallback mock_callback_;
155   constexpr uint64_t endpoint_id_ = 13;
156 
157   MessageDemuxer::MessageWatch default_watch =
158       demuxer_.SetDefaultMessageTypeWatch(
159           msgs::Type::kPresentationConnectionOpenRequest,
160           &mock_callback_global);
161   ASSERT_TRUE(default_watch);
162   MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
163       endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
164       &mock_callback_);
165   ASSERT_TRUE(watch);
166 
167   msgs::PresentationConnectionOpenRequest received_request;
168   ssize_t decode_result = 0;
169   EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _, _)).Times(0);
170   EXPECT_CALL(
171       mock_callback_global,
172       OnStreamMessage(endpoint_id_ + 1, 14,
173                       msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
174       .WillOnce(Invoke([&decode_result, &received_request](
175                            uint64_t endpoint_id, uint64_t connection_id,
176                            msgs::Type message_type, const uint8_t* buffer,
177                            size_t buffer_size, Clock::time_point now) {
178         decode_result = msgs::DecodePresentationConnectionOpenRequest(
179             buffer, buffer_size, &received_request);
180         return ConvertDecodeResult(decode_result);
181       }));
182   demuxer_.OnStreamData(endpoint_id_ + 1, 14, buffer_.data(), buffer_.size());
183   ExpectDecodedRequest(decode_result, received_request);
184 
185   decode_result = 0;
186   EXPECT_CALL(
187       mock_callback_,
188       OnStreamMessage(endpoint_id_, connection_id_,
189                       msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
190       .WillOnce(Invoke([&decode_result, &received_request](
191                            uint64_t endpoint_id, uint64_t connection_id,
192                            msgs::Type message_type, const uint8_t* buffer,
193                            size_t buffer_size, Clock::time_point now) {
194         decode_result = msgs::DecodePresentationConnectionOpenRequest(
195             buffer, buffer_size, &received_request);
196         return ConvertDecodeResult(decode_result);
197       }));
198   demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
199                         buffer_.size());
200   ExpectDecodedRequest(decode_result, received_request);
201 }
202 
TEST_F(MessageDemuxerTest,WatchAfterData)203 TEST_F(MessageDemuxerTest, WatchAfterData) {
204   msgs::PresentationConnectionOpenRequest received_request;
205   ssize_t decode_result = 0;
206   EXPECT_CALL(
207       mock_callback_,
208       OnStreamMessage(endpoint_id_, connection_id_,
209                       msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
210       .WillOnce(Invoke([&decode_result, &received_request](
211                            uint64_t endpoint_id, uint64_t connection_id,
212                            msgs::Type message_type, const uint8_t* buffer,
213                            size_t buffer_size, Clock::time_point now) {
214         decode_result = msgs::DecodePresentationConnectionOpenRequest(
215             buffer, buffer_size, &received_request);
216         return ConvertDecodeResult(decode_result);
217       }));
218   MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
219       endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
220       &mock_callback_);
221   ASSERT_TRUE(watch);
222 
223   demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
224                         buffer_.size());
225   ExpectDecodedRequest(decode_result, received_request);
226 }
227 
TEST_F(MessageDemuxerTest,WatchAfterMultipleData)228 TEST_F(MessageDemuxerTest, WatchAfterMultipleData) {
229   MockMessageCallback mock_init_callback;
230   msgs::PresentationConnectionOpenRequest received_request;
231   msgs::PresentationStartRequest received_init_request;
232   ssize_t decode_result1 = 0;
233   ssize_t decode_result2 = 0;
234   MessageDemuxer::MessageWatch init_watch = demuxer_.WatchMessageType(
235       endpoint_id_, msgs::Type::kPresentationStartRequest, &mock_init_callback);
236   EXPECT_CALL(
237       mock_callback_,
238       OnStreamMessage(endpoint_id_, connection_id_,
239                       msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
240       .WillOnce(Invoke([&decode_result1, &received_request](
241                            uint64_t endpoint_id, uint64_t connection_id,
242                            msgs::Type message_type, const uint8_t* buffer,
243                            size_t buffer_size, Clock::time_point now) {
244         decode_result1 = msgs::DecodePresentationConnectionOpenRequest(
245             buffer, buffer_size, &received_request);
246         return ConvertDecodeResult(decode_result1);
247       }));
248   EXPECT_CALL(mock_init_callback,
249               OnStreamMessage(endpoint_id_, connection_id_,
250                               msgs::Type::kPresentationStartRequest, _, _, _))
251       .WillOnce(Invoke([&decode_result2, &received_init_request](
252                            uint64_t endpoint_id, uint64_t connection_id,
253                            msgs::Type message_type, const uint8_t* buffer,
254                            size_t buffer_size, Clock::time_point now) {
255         decode_result2 = msgs::DecodePresentationStartRequest(
256             buffer, buffer_size, &received_init_request);
257         return ConvertDecodeResult(decode_result2);
258       }));
259   MessageDemuxer::MessageWatch watch = demuxer_.WatchMessageType(
260       endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
261       &mock_callback_);
262   ASSERT_TRUE(watch);
263 
264   demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
265                         buffer_.size());
266 
267   msgs::CborEncodeBuffer buffer;
268   msgs::PresentationStartRequest request;
269   request.request_id = 2;
270   request.url = "https://example.com/recv";
271   ASSERT_TRUE(msgs::EncodePresentationStartRequest(request, &buffer));
272   demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer.data(),
273                         buffer.size());
274 
275   ExpectDecodedRequest(decode_result1, received_request);
276   ASSERT_GT(decode_result2, 0);
277   EXPECT_EQ(decode_result2, static_cast<ssize_t>(buffer.size() - 2));
278   EXPECT_EQ(request.request_id, received_init_request.request_id);
279   EXPECT_EQ(request.url, received_init_request.url);
280 }
281 
TEST_F(MessageDemuxerTest,GlobalWatchAfterData)282 TEST_F(MessageDemuxerTest, GlobalWatchAfterData) {
283   msgs::PresentationConnectionOpenRequest received_request;
284   ssize_t decode_result = 0;
285   EXPECT_CALL(
286       mock_callback_,
287       OnStreamMessage(endpoint_id_, connection_id_,
288                       msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
289       .WillOnce(Invoke([&decode_result, &received_request](
290                            uint64_t endpoint_id, uint64_t connection_id,
291                            msgs::Type message_type, const uint8_t* buffer,
292                            size_t buffer_size, Clock::time_point now) {
293         decode_result = msgs::DecodePresentationConnectionOpenRequest(
294             buffer, buffer_size, &received_request);
295         return ConvertDecodeResult(decode_result);
296       }));
297   MessageDemuxer::MessageWatch watch = demuxer_.SetDefaultMessageTypeWatch(
298       msgs::Type::kPresentationConnectionOpenRequest, &mock_callback_);
299   ASSERT_TRUE(watch);
300   demuxer_.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
301                         buffer_.size());
302   ExpectDecodedRequest(decode_result, received_request);
303 }
304 
TEST_F(MessageDemuxerTest,BufferLimit)305 TEST_F(MessageDemuxerTest, BufferLimit) {
306   MessageDemuxer demuxer(FakeClock::now, 10);
307 
308   demuxer.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
309                        buffer_.size());
310   EXPECT_CALL(mock_callback_, OnStreamMessage(_, _, _, _, _, _)).Times(0);
311   MessageDemuxer::MessageWatch watch = demuxer.WatchMessageType(
312       endpoint_id_, msgs::Type::kPresentationConnectionOpenRequest,
313       &mock_callback_);
314 
315   msgs::PresentationConnectionOpenRequest received_request;
316   ssize_t decode_result = 0;
317   EXPECT_CALL(
318       mock_callback_,
319       OnStreamMessage(endpoint_id_, connection_id_,
320                       msgs::Type::kPresentationConnectionOpenRequest, _, _, _))
321       .WillOnce(Invoke([&decode_result, &received_request](
322                            uint64_t endpoint_id, uint64_t connection_id,
323                            msgs::Type message_type, const uint8_t* buffer,
324                            size_t buffer_size, Clock::time_point now) {
325         decode_result = msgs::DecodePresentationConnectionOpenRequest(
326             buffer, buffer_size, &received_request);
327         return ConvertDecodeResult(decode_result);
328       }));
329   demuxer.OnStreamData(endpoint_id_, connection_id_, buffer_.data(),
330                        buffer_.size());
331   ExpectDecodedRequest(decode_result, received_request);
332 }
333 
TEST_F(MessageDemuxerTest,DeserializeMessages)334 TEST_F(MessageDemuxerTest, DeserializeMessages) {
335   std::vector<uint8_t> kAgentInfoResponseSerialized{0x0B, 0xFF};
336   std::vector<uint8_t> kPresentationConnectionCloseEventSerialized{0x40, 0x71,
337                                                                    0x00};
338   std::vector<uint8_t> kAuthenticationRequestSerialized{0x43, 0xE9, 0xFF, 0x00};
339 
340   size_t used_bytes;
341   auto kAgentInfoResponseInfo =
342       MessageTypeDecoder::DecodeType(kAgentInfoResponseSerialized, &used_bytes);
343   EXPECT_FALSE(kAgentInfoResponseInfo.is_error());
344   EXPECT_EQ(used_bytes, size_t{1});
345   EXPECT_EQ(kAgentInfoResponseInfo.value(), msgs::Type::kAgentInfoResponse);
346 
347   auto kPresentationConnectionCloseEventInfo = MessageTypeDecoder::DecodeType(
348       kPresentationConnectionCloseEventSerialized, &used_bytes);
349   EXPECT_FALSE(kPresentationConnectionCloseEventInfo.is_error());
350   EXPECT_EQ(used_bytes, size_t{2});
351   EXPECT_EQ(kPresentationConnectionCloseEventInfo.value(),
352             msgs::Type::kPresentationConnectionCloseEvent);
353 
354   auto kAuthenticationRequestInfo = MessageTypeDecoder::DecodeType(
355       kAuthenticationRequestSerialized, &used_bytes);
356   EXPECT_FALSE(kAuthenticationRequestInfo.is_error());
357   EXPECT_EQ(used_bytes, size_t{2});
358   EXPECT_EQ(kAuthenticationRequestInfo.value(),
359             msgs::Type::kAuthenticationRequest);
360 
361   auto kUnknownInfo = MessageTypeDecoder::DecodeType({0xFF}, &used_bytes);
362   EXPECT_TRUE(kUnknownInfo.is_error());
363 }
364 
365 }  // namespace osp
366 }  // namespace openscreen
367