1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "cast/common/channel/virtual_connection_router.h"
6
7 #include <utility>
8
9 #include "cast/common/channel/connection_namespace_handler.h"
10 #include "cast/common/channel/message_util.h"
11 #include "cast/common/channel/proto/cast_channel.pb.h"
12 #include "cast/common/channel/testing/fake_cast_socket.h"
13 #include "cast/common/channel/testing/mock_cast_message_handler.h"
14 #include "cast/common/channel/testing/mock_socket_error_handler.h"
15 #include "cast/common/public/cast_socket.h"
16 #include "gtest/gtest.h"
17
18 namespace openscreen {
19 namespace cast {
20 namespace {
21
22 static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0 ==
23 static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_0),
24 "V2 1.0 constants must be equal");
25 static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_1 ==
26 static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_1),
27 "V2 1.1 constants must be equal");
28 static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2 ==
29 static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_2),
30 "V2 1.2 constants must be equal");
31 static_assert(::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3 ==
32 static_cast<int>(VirtualConnection::ProtocolVersion::kV2_1_3),
33 "V2 1.3 constants must be equal");
34
35 using ::cast::channel::CastMessage;
36 using ::testing::_;
37 using ::testing::Invoke;
38 using ::testing::SaveArg;
39 using ::testing::WithArg;
40
41 class VirtualConnectionRouterTest : public ::testing::Test {
42 public:
SetUp()43 void SetUp() override {
44 local_socket_ = fake_cast_socket_pair_.socket.get();
45 local_router_.TakeSocket(&mock_error_handler_,
46 std::move(fake_cast_socket_pair_.socket));
47
48 remote_socket_ = fake_cast_socket_pair_.peer_socket.get();
49 remote_router_.TakeSocket(&mock_error_handler_,
50 std::move(fake_cast_socket_pair_.peer_socket));
51 }
52
53 protected:
54 FakeCastSocketPair fake_cast_socket_pair_;
55 CastSocket* local_socket_;
56 CastSocket* remote_socket_;
57
58 MockSocketErrorHandler mock_error_handler_;
59
60 VirtualConnectionRouter local_router_;
61 VirtualConnectionRouter remote_router_;
62
63 VirtualConnection vc1_{"local1", "peer1", 75};
64 VirtualConnection vc2_{"local2", "peer2", 76};
65 VirtualConnection vc3_{"local1", "peer3", 75};
66 };
67
68 } // namespace
69
TEST_F(VirtualConnectionRouterTest,NoConnections)70 TEST_F(VirtualConnectionRouterTest, NoConnections) {
71 EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
72 EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
73 EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
74 }
75
TEST_F(VirtualConnectionRouterTest,AddConnections)76 TEST_F(VirtualConnectionRouterTest, AddConnections) {
77 VirtualConnection::AssociatedData data1 = {};
78
79 local_router_.AddConnection(vc1_, std::move(data1));
80 EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
81 EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
82 EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
83
84 VirtualConnection::AssociatedData data2 = {};
85 local_router_.AddConnection(vc2_, std::move(data2));
86 EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
87 EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
88 EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
89
90 VirtualConnection::AssociatedData data3 = {};
91 local_router_.AddConnection(vc3_, std::move(data3));
92 EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
93 EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
94 EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
95 }
96
TEST_F(VirtualConnectionRouterTest,RemoveConnections)97 TEST_F(VirtualConnectionRouterTest, RemoveConnections) {
98 VirtualConnection::AssociatedData data1 = {};
99 VirtualConnection::AssociatedData data2 = {};
100 VirtualConnection::AssociatedData data3 = {};
101
102 local_router_.AddConnection(vc1_, std::move(data1));
103 local_router_.AddConnection(vc2_, std::move(data2));
104 local_router_.AddConnection(vc3_, std::move(data3));
105
106 EXPECT_TRUE(local_router_.RemoveConnection(
107 vc1_, VirtualConnection::CloseReason::kClosedBySelf));
108 EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
109 EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
110 EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
111
112 EXPECT_TRUE(local_router_.RemoveConnection(
113 vc2_, VirtualConnection::CloseReason::kClosedBySelf));
114 EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
115 EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
116 EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
117
118 EXPECT_TRUE(local_router_.RemoveConnection(
119 vc3_, VirtualConnection::CloseReason::kClosedBySelf));
120 EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
121 EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
122 EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
123
124 EXPECT_FALSE(local_router_.RemoveConnection(
125 vc1_, VirtualConnection::CloseReason::kClosedBySelf));
126 EXPECT_FALSE(local_router_.RemoveConnection(
127 vc2_, VirtualConnection::CloseReason::kClosedBySelf));
128 EXPECT_FALSE(local_router_.RemoveConnection(
129 vc3_, VirtualConnection::CloseReason::kClosedBySelf));
130 }
131
TEST_F(VirtualConnectionRouterTest,RemoveConnectionsByIds)132 TEST_F(VirtualConnectionRouterTest, RemoveConnectionsByIds) {
133 VirtualConnection::AssociatedData data1 = {};
134 VirtualConnection::AssociatedData data2 = {};
135 VirtualConnection::AssociatedData data3 = {};
136
137 local_router_.AddConnection(vc1_, std::move(data1));
138 local_router_.AddConnection(vc2_, std::move(data2));
139 local_router_.AddConnection(vc3_, std::move(data3));
140
141 local_router_.RemoveConnectionsByLocalId("local1");
142 EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
143 EXPECT_TRUE(local_router_.GetConnectionData(vc2_));
144 EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
145
146 data1 = {};
147 data2 = {};
148 data3 = {};
149 local_router_.AddConnection(vc1_, std::move(data1));
150 local_router_.AddConnection(vc2_, std::move(data2));
151 local_router_.AddConnection(vc3_, std::move(data3));
152 local_router_.RemoveConnectionsBySocketId(76);
153 EXPECT_TRUE(local_router_.GetConnectionData(vc1_));
154 EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
155 EXPECT_TRUE(local_router_.GetConnectionData(vc3_));
156
157 local_router_.RemoveConnectionsBySocketId(75);
158 EXPECT_FALSE(local_router_.GetConnectionData(vc1_));
159 EXPECT_FALSE(local_router_.GetConnectionData(vc2_));
160 EXPECT_FALSE(local_router_.GetConnectionData(vc3_));
161 }
162
TEST_F(VirtualConnectionRouterTest,LocalIdHandler)163 TEST_F(VirtualConnectionRouterTest, LocalIdHandler) {
164 MockCastMessageHandler mock_message_handler;
165 local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler);
166 local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873",
167 local_socket_->socket_id()},
168 {});
169
170 CastMessage message;
171 message.set_protocol_version(
172 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
173 message.set_namespace_("zrqvn");
174 message.set_source_id("sender-9873");
175 message.set_destination_id("receiver-1234");
176 message.set_payload_type(CastMessage::STRING);
177 message.set_payload_utf8("cnlybnq");
178 EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _));
179 EXPECT_TRUE(remote_socket_->Send(message).ok());
180
181 EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _));
182 EXPECT_TRUE(remote_socket_->Send(message).ok());
183
184 message.set_destination_id("receiver-4321");
185 EXPECT_CALL(mock_message_handler, OnMessage(_, _, _)).Times(0);
186 EXPECT_TRUE(remote_socket_->Send(message).ok());
187
188 local_router_.RemoveHandlerForLocalId("receiver-1234");
189 }
190
TEST_F(VirtualConnectionRouterTest,RemoveLocalIdHandler)191 TEST_F(VirtualConnectionRouterTest, RemoveLocalIdHandler) {
192 MockCastMessageHandler mock_message_handler;
193 local_router_.AddHandlerForLocalId("receiver-1234", &mock_message_handler);
194 local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-9873",
195 local_socket_->socket_id()},
196 {});
197
198 CastMessage message;
199 message.set_protocol_version(
200 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
201 message.set_namespace_("zrqvn");
202 message.set_source_id("sender-9873");
203 message.set_destination_id("receiver-1234");
204 message.set_payload_type(CastMessage::STRING);
205 message.set_payload_utf8("cnlybnq");
206 EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _));
207 EXPECT_TRUE(remote_socket_->Send(message).ok());
208
209 local_router_.RemoveHandlerForLocalId("receiver-1234");
210
211 EXPECT_CALL(mock_message_handler, OnMessage(_, local_socket_, _)).Times(0);
212 EXPECT_TRUE(remote_socket_->Send(message).ok());
213
214 local_router_.RemoveHandlerForLocalId("receiver-1234");
215 }
216
TEST_F(VirtualConnectionRouterTest,SendMessage)217 TEST_F(VirtualConnectionRouterTest, SendMessage) {
218 local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321",
219 local_socket_->socket_id()},
220 {});
221
222 MockCastMessageHandler destination;
223 remote_router_.AddHandlerForLocalId("sender-4321", &destination);
224 remote_router_.AddConnection(VirtualConnection{"sender-4321", "receiver-1234",
225 remote_socket_->socket_id()},
226 {});
227
228 CastMessage message;
229 message.set_protocol_version(
230 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
231 message.set_namespace_("zrqvn");
232 message.set_source_id("receiver-1234");
233 message.set_destination_id("sender-4321");
234 message.set_payload_type(CastMessage::STRING);
235 message.set_payload_utf8("cnlybnq");
236 ASSERT_TRUE(message.IsInitialized());
237
238 EXPECT_CALL(destination, OnMessage(&remote_router_, remote_socket_, _))
239 .WillOnce(
240 WithArg<2>(Invoke([&message](CastMessage message_at_destination) {
241 ASSERT_TRUE(message_at_destination.IsInitialized());
242 EXPECT_EQ(message.SerializeAsString(),
243 message_at_destination.SerializeAsString());
244 })));
245 local_router_.Send(VirtualConnection{"receiver-1234", "sender-4321",
246 local_socket_->socket_id()},
247 message);
248 }
249
TEST_F(VirtualConnectionRouterTest,CloseSocketRemovesVirtualConnections)250 TEST_F(VirtualConnectionRouterTest, CloseSocketRemovesVirtualConnections) {
251 local_router_.AddConnection(VirtualConnection{"receiver-1234", "sender-4321",
252 local_socket_->socket_id()},
253 {});
254
255 EXPECT_CALL(mock_error_handler_, OnClose(local_socket_)).Times(1);
256
257 int id = local_socket_->socket_id();
258 local_router_.CloseSocket(id);
259 EXPECT_FALSE(local_router_.GetConnectionData(
260 VirtualConnection{"receiver-1234", "sender-4321", id}));
261 }
262
263 // Tests that VirtualConnectionRouter::Send() broadcasts a message from a local
264 // source to both: 1) all other local peers; and 2) all remote peers.
TEST_F(VirtualConnectionRouterTest,BroadcastsFromLocalSource)265 TEST_F(VirtualConnectionRouterTest, BroadcastsFromLocalSource) {
266 // Local peers.
267 MockCastMessageHandler alice, bob;
268 local_router_.AddHandlerForLocalId("alice", &alice);
269 local_router_.AddHandlerForLocalId("bob", &bob);
270
271 // Remote peers.
272 MockCastMessageHandler charlie, dave, eve;
273 remote_router_.AddHandlerForLocalId("charlie", &charlie);
274 remote_router_.AddHandlerForLocalId("dave", &dave);
275 remote_router_.AddHandlerForLocalId("eve", &eve);
276
277 // The local broadcaster, which should never receive her own messages.
278 MockCastMessageHandler wendy;
279 local_router_.AddHandlerForLocalId("wendy", &wendy);
280 EXPECT_CALL(wendy, OnMessage(_, _, _)).Times(0);
281
282 CastMessage message;
283 message.set_protocol_version(
284 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
285 message.set_namespace_("zrqvn");
286 message.set_payload_type(CastMessage::STRING);
287 message.set_payload_utf8("cnlybnq");
288
289 CastMessage message_alice_got, message_bob_got, message_charlie_got,
290 message_dave_got, message_eve_got;
291 EXPECT_CALL(alice, OnMessage(&local_router_, nullptr, _))
292 .WillOnce(SaveArg<2>(&message_alice_got))
293 .RetiresOnSaturation();
294 EXPECT_CALL(bob, OnMessage(&local_router_, nullptr, _))
295 .WillOnce(SaveArg<2>(&message_bob_got))
296 .RetiresOnSaturation();
297 EXPECT_CALL(charlie, OnMessage(&remote_router_, remote_socket_, _))
298 .WillOnce(SaveArg<2>(&message_charlie_got))
299 .RetiresOnSaturation();
300 EXPECT_CALL(dave, OnMessage(&remote_router_, remote_socket_, _))
301 .WillOnce(SaveArg<2>(&message_dave_got))
302 .RetiresOnSaturation();
303 EXPECT_CALL(eve, OnMessage(&remote_router_, remote_socket_, _))
304 .WillOnce(SaveArg<2>(&message_eve_got))
305 .RetiresOnSaturation();
306 ASSERT_TRUE(local_router_.BroadcastFromLocalPeer("wendy", message).ok());
307
308 // Confirm message data is correct.
309 message.set_source_id("wendy");
310 message.set_destination_id(kBroadcastId);
311 ASSERT_TRUE(message.IsInitialized());
312 ASSERT_TRUE(message_alice_got.IsInitialized());
313 EXPECT_EQ(message.SerializeAsString(), message_alice_got.SerializeAsString());
314 ASSERT_TRUE(message_bob_got.IsInitialized());
315 EXPECT_EQ(message.SerializeAsString(), message_bob_got.SerializeAsString());
316 ASSERT_TRUE(message_charlie_got.IsInitialized());
317 EXPECT_EQ(message.SerializeAsString(),
318 message_charlie_got.SerializeAsString());
319 ASSERT_TRUE(message_dave_got.IsInitialized());
320 EXPECT_EQ(message.SerializeAsString(), message_dave_got.SerializeAsString());
321 ASSERT_TRUE(message_eve_got.IsInitialized());
322 EXPECT_EQ(message.SerializeAsString(), message_eve_got.SerializeAsString());
323
324 // Remove one local peer and one remote peer, and confirm only the correct
325 // entities receive a broadcast message.
326 local_router_.RemoveHandlerForLocalId("bob");
327 remote_router_.RemoveHandlerForLocalId("charlie");
328 EXPECT_CALL(alice, OnMessage(&local_router_, nullptr, _)).Times(1);
329 EXPECT_CALL(bob, OnMessage(_, _, _)).Times(0);
330 EXPECT_CALL(charlie, OnMessage(_, _, _)).Times(0);
331 EXPECT_CALL(dave, OnMessage(&remote_router_, remote_socket_, _)).Times(1);
332 EXPECT_CALL(eve, OnMessage(&remote_router_, remote_socket_, _)).Times(1);
333 ASSERT_TRUE(local_router_.BroadcastFromLocalPeer("wendy", message).ok());
334 }
335
336 // Tests that VirtualConnectionRouter::OnMessage() broadcasts a message from a
337 // remote source to all local peers.
TEST_F(VirtualConnectionRouterTest,BroadcastsFromRemoteSource)338 TEST_F(VirtualConnectionRouterTest, BroadcastsFromRemoteSource) {
339 // Local peers.
340 MockCastMessageHandler alice, bob, charlie;
341 local_router_.AddHandlerForLocalId("alice", &alice);
342 local_router_.AddHandlerForLocalId("bob", &bob);
343 local_router_.AddHandlerForLocalId("charlie", &charlie);
344
345 // The remote broadcaster, which should never receive her own messages.
346 MockCastMessageHandler wendy;
347 remote_router_.AddHandlerForLocalId("wendy", &wendy);
348 EXPECT_CALL(wendy, OnMessage(_, _, _)).Times(0);
349
350 CastMessage message;
351 message.set_protocol_version(
352 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
353 message.set_namespace_("zrqvn");
354 message.set_payload_type(CastMessage::STRING);
355 message.set_payload_utf8("cnlybnq");
356
357 CastMessage message_alice_got, message_bob_got, message_charlie_got;
358 EXPECT_CALL(alice, OnMessage(&local_router_, local_socket_, _))
359 .WillOnce(SaveArg<2>(&message_alice_got))
360 .RetiresOnSaturation();
361 EXPECT_CALL(bob, OnMessage(&local_router_, local_socket_, _))
362 .WillOnce(SaveArg<2>(&message_bob_got))
363 .RetiresOnSaturation();
364 EXPECT_CALL(charlie, OnMessage(&local_router_, local_socket_, _))
365 .WillOnce(SaveArg<2>(&message_charlie_got))
366 .RetiresOnSaturation();
367 ASSERT_TRUE(remote_router_.BroadcastFromLocalPeer("wendy", message).ok());
368
369 // Confirm message data is correct.
370 message.set_source_id("wendy");
371 message.set_destination_id(kBroadcastId);
372 ASSERT_TRUE(message.IsInitialized());
373 ASSERT_TRUE(message_alice_got.IsInitialized());
374 EXPECT_EQ(message.SerializeAsString(), message_alice_got.SerializeAsString());
375 ASSERT_TRUE(message_bob_got.IsInitialized());
376 EXPECT_EQ(message.SerializeAsString(), message_bob_got.SerializeAsString());
377 ASSERT_TRUE(message_charlie_got.IsInitialized());
378 EXPECT_EQ(message.SerializeAsString(),
379 message_charlie_got.SerializeAsString());
380
381 // Remove one local peer, and confirm only the two remaining local peers
382 // receive a broadcast message from the remote source.
383 local_router_.RemoveHandlerForLocalId("bob");
384 EXPECT_CALL(alice, OnMessage(&local_router_, local_socket_, _)).Times(1);
385 EXPECT_CALL(bob, OnMessage(_, _, _)).Times(0);
386 EXPECT_CALL(charlie, OnMessage(&local_router_, local_socket_, _)).Times(1);
387 ASSERT_TRUE(remote_router_.BroadcastFromLocalPeer("wendy", message).ok());
388 }
389
390 // Tests that the VirtualConnectionRouter treats kConnectionNamespace messages
391 // as a special case. The details of this are described in the implementation of
392 // VirtualConnectionRouter::OnMessage().
TEST_F(VirtualConnectionRouterTest,HandlesConnectionMessagesAsSpecialCase)393 TEST_F(VirtualConnectionRouterTest, HandlesConnectionMessagesAsSpecialCase) {
394 class MockConnectionNamespaceHandler final
395 : public ConnectionNamespaceHandler,
396 public ConnectionNamespaceHandler::VirtualConnectionPolicy {
397 public:
398 explicit MockConnectionNamespaceHandler(VirtualConnectionRouter* vc_router)
399 : ConnectionNamespaceHandler(vc_router, this) {}
400 ~MockConnectionNamespaceHandler() final = default;
401 MOCK_METHOD(void,
402 OnMessage,
403 (VirtualConnectionRouter * router,
404 CastSocket* socket,
405 ::cast::channel::CastMessage message),
406 (final));
407 bool IsConnectionAllowed(
408 const VirtualConnection& virtual_conn) const final {
409 return true;
410 }
411 };
412 MockConnectionNamespaceHandler connection_handler(&local_router_);
413
414 MockCastMessageHandler alice;
415 local_router_.AddHandlerForLocalId("alice", &alice);
416
417 CastMessage message;
418 message.set_protocol_version(
419 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0);
420 message.set_source_id(kPlatformSenderId);
421 message.set_destination_id("alice");
422 message.set_namespace_(kConnectionNamespace);
423
424 CastMessage message_received;
425 EXPECT_CALL(connection_handler, OnMessage(&local_router_, local_socket_, _))
426 .WillOnce(SaveArg<2>(&message_received))
427 .RetiresOnSaturation();
428 EXPECT_CALL(alice, OnMessage(_, _, _)).Times(0);
429 local_router_.OnMessage(local_socket_, message);
430
431 EXPECT_EQ(kPlatformSenderId, message.source_id());
432 EXPECT_EQ("alice", message.destination_id());
433 EXPECT_EQ(kConnectionNamespace, message.namespace_());
434 }
435
436 } // namespace cast
437 } // namespace openscreen
438