xref: /aosp_15_r20/external/openscreen/cast/common/channel/virtual_connection_router_unittest.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "cast/common/channel/virtual_connection_router.h"
6 
7 #include <utility>
8 
9 #include "cast/common/channel/connection_namespace_handler.h"
10 #include "cast/common/channel/message_util.h"
11 #include "cast/common/channel/proto/cast_channel.pb.h"
12 #include "cast/common/channel/testing/fake_cast_socket.h"
13 #include "cast/common/channel/testing/mock_cast_message_handler.h"
14 #include "cast/common/channel/testing/mock_socket_error_handler.h"
15 #include "cast/common/public/cast_socket.h"
16 #include "gtest/gtest.h"
17 
18 namespace openscreen {
19 namespace cast {
20 namespace {
21 
22 static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0 ==
23                   static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_0),
24               "V2 1.0 constants must be equal");
25 static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_1 ==
26                   static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_1),
27               "V2 1.1 constants must be equal");
28 static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2 ==
29                   static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_2),
30               "V2 1.2 constants must be equal");
31 static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3 ==
32                   static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_3),
33               "V2 1.3 constants must be equal");
34 
35 using ::cast::channel::CastMessage;
36 using ::testing::_;
37 using ::testing::Invoke;
38 using ::testing::SaveArg;
39 using ::testing::WithArg;
40 
41 class VirtualConnectionRouterTest : public ::testing::Test {
42  public:
SetUp()43   void SetUp() override {
44     local_socket_ = fake_cast_socket_pair_.socket.get();
45     local_router_.TakeSocket(&mock_error_handler_,
46                              std::move(fake_cast_socket_pair_.socket));
47 
48     remote_socket_ = fake_cast_socket_pair_.peer_socket.get();
49     remote_router_.TakeSocket(&mock_error_handler_,
50                               std::move(fake_cast_socket_pair_.peer_socket));
51   }
52 
53  protected:
54   FakeCastSocketPair fake_cast_socket_pair_;
55   CastSocket* local_socket_;
56   CastSocket* remote_socket_;
57 
58   MockSocketErrorHandler mock_error_handler_;
59 
60   VirtualConnectionRouter local_router_;
61   VirtualConnectionRouter remote_router_;
62 
63   VirtualConnection vc1_{"local1", "peer1", 75};
64   VirtualConnection vc2_{"local2", "peer2", 76};
65   VirtualConnection vc3_{"local1", "peer3", 75};
66 };
67 
68 }  // namespace
69 
TEST_F(VirtualConnectionRouterTest,NoConnections)70 TEST_F(VirtualConnectionRouterTest, NoConnections) {
71   EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
72   EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
73   EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
74 }
75 
TEST_F(VirtualConnectionRouterTest,AddConnections)76 TEST_F(VirtualConnectionRouterTest, AddConnections) {
77   VirtualConnection::AssociatedData data1 = {};
78 
79   local_router_.AddConnection(vc1_, std::move(data1));
80   EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
81   EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
82   EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
83 
84   VirtualConnection::AssociatedData data2 = {};
85   local_router_.AddConnection(vc2_, std::move(data2));
86   EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
87   EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
88   EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
89 
90   VirtualConnection::AssociatedData data3 = {};
91   local_router_.AddConnection(vc3_, std::move(data3));
92   EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
93   EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
94   EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
95 }
96 
TEST_F(VirtualConnectionRouterTest,RemoveConnections)97 TEST_F(VirtualConnectionRouterTest, RemoveConnections) {
98   VirtualConnection::AssociatedData data1 = {};
99   VirtualConnection::AssociatedData data2 = {};
100   VirtualConnection::AssociatedData data3 = {};
101 
102   local_router_.AddConnection(vc1_, std::move(data1));
103   local_router_.AddConnection(vc2_, std::move(data2));
104   local_router_.AddConnection(vc3_, std::move(data3));
105 
106   EXPECT_TRUE(local_router_.RemoveConnection(
107       vc1_, VirtualConnection::CloseReason::kClosedBySelf));
108   EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
109   EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
110   EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
111 
112   EXPECT_TRUE(local_router_.RemoveConnection(
113       vc2_, VirtualConnection::CloseReason::kClosedBySelf));
114   EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
115   EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
116   EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
117 
118   EXPECT_TRUE(local_router_.RemoveConnection(
119       vc3_, VirtualConnection::CloseReason::kClosedBySelf));
120   EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
121   EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
122   EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
123 
124   EXPECT_FALSE(local_router_.RemoveConnection(
125       vc1_, VirtualConnection::CloseReason::kClosedBySelf));
126   EXPECT_FALSE(local_router_.RemoveConnection(
127       vc2_, VirtualConnection::CloseReason::kClosedBySelf));
128   EXPECT_FALSE(local_router_.RemoveConnection(
129       vc3_, VirtualConnection::CloseReason::kClosedBySelf));
130 }
131 
TEST_F(VirtualConnectionRouterTest,RemoveConnectionsByIds)132 TEST_F(VirtualConnectionRouterTest, RemoveConnectionsByIds) {
133   VirtualConnection::AssociatedData data1 = {};
134   VirtualConnection::AssociatedData data2 = {};
135   VirtualConnection::AssociatedData data3 = {};
136 
137   local_router_.AddConnection(vc1_, std::move(data1));
138   local_router_.AddConnection(vc2_, std::move(data2));
139   local_router_.AddConnection(vc3_, std::move(data3));
140 
141   local_router_.RemoveConnectionsByLocalId("local1");
142   EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
143   EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
144   EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
145 
146   data1 = {};
147   data2 = {};
148   data3 = {};
149   local_router_.AddConnection(vc1_, std::move(data1));
150   local_router_.AddConnection(vc2_, std::move(data2));
151   local_router_.AddConnection(vc3_, std::move(data3));
152   local_router_.RemoveConnectionsBySocketId(76);
153   EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
154   EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
155   EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
156 
157   local_router_.RemoveConnectionsBySocketId(75);
158   EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
159   EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
160   EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
161 }
162 
TEST_F(VirtualConnectionRouterTest,LocalIdHandler)163 TEST_F(VirtualConnectionRouterTest, LocalIdHandler) {
164   MockCastMessageHandler mock_message_handler;
165   local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler);
166   local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873",
167                                                 local_socket_->socket_id()},
168                               {});
169 
170   CastMessage message;
171   message.set_protocol_version(
172       ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
173   message.set_namespace_("zrqvn");
174   message.set_source_id("sender-9873");
175   message.set_destination_id("receiver-1234");
176   message.set_payload_type(CastMessage::STRING);
177   message.set_payload_utf8("cnlybnq");
178   EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _));
179   EXPECT_TRUE(remote_socket_->Send(message).ok());
180 
181   EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _));
182   EXPECT_TRUE(remote_socket_->Send(message).ok());
183 
184   message.set_destination_id("receiver-4321");
185   EXPECT_CALL(mock_message_handler, OnMessage(_, _, _)).Times(0);
186   EXPECT_TRUE(remote_socket_->Send(message).ok());
187 
188   local_router_.RemoveHandlerForLocalId("receiver-1234");
189 }
190 
TEST_F(VirtualConnectionRouterTest,RemoveLocalIdHandler)191 TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) {
192   MockCastMessageHandler mock_message_handler;
193   local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler);
194   local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873",
195                                                 local_socket_->socket_id()},
196                               {});
197 
198   CastMessage message;
199   message.set_protocol_version(
200       ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
201   message.set_namespace_("zrqvn");
202   message.set_source_id("sender-9873");
203   message.set_destination_id("receiver-1234");
204   message.set_payload_type(CastMessage::STRING);
205   message.set_payload_utf8("cnlybnq");
206   EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _));
207   EXPECT_TRUE(remote_socket_->Send(message).ok());
208 
209   local_router_.RemoveHandlerForLocalId("receiver-1234");
210 
211   EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _)).Times(0);
212   EXPECT_TRUE(remote_socket_->Send(message).ok());
213 
214   local_router_.RemoveHandlerForLocalId("receiver-1234");
215 }
216 
TEST_F(VirtualConnectionRouterTest,SendMessage)217 TEST_F(VirtualConnectionRouterTest, SendMessage) {
218   local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321",
219                                                 local_socket_->socket_id()},
220                               {});
221 
222   MockCastMessageHandler destination;
223   remote_router_.AddHandlerForLocalId("sender-4321", &destination);
224   remote_router_.AddConnection(VirtualConnection{"sender-4321", "receiver-1234",
225                                                  remote_socket_->socket_id()},
226                                {});
227 
228   CastMessage message;
229   message.set_protocol_version(
230       ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
231   message.set_namespace_("zrqvn");
232   message.set_source_id("receiver-1234");
233   message.set_destination_id("sender-4321");
234   message.set_payload_type(CastMessage::STRING);
235   message.set_payload_utf8("cnlybnq");
236   ASSERT_TRUE(message.IsInitialized());
237 
238   EXPECT_CALL(destination, OnMessage(&remote_router_, remote_socket_, _))
239       .WillOnce(
240           WithArg<2>(Invoke([&message](CastMessage message_at_destination) {
241             ASSERT_TRUE(message_at_destination.IsInitialized());
242             EXPECT_EQ(message.SerializeAsString(),
243                       message_at_destination.SerializeAsString());
244           })));
245   local_router_.Send(VirtualConnection{"receiver-1234", "sender-4321",
246                                        local_socket_->socket_id()},
247                      message);
248 }
249 
TEST_F(VirtualConnectionRouterTest,CloseSocketRemovesVirtualConnections)250 TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) {
251   local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321",
252                                                 local_socket_->socket_id()},
253                               {});
254 
255   EXPECT_CALL(mock_error_handler_, OnClose(local_socket_)).Times(1);
256 
257   int id = local_socket_->socket_id();
258   local_router_.CloseSocket(id);
259   EXPECT_FALSE(local_router_.GetConnectionData(
260       VirtualConnection{"receiver-1234", "sender-4321", id}));
261 }
262 
263 // Tests that VirtualConnectionRouter::Send() broadcasts a message from a local
264 // source to both: 1) all other local peers; and 2) all remote peers.
TEST_F(VirtualConnectionRouterTest,BroadcastsFromLocalSource)265 TEST_F(VirtualConnectionRouterTest, BroadcastsFromLocalSource) {
266   // Local peers.
267   MockCastMessageHandler alice, bob;
268   local_router_.AddHandlerForLocalId("alice", &alice);
269   local_router_.AddHandlerForLocalId("bob", &bob);
270 
271   // Remote peers.
272   MockCastMessageHandler charlie, dave, eve;
273   remote_router_.AddHandlerForLocalId("charlie", &charlie);
274   remote_router_.AddHandlerForLocalId("dave", &dave);
275   remote_router_.AddHandlerForLocalId("eve", &eve);
276 
277   // The local broadcaster, which should never receive her own messages.
278   MockCastMessageHandler wendy;
279   local_router_.AddHandlerForLocalId("wendy", &wendy);
280   EXPECT_CALL(wendy, OnMessage(_, _, _)).Times(0);
281 
282   CastMessage message;
283   message.set_protocol_version(
284       ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
285   message.set_namespace_("zrqvn");
286   message.set_payload_type(CastMessage::STRING);
287   message.set_payload_utf8("cnlybnq");
288 
289   CastMessage message_alice_got, message_bob_got, message_charlie_got,
290       message_dave_got, message_eve_got;
291   EXPECT_CALL(alice, OnMessage(&local_router_, nullptr, _))
292       .WillOnce(SaveArg<2>(&message_alice_got))
293       .RetiresOnSaturation();
294   EXPECT_CALL(bob, OnMessage(&local_router_, nullptr, _))
295       .WillOnce(SaveArg<2>(&message_bob_got))
296       .RetiresOnSaturation();
297   EXPECT_CALL(charlie, OnMessage(&remote_router_, remote_socket_, _))
298       .WillOnce(SaveArg<2>(&message_charlie_got))
299       .RetiresOnSaturation();
300   EXPECT_CALL(dave, OnMessage(&remote_router_, remote_socket_, _))
301       .WillOnce(SaveArg<2>(&message_dave_got))
302       .RetiresOnSaturation();
303   EXPECT_CALL(eve, OnMessage(&remote_router_, remote_socket_, _))
304       .WillOnce(SaveArg<2>(&message_eve_got))
305       .RetiresOnSaturation();
306   ASSERT_TRUE(local_router_.BroadcastFromLocalPeer("wendy", message).ok());
307 
308   // Confirm message data is correct.
309   message.set_source_id("wendy");
310   message.set_destination_id(kBroadcastId);
311   ASSERT_TRUE(message.IsInitialized());
312   ASSERT_TRUE(message_alice_got.IsInitialized());
313   EXPECT_EQ(message.SerializeAsString(), message_alice_got.SerializeAsString());
314   ASSERT_TRUE(message_bob_got.IsInitialized());
315   EXPECT_EQ(message.SerializeAsString(), message_bob_got.SerializeAsString());
316   ASSERT_TRUE(message_charlie_got.IsInitialized());
317   EXPECT_EQ(message.SerializeAsString(),
318             message_charlie_got.SerializeAsString());
319   ASSERT_TRUE(message_dave_got.IsInitialized());
320   EXPECT_EQ(message.SerializeAsString(), message_dave_got.SerializeAsString());
321   ASSERT_TRUE(message_eve_got.IsInitialized());
322   EXPECT_EQ(message.SerializeAsString(), message_eve_got.SerializeAsString());
323 
324   // Remove one local peer and one remote peer, and confirm only the correct
325   // entities receive a broadcast message.
326   local_router_.RemoveHandlerForLocalId("bob");
327   remote_router_.RemoveHandlerForLocalId("charlie");
328   EXPECT_CALL(alice, OnMessage(&local_router_, nullptr, _)).Times(1);
329   EXPECT_CALL(bob, OnMessage(_, _, _)).Times(0);
330   EXPECT_CALL(charlie, OnMessage(_, _, _)).Times(0);
331   EXPECT_CALL(dave, OnMessage(&remote_router_, remote_socket_, _)).Times(1);
332   EXPECT_CALL(eve, OnMessage(&remote_router_, remote_socket_, _)).Times(1);
333   ASSERT_TRUE(local_router_.BroadcastFromLocalPeer("wendy", message).ok());
334 }
335 
336 // Tests that VirtualConnectionRouter::OnMessage() broadcasts a message from a
337 // remote source to all local peers.
TEST_F(VirtualConnectionRouterTest,BroadcastsFromRemoteSource)338 TEST_F(VirtualConnectionRouterTest, BroadcastsFromRemoteSource) {
339   // Local peers.
340   MockCastMessageHandler alice, bob, charlie;
341   local_router_.AddHandlerForLocalId("alice", &alice);
342   local_router_.AddHandlerForLocalId("bob", &bob);
343   local_router_.AddHandlerForLocalId("charlie", &charlie);
344 
345   // The remote broadcaster, which should never receive her own messages.
346   MockCastMessageHandler wendy;
347   remote_router_.AddHandlerForLocalId("wendy", &wendy);
348   EXPECT_CALL(wendy, OnMessage(_, _, _)).Times(0);
349 
350   CastMessage message;
351   message.set_protocol_version(
352       ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
353   message.set_namespace_("zrqvn");
354   message.set_payload_type(CastMessage::STRING);
355   message.set_payload_utf8("cnlybnq");
356 
357   CastMessage message_alice_got, message_bob_got, message_charlie_got;
358   EXPECT_CALL(alice, OnMessage(&local_router_, local_socket_, _))
359       .WillOnce(SaveArg<2>(&message_alice_got))
360       .RetiresOnSaturation();
361   EXPECT_CALL(bob, OnMessage(&local_router_, local_socket_, _))
362       .WillOnce(SaveArg<2>(&message_bob_got))
363       .RetiresOnSaturation();
364   EXPECT_CALL(charlie, OnMessage(&local_router_, local_socket_, _))
365       .WillOnce(SaveArg<2>(&message_charlie_got))
366       .RetiresOnSaturation();
367   ASSERT_TRUE(remote_router_.BroadcastFromLocalPeer("wendy", message).ok());
368 
369   // Confirm message data is correct.
370   message.set_source_id("wendy");
371   message.set_destination_id(kBroadcastId);
372   ASSERT_TRUE(message.IsInitialized());
373   ASSERT_TRUE(message_alice_got.IsInitialized());
374   EXPECT_EQ(message.SerializeAsString(), message_alice_got.SerializeAsString());
375   ASSERT_TRUE(message_bob_got.IsInitialized());
376   EXPECT_EQ(message.SerializeAsString(), message_bob_got.SerializeAsString());
377   ASSERT_TRUE(message_charlie_got.IsInitialized());
378   EXPECT_EQ(message.SerializeAsString(),
379             message_charlie_got.SerializeAsString());
380 
381   // Remove one local peer, and confirm only the two remaining local peers
382   // receive a broadcast message from the remote source.
383   local_router_.RemoveHandlerForLocalId("bob");
384   EXPECT_CALL(alice, OnMessage(&local_router_, local_socket_, _)).Times(1);
385   EXPECT_CALL(bob, OnMessage(_, _, _)).Times(0);
386   EXPECT_CALL(charlie, OnMessage(&local_router_, local_socket_, _)).Times(1);
387   ASSERT_TRUE(remote_router_.BroadcastFromLocalPeer("wendy", message).ok());
388 }
389 
390 // Tests that the VirtualConnectionRouter treats kConnectionNamespace messages
391 // as a special case. The details of this are described in the implementation of
392 // VirtualConnectionRouter::OnMessage().
TEST_F(VirtualConnectionRouterTest,HandlesConnectionMessagesAsSpecialCase)393 TEST_F(VirtualConnectionRouterTest, HandlesConnectionMessagesAsSpecialCase) {
394   class MockConnectionNamespaceHandler final
395       : public ConnectionNamespaceHandler,
396         public ConnectionNamespaceHandler::VirtualConnectionPolicy {
397    public:
398     explicit MockConnectionNamespaceHandler(VirtualConnectionRouter* vc_router)
399         : ConnectionNamespaceHandler(vc_router, this) {}
400     ~MockConnectionNamespaceHandler() final = default;
401     MOCK_METHOD(void,
402                 OnMessage,
403                 (VirtualConnectionRouter * router,
404                  CastSocket* socket,
405                  ::cast::channel::CastMessage message),
406                 (final));
407     bool IsConnectionAllowed(
408         const VirtualConnection& virtual_conn) const final {
409       return true;
410     }
411   };
412   MockConnectionNamespaceHandler connection_handler(&local_router_);
413 
414   MockCastMessageHandler alice;
415   local_router_.AddHandlerForLocalId("alice", &alice);
416 
417   CastMessage message;
418   message.set_protocol_version(
419       ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
420   message.set_source_id(kPlatformSenderId);
421   message.set_destination_id("alice");
422   message.set_namespace_(kConnectionNamespace);
423 
424   CastMessage message_received;
425   EXPECT_CALL(connection_handler, OnMessage(&local_router_, local_socket_, _))
426       .WillOnce(SaveArg<2>(&message_received))
427       .RetiresOnSaturation();
428   EXPECT_CALL(alice, OnMessage(_, _, _)).Times(0);
429   local_router_.OnMessage(local_socket_, message);
430 
431   EXPECT_EQ(kPlatformSenderId, message.source_id());
432   EXPECT_EQ("alice", message.destination_id());
433   EXPECT_EQ(kConnectionNamespace, message.namespace_());
434 }
435 
436 }  // namespace cast
437 }  // namespace openscreen
438