xref: /aosp_15_r20/external/openscreen/cast/common/channel/connection_namespace_handler_unittest.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "cast/common/channel/connection_namespace_handler.h"
6 
7 #include <string>
8 #include <utility>
9 #include <vector>
10 
11 #include "cast/common/channel/message_util.h"
12 #include "cast/common/channel/testing/fake_cast_socket.h"
13 #include "cast/common/channel/testing/mock_socket_error_handler.h"
14 #include "cast/common/channel/virtual_connection.h"
15 #include "cast/common/channel/virtual_connection_router.h"
16 #include "cast/common/public/cast_socket.h"
17 #include "gmock/gmock.h"
18 #include "gtest/gtest.h"
19 #include "util/json/json_serialization.h"
20 #include "util/json/json_value.h"
21 #include "util/osp_logging.h"
22 
23 namespace openscreen {
24 namespace cast {
25 namespace {
26 
27 using ::testing::_;
28 using ::testing::Invoke;
29 using ::testing::NiceMock;
30 
31 using ::cast::channel::CastMessage;
32 using ::cast::channel::CastMessage_ProtocolVersion;
33 
34 class MockVirtualConnectionPolicy
35     : public ConnectionNamespaceHandler::VirtualConnectionPolicy {
36  public:
37   ~MockVirtualConnectionPolicy() override = default;
38 
39   MOCK_METHOD(bool,
40               IsConnectionAllowed,
41               (const VirtualConnection& virtual_conn),
42               (const, override));
43 };
44 
MakeVersionedConnectMessage(const std::string & source_id,const std::string & destination_id,absl::optional<CastMessage_ProtocolVersion> version,std::vector<CastMessage_ProtocolVersion> version_list)45 CastMessage MakeVersionedConnectMessage(
46     const std::string& source_id,
47     const std::string& destination_id,
48     absl::optional<CastMessage_ProtocolVersion> version,
49     std::vector<CastMessage_ProtocolVersion> version_list) {
50   CastMessage connect_message = MakeConnectMessage(source_id, destination_id);
51   Json::Value message(Json::ValueType::objectValue);
52   message[kMessageKeyType] = kMessageTypeConnect;
53   if (version) {
54     message[kMessageKeyProtocolVersion] = version.value();
55   }
56   if (!version_list.empty()) {
57     Json::Value list(Json::ValueType::arrayValue);
58     for (CastMessage_ProtocolVersion v : version_list) {
59       list.append(v);
60     }
61     message[kMessageKeyProtocolVersionList] = std::move(list);
62   }
63   ErrorOr<std::string> result = json::Stringify(message);
64   OSP_DCHECK(result);
65   connect_message.set_payload_utf8(std::move(result.value()));
66   return connect_message;
67 }
68 
VerifyConnectionMessage(const CastMessage & message,const std::string & source_id,const std::string & destination_id)69 void VerifyConnectionMessage(const CastMessage& message,
70                              const std::string& source_id,
71                              const std::string& destination_id) {
72   EXPECT_EQ(message.source_id(), source_id);
73   EXPECT_EQ(message.destination_id(), destination_id);
74   EXPECT_EQ(message.namespace_(), kConnectionNamespace);
75   ASSERT_EQ(message.payload_type(),
76             ::cast::channel::CastMessage_PayloadType_STRING);
77 }
78 
ParseConnectionMessage(const CastMessage & message)79 Json::Value ParseConnectionMessage(const CastMessage& message) {
80   ErrorOr<Json::Value> result = json::Parse(message.payload_utf8());
81   OSP_CHECK(result) << message.payload_utf8();
82   return result.value();
83 }
84 
85 }  // namespace
86 
87 class ConnectionNamespaceHandlerTest : public ::testing::Test {
88  public:
SetUp()89   void SetUp() override {
90     socket_ = fake_cast_socket_pair_.socket.get();
91     router_.TakeSocket(&mock_error_handler_,
92                        std::move(fake_cast_socket_pair_.socket));
93 
94     ON_CALL(vc_policy_, IsConnectionAllowed(_))
95         .WillByDefault(
96             Invoke([](const VirtualConnection& virtual_conn) { return true; }));
97   }
98 
99  protected:
ExpectCloseMessage(MockCastSocketClient * mock_client,const std::string & source_id,const std::string & destination_id)100   void ExpectCloseMessage(MockCastSocketClient* mock_client,
101                           const std::string& source_id,
102                           const std::string& destination_id) {
103     EXPECT_CALL(*mock_client, OnMessage(_, _))
104         .WillOnce(Invoke([&source_id, &destination_id](CastSocket* socket,
105                                                        CastMessage message) {
106           VerifyConnectionMessage(message, source_id, destination_id);
107           Json::Value value = ParseConnectionMessage(message);
108           absl::optional<absl::string_view> type = MaybeGetString(
109               value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
110           ASSERT_TRUE(type) << message.payload_utf8();
111           EXPECT_EQ(type.value(), kMessageTypeClose) << message.payload_utf8();
112         }));
113   }
114 
ExpectConnectedMessage(MockCastSocketClient * mock_client,const std::string & source_id,const std::string & destination_id,absl::optional<CastMessage_ProtocolVersion> version=absl::nullopt)115   void ExpectConnectedMessage(
116       MockCastSocketClient* mock_client,
117       const std::string& source_id,
118       const std::string& destination_id,
119       absl::optional<CastMessage_ProtocolVersion> version = absl::nullopt) {
120     EXPECT_CALL(*mock_client, OnMessage(_, _))
121         .WillOnce(Invoke([&source_id, &destination_id, version](
122                              CastSocket* socket, CastMessage message) {
123           VerifyConnectionMessage(message, source_id, destination_id);
124           Json::Value value = ParseConnectionMessage(message);
125           absl::optional<absl::string_view> type = MaybeGetString(
126               value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType));
127           ASSERT_TRUE(type) << message.payload_utf8();
128           EXPECT_EQ(type.value(), kMessageTypeConnected)
129               << message.payload_utf8();
130           if (version) {
131             absl::optional<int> message_version = MaybeGetInt(
132                 value,
133                 JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion));
134             ASSERT_TRUE(message_version) << message.payload_utf8();
135             EXPECT_EQ(message_version.value(), version.value());
136           }
137         }));
138   }
139 
140   FakeCastSocketPair fake_cast_socket_pair_;
141   MockSocketErrorHandler mock_error_handler_;
142   CastSocket* socket_;
143 
144   NiceMock<MockVirtualConnectionPolicy> vc_policy_;
145   VirtualConnectionRouter router_;
146   ConnectionNamespaceHandler connection_namespace_handler_{&router_,
147                                                            &vc_policy_};
148 
149   const std::string sender_id_{"sender-5678"};
150   const std::string receiver_id_{"receiver-3245"};
151 };
152 
TEST_F(ConnectionNamespaceHandlerTest,Connect)153 TEST_F(ConnectionNamespaceHandlerTest, Connect) {
154   connection_namespace_handler_.OnMessage(
155       &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
156   EXPECT_TRUE(router_.GetConnectionData(
157       VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
158 
159   EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _))
160       .Times(0);
161 }
162 
TEST_F(ConnectionNamespaceHandlerTest,PolicyDeniesConnection)163 TEST_F(ConnectionNamespaceHandlerTest, PolicyDeniesConnection) {
164   EXPECT_CALL(vc_policy_, IsConnectionAllowed(_))
165       .WillOnce(
166           Invoke([](const VirtualConnection& virtual_conn) { return false; }));
167   ExpectCloseMessage(&fake_cast_socket_pair_.mock_peer_client, receiver_id_,
168                      sender_id_);
169   connection_namespace_handler_.OnMessage(
170       &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
171   EXPECT_FALSE(router_.GetConnectionData(
172       VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
173 }
174 
TEST_F(ConnectionNamespaceHandlerTest,ConnectWithVersion)175 TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersion) {
176   ExpectConnectedMessage(
177       &fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_,
178       ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2);
179   connection_namespace_handler_.OnMessage(
180       &router_, socket_,
181       MakeVersionedConnectMessage(
182           sender_id_, receiver_id_,
183           ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2, {}));
184   EXPECT_TRUE(router_.GetConnectionData(
185       VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
186 }
187 
TEST_F(ConnectionNamespaceHandlerTest,ConnectWithVersionList)188 TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersionList) {
189   ExpectConnectedMessage(
190       &fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_,
191       ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3);
192   connection_namespace_handler_.OnMessage(
193       &router_, socket_,
194       MakeVersionedConnectMessage(
195           sender_id_, receiver_id_,
196           ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2,
197           {::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3,
198            ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0}));
199   EXPECT_TRUE(router_.GetConnectionData(
200       VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
201 }
202 
TEST_F(ConnectionNamespaceHandlerTest,Close)203 TEST_F(ConnectionNamespaceHandlerTest, Close) {
204   connection_namespace_handler_.OnMessage(
205       &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
206   EXPECT_TRUE(router_.GetConnectionData(
207       VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
208 
209   connection_namespace_handler_.OnMessage(
210       &router_, socket_, MakeCloseMessage(sender_id_, receiver_id_));
211   EXPECT_FALSE(router_.GetConnectionData(
212       VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
213 }
214 
TEST_F(ConnectionNamespaceHandlerTest,CloseUnknown)215 TEST_F(ConnectionNamespaceHandlerTest, CloseUnknown) {
216   connection_namespace_handler_.OnMessage(
217       &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_));
218   EXPECT_TRUE(router_.GetConnectionData(
219       VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
220 
221   connection_namespace_handler_.OnMessage(
222       &router_, socket_, MakeCloseMessage(sender_id_ + "098", receiver_id_));
223   EXPECT_TRUE(router_.GetConnectionData(
224       VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()}));
225 }
226 
227 }  // namespace cast
228 }  // namespace openscreen
229