1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "cast/common/channel/connection_namespace_handler.h"
6
7 #include <string>
8 #include <utility>
9 #include <vector>
10
11 #include "cast/common/channel/message_util.h"
12 #include "cast/common/channel/testing/fake_cast_socket.h"
13 #include "cast/common/channel/testing/mock_socket_error_handler.h"
14 #include "cast/common/channel/virtual_connection.h"
15 #include "cast/common/channel/virtual_connection_router.h"
16 #include "cast/common/public/cast_socket.h"
17 #include "gmock/gmock.h"
18 #include "gtest/gtest.h"
19 #include "util/json/json_serialization.h"
20 #include "util/json/json_value.h"
21 #include "util/osp_logging.h"
22
23 namespace openscreen {
24 namespace cast {
25 namespace {
26
27 using ::testing::_;
28 using ::testing::Invoke;
29 using ::testing::NiceMock;
30
31 using ::cast::channel::CastMessage;
32 using ::cast::channel::CastMessage_ProtocolVersion;
33
34 class MockVirtualConnectionPolicy
35 : public ConnectionNamespaceHandler::VirtualConnectionPolicy {
36 public:
37 ~MockVirtualConnectionPolicy() override = default;
38
39 MOCK_METHOD(bool,
40 IsConnectionAllowed,
41 (const VirtualConnection& virtual_conn),
42 (const, override));
43 };
44
MakeVersionedConnectMessage(const std::string & source_id,const std::string & destination_id,absl::optional<CastMessage_ProtocolVersion> version,std::vector<CastMessage_ProtocolVersion> version_list)45 CastMessage MakeVersionedConnectMessage(
46 const std::string& source_id,
47 const std::string& destination_id,
48 absl::optional<CastMessage_ProtocolVersion> version,
49 std::vector<CastMessage_ProtocolVersion> version_list) {
50 CastMessage connect_message = MakeConnectMessage(source_id, destination_id);
51 Json::Value message(Json::ValueType::objectValue);
52 message[kMessageKeyType] = kMessageTypeConnect;
53 if (version) {
54 message[kMessageKeyProtocolVersion] = version.value();
55 }
56 if (!version_list.empty()) {
57 Json::Value list(Json::ValueType::arrayValue);
58 for (CastMessage_ProtocolVersion v : version_list) {
59 list.append(v);
60 }
61 message[kMessageKeyProtocolVersionList] = std::move(list);
62 }
63 ErrorOr<std::string> result = json::Stringify(message);
64 OSP_DCHECK(result);
65 connect_message.set_payload_utf8(std::move(result.value()));
66 return connect_message;
67 }
68
VerifyConnectionMessage(const CastMessage & message,const std::string & source_id,const std::string & destination_id)69 void VerifyConnectionMessage(const CastMessage& message,
70 const std::string& source_id,
71 const std::string& destination_id) {
72 EXPECT_EQ(message.source_id(), source_id);
73 EXPECT_EQ(message.destination_id(), destination_id);
74 EXPECT_EQ(message.namespace_(), kConnectionNamespace);
75 ASSERT_EQ(message.payload_type(),
76 ::cast::channel::CastMessage_PayloadType_STRING);
77 }
78
ParseConnectionMessage(const CastMessage & message)79 Json::Value ParseConnectionMessage(const CastMessage& message) {
80 ErrorOr<Json::Value> result = json::Parse(message.payload_utf8());
81 OSP_CHECK(result) << message.payload_utf8();
82 return result.value();
83 }
84
85 } // namespace
86
87 class ConnectionNamespaceHandlerTest : public ::testing::Test {
88 public:
SetUp()89 void SetUp() override {
90 socket_ = fake_cast_socket_pair_.socket.get();
91 router_.TakeSocket(&mock_error_handler_,
92 std::move(fake_cast_socket_pair_.socket));
93
94 ON_CALL(vc_policy_, IsConnectionAllowed(_))
95 .WillByDefault(
96 Invoke([](const VirtualConnection& virtual_conn) { return true; }));
97 }
98
99 protected:
ExpectCloseMessage(MockCastSocketClient * mock_client,const std::string & source_id,const std::string & destination_id)100 void ExpectCloseMessage(MockCastSocketClient* mock_client,
101 const std::string& source_id,
102 const std::string& destination_id) {
103 EXPECT_CALL(*mock_client, OnMessage(_, _))
104 .WillOnce(Invoke([&source_id, &destination_id](CastSocket* socket,
105 CastMessage message) {
106 VerifyConnectionMessage(message, source_id, destination_id);
107 Json::Value value = ParseConnectionMessage(message);
108 absl::optional<absl::string_view> type = MaybeGetString(
109 value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
110 ASSERT_TRUE(type) << message.payload_utf8();
111 EXPECT_EQ(type.value(), kMessageTypeClose) << message.payload_utf8();
112 }));
113 }
114
ExpectConnectedMessage(MockCastSocketClient * mock_client,const std::string & source_id,const std::string & destination_id,absl::optional<CastMessage_ProtocolVersion> version=absl::nullopt)115 void ExpectConnectedMessage(
116 MockCastSocketClient* mock_client,
117 const std::string& source_id,
118 const std::string& destination_id,
119 absl::optional<CastMessage_ProtocolVersion> version = absl::nullopt) {
120 EXPECT_CALL(*mock_client, OnMessage(_, _))
121 .WillOnce(Invoke([&source_id, &destination_id, version](
122 CastSocket* socket, CastMessage message) {
123 VerifyConnectionMessage(message, source_id, destination_id);
124 Json::Value value = ParseConnectionMessage(message);
125 absl::optional<absl::string_view> type = MaybeGetString(
126 value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
127 ASSERT_TRUE(type) << message.payload_utf8();
128 EXPECT_EQ(type.value(), kMessageTypeConnected)
129 << message.payload_utf8();
130 if (version) {
131 absl::optional<int> message_version = MaybeGetInt(
132 value,
133 JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion));
134 ASSERT_TRUE(message_version) << message.payload_utf8();
135 EXPECT_EQ(message_version.value(), version.value());
136 }
137 }));
138 }
139
140 FakeCastSocketPair fake_cast_socket_pair_;
141 MockSocketErrorHandler mock_error_handler_;
142 CastSocket* socket_;
143
144 NiceMock<MockVirtualConnectionPolicy> vc_policy_;
145 VirtualConnectionRouter router_;
146 ConnectionNamespaceHandler connection_namespace_handler_{&router_,
147 &vc_policy_};
148
149 const std::string sender_id_{"sender-5678"};
150 const std::string receiver_id_{"receiver-3245"};
151 };
152
TEST_F(ConnectionNamespaceHandlerTest,Connect)153 TEST_F(ConnectionNamespaceHandlerTest, Connect) {
154 connection_namespace_handler_.OnMessage(
155 &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
156 EXPECT_TRUE(router_.GetConnectionData(
157 VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
158
159 EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _))
160 .Times(0);
161 }
162
TEST_F(ConnectionNamespaceHandlerTest,PolicyDeniesConnection)163 TEST_F(ConnectionNamespaceHandlerTest, PolicyDeniesConnection) {
164 EXPECT_CALL(vc_policy_, IsConnectionAllowed(_))
165 .WillOnce(
166 Invoke([](const VirtualConnection& virtual_conn) { return false; }));
167 ExpectCloseMessage(&fake_cast_socket_pair_.mock_peer_client, receiver_id_,
168 sender_id_);
169 connection_namespace_handler_.OnMessage(
170 &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
171 EXPECT_FALSE(router_.GetConnectionData(
172 VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
173 }
174
TEST_F(ConnectionNamespaceHandlerTest,ConnectWithVersion)175 TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersion) {
176 ExpectConnectedMessage(
177 &fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_,
178 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2);
179 connection_namespace_handler_.OnMessage(
180 &router_, socket_,
181 MakeVersionedConnectMessage(
182 sender_id_, receiver_id_,
183 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2, {}));
184 EXPECT_TRUE(router_.GetConnectionData(
185 VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
186 }
187
TEST_F(ConnectionNamespaceHandlerTest,ConnectWithVersionList)188 TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersionList) {
189 ExpectConnectedMessage(
190 &fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_,
191 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3);
192 connection_namespace_handler_.OnMessage(
193 &router_, socket_,
194 MakeVersionedConnectMessage(
195 sender_id_, receiver_id_,
196 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2,
197 {::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3,
198 ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0}));
199 EXPECT_TRUE(router_.GetConnectionData(
200 VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
201 }
202
TEST_F(ConnectionNamespaceHandlerTest,Close)203 TEST_F(ConnectionNamespaceHandlerTest, Close) {
204 connection_namespace_handler_.OnMessage(
205 &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
206 EXPECT_TRUE(router_.GetConnectionData(
207 VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
208
209 connection_namespace_handler_.OnMessage(
210 &router_, socket_, MakeCloseMessage(sender_id_, receiver_id_));
211 EXPECT_FALSE(router_.GetConnectionData(
212 VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
213 }
214
TEST_F(ConnectionNamespaceHandlerTest,CloseUnknown)215 TEST_F(ConnectionNamespaceHandlerTest, CloseUnknown) {
216 connection_namespace_handler_.OnMessage(
217 &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
218 EXPECT_TRUE(router_.GetConnectionData(
219 VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
220
221 connection_namespace_handler_.OnMessage(
222 &router_, socket_, MakeCloseMessage(sender_id_ + "098", receiver_id_));
223 EXPECT_TRUE(router_.GetConnectionData(
224 VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
225 }
226
227 } // namespace cast
228 } // namespace openscreen
229