1 /*
2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <pw_allocator/allocator.h>
18 #include <pw_allocator/capability.h>
19 #include <pw_allocator/unique_ptr.h>
20 #include <cstddef>
21 #include <cstdint>
22 #include <optional>
23 #include <utility>
24
25 #include "chre/util/dynamic_vector.h"
26 #include "chre/util/system/message_common.h"
27 #include "chre/util/system/message_router.h"
28 #include "chre/util/system/message_router_callback_allocator.h"
29 #include "chre_api/chre.h"
30 #include "gtest/gtest.h"
31
32 namespace chre::message {
33 namespace {
34
35 constexpr size_t kMaxMessageHubs = 3;
36 constexpr size_t kMaxSessions = 10;
37 constexpr size_t kMaxFreeCallbackRecords = kMaxSessions * 2;
38 constexpr size_t kNumEndpoints = 3;
39
40 const EndpointInfo kEndpointInfos[kNumEndpoints] = {
41 EndpointInfo(/* id= */ 1, /* name= */ "endpoint1", /* version= */ 1,
42 EndpointType::NANOAPP, CHRE_MESSAGE_PERMISSION_NONE),
43 EndpointInfo(/* id= */ 2, /* name= */ "endpoint2", /* version= */ 10,
44 EndpointType::HOST_NATIVE, CHRE_MESSAGE_PERMISSION_BLE),
45 EndpointInfo(/* id= */ 3, /* name= */ "endpoint3", /* version= */ 100,
46 EndpointType::GENERIC, CHRE_MESSAGE_PERMISSION_AUDIO)};
47
48 class TestAllocator : public pw::Allocator {
49 public:
50 static constexpr Capabilities kCapabilities = 0;
51
TestAllocator()52 TestAllocator() : pw::Allocator(kCapabilities) {}
53
DoAllocate(Layout layout)54 virtual void *DoAllocate(Layout layout) override {
55 if (layout.alignment() > alignof(std::max_align_t)) {
56 void *ptr;
57 return posix_memalign(&ptr, layout.alignment(), layout.size()) == 0
58 ? ptr
59 : nullptr;
60 } else {
61 return malloc(layout.size());
62 }
63 }
64
DoDeallocate(void * ptr)65 virtual void DoDeallocate(void *ptr) override {
66 free(ptr);
67 }
68 };
69
70 class MessageRouterTest : public ::testing::Test {
71 protected:
SetUp()72 void SetUp() override {}
73
74 TestAllocator mAllocator;
75 };
76
77 //! Base class for MessageHubCallbacks used in tests
78 class MessageHubCallbackBase : public MessageRouter::MessageHubCallback {
79 public:
forEachEndpoint(const pw::Function<bool (const EndpointInfo &)> & function)80 void forEachEndpoint(
81 const pw::Function<bool(const EndpointInfo &)> &function) override {
82 for (const EndpointInfo &endpointInfo : kEndpointInfos) {
83 if (function(endpointInfo)) {
84 return;
85 }
86 }
87 }
88
getEndpointInfo(EndpointId endpointId)89 std::optional<EndpointInfo> getEndpointInfo(EndpointId endpointId) override {
90 for (const EndpointInfo &endpointInfo : kEndpointInfos) {
91 if (endpointInfo.id == endpointId) {
92 return endpointInfo;
93 }
94 }
95 return std::nullopt;
96 }
97 };
98
99 //! MessageHubCallback that stores the data passed to onMessageReceived and
100 //! onSessionClosed
101 class MessageHubCallbackStoreData : public MessageHubCallbackBase {
102 public:
MessageHubCallbackStoreData(Message * message,Session * session)103 MessageHubCallbackStoreData(Message *message, Session *session)
104 : mMessage(message), mSession(session) {}
105
onMessageReceived(pw::UniquePtr<std::byte[]> && data,size_t length,uint32_t messageType,uint32_t messagePermissions,const Session & session,bool sentBySessionInitiator)106 bool onMessageReceived(pw::UniquePtr<std::byte[]> &&data, size_t length,
107 uint32_t messageType, uint32_t messagePermissions,
108 const Session &session,
109 bool sentBySessionInitiator) override {
110 if (mMessage != nullptr) {
111 mMessage->sender = sentBySessionInitiator ? session.initiator
112 : session.peer;
113 mMessage->recipient =
114 sentBySessionInitiator ? session.peer : session.initiator;
115 mMessage->sessionId = session.sessionId;
116 mMessage->data = std::move(data);
117 mMessage->length = length;
118 mMessage->messageType = messageType;
119 mMessage->messagePermissions = messagePermissions;
120 }
121 return true;
122 }
123
onSessionClosed(const Session & session)124 void onSessionClosed(const Session &session) override {
125 if (mSession != nullptr) {
126 *mSession = session;
127 }
128 }
129
130 private:
131 Message *mMessage;
132 Session *mSession;
133 };
134
135 //! MessageHubCallback that always fails to process messages
136 class MessageHubCallbackAlwaysFails : public MessageHubCallbackBase {
137 public:
MessageHubCallbackAlwaysFails(bool * wasMessageReceivedCalled,bool * wasSessionClosedCalled)138 MessageHubCallbackAlwaysFails(bool *wasMessageReceivedCalled,
139 bool *wasSessionClosedCalled)
140 : mWasMessageReceivedCalled(wasMessageReceivedCalled),
141 mWasSessionClosedCalled(wasSessionClosedCalled) {}
142
onMessageReceived(pw::UniquePtr<std::byte[]> &&,size_t,uint32_t,uint32_t,const Session &,bool)143 bool onMessageReceived(pw::UniquePtr<std::byte[]> && /* data */,
144 size_t /* length */, uint32_t /* messageType */,
145 uint32_t /* messagePermissions */,
146 const Session & /* session */,
147 bool /* sentBySessionInitiator */) override {
148 if (mWasMessageReceivedCalled != nullptr) {
149 *mWasMessageReceivedCalled = true;
150 }
151 return false;
152 }
153
onSessionClosed(const Session &)154 void onSessionClosed(const Session & /* session */) override {
155 if (mWasSessionClosedCalled != nullptr) {
156 *mWasSessionClosedCalled = true;
157 }
158 }
159
160 private:
161 bool *mWasMessageReceivedCalled;
162 bool *mWasSessionClosedCalled;
163 };
164
165 //! MessageHubCallback that calls MessageHub APIs during callbacks
166 class MessageHubCallbackCallsMessageHubApisDuringCallback
167 : public MessageHubCallbackBase {
168 public:
onMessageReceived(pw::UniquePtr<std::byte[]> &&,size_t,uint32_t,uint32_t,const Session &,bool)169 bool onMessageReceived(pw::UniquePtr<std::byte[]> && /* data */,
170 size_t /* length */, uint32_t /* messageType */,
171 uint32_t /* messagePermissions */,
172 const Session & /* session */,
173 bool /* sentBySessionInitiator */) override {
174 if (mMessageHub != nullptr) {
175 // Call a function that locks the MessageRouter mutex
176 mMessageHub->openSession(kEndpointInfos[0].id, mMessageHub->getId(),
177 kEndpointInfos[1].id);
178 }
179 return true;
180 }
181
onSessionClosed(const Session &)182 void onSessionClosed(const Session & /* session */) override {
183 if (mMessageHub != nullptr) {
184 // Call a function that locks the MessageRouter mutex
185 mMessageHub->openSession(kEndpointInfos[0].id, mMessageHub->getId(),
186 kEndpointInfos[1].id);
187 }
188 }
189
setMessageHub(MessageRouter::MessageHub * messageHub)190 void setMessageHub(MessageRouter::MessageHub *messageHub) {
191 mMessageHub = messageHub;
192 }
193
194 private:
195 MessageRouter::MessageHub *mMessageHub = nullptr;
196 };
197
TEST_F(MessageRouterTest,RegisterMessageHubNameIsUnique)198 TEST_F(MessageRouterTest, RegisterMessageHubNameIsUnique) {
199 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
200
201 MessageHubCallbackStoreData callback(/* message= */ nullptr,
202 /* session= */ nullptr);
203 std::optional<MessageRouter::MessageHub> messageHub1 =
204 router.registerMessageHub("hub1", /* id= */ 1, callback);
205 EXPECT_TRUE(messageHub1.has_value());
206 std::optional<MessageRouter::MessageHub> messageHub2 =
207 router.registerMessageHub("hub2", /* id= */ 2, callback);
208 EXPECT_TRUE(messageHub2.has_value());
209
210 std::optional<MessageRouter::MessageHub> messageHub3 =
211 router.registerMessageHub("hub1", /* id= */ 1, callback);
212 EXPECT_FALSE(messageHub3.has_value());
213 }
214
TEST_F(MessageRouterTest,RegisterMessageHubIdIsUnique)215 TEST_F(MessageRouterTest, RegisterMessageHubIdIsUnique) {
216 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
217
218 MessageHubCallbackStoreData callback(/* message= */ nullptr,
219 /* session= */ nullptr);
220 std::optional<MessageRouter::MessageHub> messageHub1 =
221 router.registerMessageHub("hub1", /* id= */ 1, callback);
222 EXPECT_TRUE(messageHub1.has_value());
223 std::optional<MessageRouter::MessageHub> messageHub2 =
224 router.registerMessageHub("hub2", /* id= */ 2, callback);
225 EXPECT_TRUE(messageHub2.has_value());
226
227 std::optional<MessageRouter::MessageHub> messageHub3 =
228 router.registerMessageHub("hub3", /* id= */ 1, callback);
229 EXPECT_FALSE(messageHub3.has_value());
230 }
231
TEST_F(MessageRouterTest,RegisterMessageHubGetListOfHubs)232 TEST_F(MessageRouterTest, RegisterMessageHubGetListOfHubs) {
233 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
234
235 MessageHubCallbackStoreData callback(/* message= */ nullptr,
236 /* session= */ nullptr);
237 std::optional<MessageRouter::MessageHub> messageHub1 =
238 router.registerMessageHub("hub1", /* id= */ 1, callback);
239 EXPECT_TRUE(messageHub1.has_value());
240 std::optional<MessageRouter::MessageHub> messageHub2 =
241 router.registerMessageHub("hub2", /* id= */ 2, callback);
242 EXPECT_TRUE(messageHub2.has_value());
243 std::optional<MessageRouter::MessageHub> messageHub3 =
244 router.registerMessageHub("hub3", /* id= */ 3, callback);
245 EXPECT_TRUE(messageHub3.has_value());
246
247 DynamicVector<MessageHubInfo> messageHubs;
248 router.forEachMessageHub(
249 [&messageHubs](const MessageHubInfo &messageHubInfo) {
250 messageHubs.push_back(messageHubInfo);
251 return false;
252 });
253 EXPECT_EQ(messageHubs.size(), 3);
254 EXPECT_EQ(messageHubs[0].name, "hub1");
255 EXPECT_EQ(messageHubs[1].name, "hub2");
256 EXPECT_EQ(messageHubs[2].name, "hub3");
257 EXPECT_EQ(messageHubs[0].id, 1);
258 EXPECT_EQ(messageHubs[1].id, 2);
259 EXPECT_EQ(messageHubs[2].id, 3);
260 EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
261 EXPECT_EQ(messageHubs[1].id, messageHub2->getId());
262 EXPECT_EQ(messageHubs[2].id, messageHub3->getId());
263 }
264
TEST_F(MessageRouterTest,RegisterMessageHubGetListOfHubsWithUnregister)265 TEST_F(MessageRouterTest, RegisterMessageHubGetListOfHubsWithUnregister) {
266 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
267
268 MessageHubCallbackStoreData callback(/* message= */ nullptr,
269 /* session= */ nullptr);
270 std::optional<MessageRouter::MessageHub> messageHub1 =
271 router.registerMessageHub("hub1", /* id= */ 1, callback);
272 EXPECT_TRUE(messageHub1.has_value());
273 std::optional<MessageRouter::MessageHub> messageHub2 =
274 router.registerMessageHub("hub2", /* id= */ 2, callback);
275 EXPECT_TRUE(messageHub2.has_value());
276 std::optional<MessageRouter::MessageHub> messageHub3 =
277 router.registerMessageHub("hub3", /* id= */ 3, callback);
278 EXPECT_TRUE(messageHub3.has_value());
279
280 DynamicVector<MessageHubInfo> messageHubs;
281 router.forEachMessageHub(
282 [&messageHubs](const MessageHubInfo &messageHubInfo) {
283 messageHubs.push_back(messageHubInfo);
284 return false;
285 });
286 EXPECT_EQ(messageHubs.size(), 3);
287 EXPECT_EQ(messageHubs[0].name, "hub1");
288 EXPECT_EQ(messageHubs[1].name, "hub2");
289 EXPECT_EQ(messageHubs[2].name, "hub3");
290 EXPECT_EQ(messageHubs[0].id, 1);
291 EXPECT_EQ(messageHubs[1].id, 2);
292 EXPECT_EQ(messageHubs[2].id, 3);
293 EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
294 EXPECT_EQ(messageHubs[1].id, messageHub2->getId());
295 EXPECT_EQ(messageHubs[2].id, messageHub3->getId());
296
297 // Clear messageHubs and reset messageHub2
298 messageHubs.clear();
299 messageHub2.reset();
300
301 router.forEachMessageHub(
302 [&messageHubs](const MessageHubInfo &messageHubInfo) {
303 messageHubs.push_back(messageHubInfo);
304 return false;
305 });
306 EXPECT_EQ(messageHubs.size(), 2);
307 EXPECT_EQ(messageHubs[0].name, "hub1");
308 EXPECT_EQ(messageHubs[1].name, "hub3");
309 EXPECT_EQ(messageHubs[0].id, 1);
310 EXPECT_EQ(messageHubs[1].id, 3);
311 EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
312 EXPECT_EQ(messageHubs[1].id, messageHub3->getId());
313 }
314
TEST_F(MessageRouterTest,RegisterMessageHubTooManyFails)315 TEST_F(MessageRouterTest, RegisterMessageHubTooManyFails) {
316 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
317 static_assert(kMaxMessageHubs == 3);
318 constexpr const char *kNames[3] = {"hub1", "hub2", "hub3"};
319
320 MessageHubCallbackStoreData callback(/* message= */ nullptr,
321 /* session= */ nullptr);
322 MessageRouter::MessageHub messageHubs[kMaxMessageHubs];
323 for (size_t i = 0; i < kMaxMessageHubs; ++i) {
324 std::optional<MessageRouter::MessageHub> messageHub =
325 router.registerMessageHub(kNames[i], /* id= */ i, callback);
326 EXPECT_TRUE(messageHub.has_value());
327 messageHubs[i] = std::move(*messageHub);
328 }
329
330 std::optional<MessageRouter::MessageHub> messageHub =
331 router.registerMessageHub("shouldfail", /* id= */ kMaxMessageHubs * 2,
332 callback);
333 EXPECT_FALSE(messageHub.has_value());
334 }
335
TEST_F(MessageRouterTest,GetEndpointInfo)336 TEST_F(MessageRouterTest, GetEndpointInfo) {
337 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
338
339 MessageHubCallbackStoreData callback(/* message= */ nullptr,
340 /* session= */ nullptr);
341 std::optional<MessageRouter::MessageHub> messageHub1 =
342 router.registerMessageHub("hub1", /* id= */ 1, callback);
343 EXPECT_TRUE(messageHub1.has_value());
344 std::optional<MessageRouter::MessageHub> messageHub2 =
345 router.registerMessageHub("hub2", /* id= */ 2, callback);
346 EXPECT_TRUE(messageHub2.has_value());
347 std::optional<MessageRouter::MessageHub> messageHub3 =
348 router.registerMessageHub("hub3", /* id= */ 3, callback);
349 EXPECT_TRUE(messageHub3.has_value());
350
351 for (size_t i = 0; i < kNumEndpoints; ++i) {
352 EXPECT_EQ(
353 router.getEndpointInfo(messageHub1->getId(), kEndpointInfos[i].id),
354 kEndpointInfos[i]);
355 EXPECT_EQ(
356 router.getEndpointInfo(messageHub2->getId(), kEndpointInfos[i].id),
357 kEndpointInfos[i]);
358 EXPECT_EQ(
359 router.getEndpointInfo(messageHub3->getId(), kEndpointInfos[i].id),
360 kEndpointInfos[i]);
361 }
362 }
363
TEST_F(MessageRouterTest,RegisterSessionTwoDifferentMessageHubs)364 TEST_F(MessageRouterTest, RegisterSessionTwoDifferentMessageHubs) {
365 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
366 Session sessionFromCallback1;
367 Session sessionFromCallback2;
368 MessageHubCallbackStoreData callback(/* message= */ nullptr,
369 &sessionFromCallback1);
370 MessageHubCallbackStoreData callback2(/* message= */ nullptr,
371 &sessionFromCallback2);
372
373 std::optional<MessageRouter::MessageHub> messageHub =
374 router.registerMessageHub("hub1", /* id= */ 1, callback);
375 EXPECT_TRUE(messageHub.has_value());
376 std::optional<MessageRouter::MessageHub> messageHub2 =
377 router.registerMessageHub("hub2", /* id= */ 2, callback2);
378 EXPECT_TRUE(messageHub2.has_value());
379
380 // Open session from messageHub:1 to messageHub2:2
381 SessionId sessionId = messageHub->openSession(
382 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
383 EXPECT_NE(sessionId, SESSION_ID_INVALID);
384
385 // Get session from messageHub and compare it with messageHub2
386 std::optional<Session> sessionAfterRegistering =
387 messageHub->getSessionWithId(sessionId);
388 EXPECT_TRUE(sessionAfterRegistering.has_value());
389 EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
390 EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
391 messageHub->getId());
392 EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
393 kEndpointInfos[0].id);
394 EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
395 EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
396 std::optional<Session> sessionAfterRegistering2 =
397 messageHub2->getSessionWithId(sessionId);
398 EXPECT_TRUE(sessionAfterRegistering2.has_value());
399 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
400
401 // Close the session and verify it is closed on both message hubs
402 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
403 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback2);
404 EXPECT_TRUE(messageHub->closeSession(sessionId));
405 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
406 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback2);
407 EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
408 EXPECT_FALSE(messageHub2->getSessionWithId(sessionId).has_value());
409 }
410
TEST_F(MessageRouterTest,UnregisterMessageHubCausesSessionClosed)411 TEST_F(MessageRouterTest, UnregisterMessageHubCausesSessionClosed) {
412 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
413 Session sessionFromCallback1;
414 Session sessionFromCallback2;
415 MessageHubCallbackStoreData callback(/* message= */ nullptr,
416 &sessionFromCallback1);
417 MessageHubCallbackStoreData callback2(/* message= */ nullptr,
418 &sessionFromCallback2);
419
420 std::optional<MessageRouter::MessageHub> messageHub =
421 router.registerMessageHub("hub1", /* id= */ 1, callback);
422 EXPECT_TRUE(messageHub.has_value());
423 std::optional<MessageRouter::MessageHub> messageHub2 =
424 router.registerMessageHub("hub2", /* id= */ 2, callback2);
425 EXPECT_TRUE(messageHub2.has_value());
426
427 // Open session from messageHub:1 to messageHub2:2
428 SessionId sessionId = messageHub->openSession(
429 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
430 EXPECT_NE(sessionId, SESSION_ID_INVALID);
431
432 // Get session from messageHub and compare it with messageHub2
433 std::optional<Session> sessionAfterRegistering =
434 messageHub->getSessionWithId(sessionId);
435 EXPECT_TRUE(sessionAfterRegistering.has_value());
436 EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
437 EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
438 messageHub->getId());
439 EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
440 kEndpointInfos[0].id);
441 EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
442 EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
443 std::optional<Session> sessionAfterRegistering2 =
444 messageHub2->getSessionWithId(sessionId);
445 EXPECT_TRUE(sessionAfterRegistering2.has_value());
446 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
447
448 // Close the session and verify it is closed on the other hub
449 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
450 messageHub2.reset();
451 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
452 EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
453 }
454
TEST_F(MessageRouterTest,RegisterSessionSameMessageHubInvalid)455 TEST_F(MessageRouterTest, RegisterSessionSameMessageHubInvalid) {
456 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
457 Session sessionFromCallback1;
458 Session sessionFromCallback2;
459 MessageHubCallbackStoreData callback(/* message= */ nullptr,
460 &sessionFromCallback1);
461 MessageHubCallbackStoreData callback2(/* message= */ nullptr,
462 &sessionFromCallback2);
463
464 std::optional<MessageRouter::MessageHub> messageHub =
465 router.registerMessageHub("hub1", /* id= */ 1, callback);
466 EXPECT_TRUE(messageHub.has_value());
467 std::optional<MessageRouter::MessageHub> messageHub2 =
468 router.registerMessageHub("hub2", /* id= */ 2, callback2);
469 EXPECT_TRUE(messageHub2.has_value());
470
471 // Open session from messageHub:2 to messageHub:2
472 SessionId sessionId = messageHub->openSession(
473 kEndpointInfos[1].id, messageHub->getId(), kEndpointInfos[1].id);
474 EXPECT_EQ(sessionId, SESSION_ID_INVALID);
475
476 // Open session from messageHub:1 to messageHub:3
477 sessionId = messageHub->openSession(kEndpointInfos[0].id, messageHub->getId(),
478 kEndpointInfos[2].id);
479 EXPECT_EQ(sessionId, SESSION_ID_INVALID);
480 }
481
TEST_F(MessageRouterTest,RegisterSessionDifferentMessageHubsSameEndpoints)482 TEST_F(MessageRouterTest, RegisterSessionDifferentMessageHubsSameEndpoints) {
483 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
484 Session sessionFromCallback1;
485 Session sessionFromCallback2;
486 MessageHubCallbackStoreData callback(/* message= */ nullptr,
487 &sessionFromCallback1);
488 MessageHubCallbackStoreData callback2(/* message= */ nullptr,
489 &sessionFromCallback2);
490
491 std::optional<MessageRouter::MessageHub> messageHub =
492 router.registerMessageHub("hub1", /* id= */ 1, callback);
493 EXPECT_TRUE(messageHub.has_value());
494 std::optional<MessageRouter::MessageHub> messageHub2 =
495 router.registerMessageHub("hub2", /* id= */ 2, callback2);
496 EXPECT_TRUE(messageHub2.has_value());
497
498 // Open session from messageHub:1 to messageHub:2
499 SessionId sessionId = messageHub->openSession(
500 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[0].id);
501 EXPECT_NE(sessionId, SESSION_ID_INVALID);
502 }
503
TEST_F(MessageRouterTest,RegisterSessionTwoDifferentMessageHubsInvalidEndpoint)504 TEST_F(MessageRouterTest,
505 RegisterSessionTwoDifferentMessageHubsInvalidEndpoint) {
506 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
507 MessageHubCallbackStoreData callback(/* message= */ nullptr,
508 /* session= */ nullptr);
509 MessageHubCallbackStoreData callback2(/* message= */ nullptr,
510 /* session= */ nullptr);
511
512 std::optional<MessageRouter::MessageHub> messageHub =
513 router.registerMessageHub("hub1", /* id= */ 1, callback);
514 EXPECT_TRUE(messageHub.has_value());
515 std::optional<MessageRouter::MessageHub> messageHub2 =
516 router.registerMessageHub("hub2", /* id= */ 2, callback2);
517 EXPECT_TRUE(messageHub2.has_value());
518
519 // Open session from messageHub with other non-registered endpoint - not
520 // valid
521 SessionId sessionId = messageHub->openSession(
522 kEndpointInfos[0].id, messageHub2->getId(), /* toEndpointId= */ 10);
523 EXPECT_EQ(sessionId, SESSION_ID_INVALID);
524 }
525
TEST_F(MessageRouterTest,ThirdMessageHubTriesToFindOthersSession)526 TEST_F(MessageRouterTest, ThirdMessageHubTriesToFindOthersSession) {
527 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
528 Session sessionFromCallback1;
529 Session sessionFromCallback2;
530 Session sessionFromCallback3;
531 MessageHubCallbackStoreData callback(/* message= */ nullptr,
532 &sessionFromCallback1);
533 MessageHubCallbackStoreData callback2(/* message= */ nullptr,
534 &sessionFromCallback2);
535 MessageHubCallbackStoreData callback3(/* message= */ nullptr,
536 &sessionFromCallback3);
537
538 std::optional<MessageRouter::MessageHub> messageHub =
539 router.registerMessageHub("hub1", /* id= */ 1, callback);
540 EXPECT_TRUE(messageHub.has_value());
541 std::optional<MessageRouter::MessageHub> messageHub2 =
542 router.registerMessageHub("hub2", /* id= */ 2, callback2);
543 EXPECT_TRUE(messageHub2.has_value());
544 std::optional<MessageRouter::MessageHub> messageHub3 =
545 router.registerMessageHub("hub3", /* id= */ 3, callback3);
546 EXPECT_TRUE(messageHub3.has_value());
547
548 // Open session from messageHub:1 to messageHub2:2
549 SessionId sessionId = messageHub->openSession(
550 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
551 EXPECT_NE(sessionId, SESSION_ID_INVALID);
552
553 // Get session from messageHub and compare it with messageHub2
554 std::optional<Session> sessionAfterRegistering =
555 messageHub->getSessionWithId(sessionId);
556 EXPECT_TRUE(sessionAfterRegistering.has_value());
557 EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
558 EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
559 messageHub->getId());
560 EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
561 kEndpointInfos[0].id);
562 EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
563 EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
564 std::optional<Session> sessionAfterRegistering2 =
565 messageHub2->getSessionWithId(sessionId);
566 EXPECT_TRUE(sessionAfterRegistering2.has_value());
567 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
568
569 // Third message hub tries to find the session - not found
570 EXPECT_FALSE(messageHub3->getSessionWithId(sessionId).has_value());
571 // Third message hub tries to close the session - not found
572 EXPECT_FALSE(messageHub3->closeSession(sessionId));
573
574 // Get session from messageHub and compare it with messageHub2 again
575 sessionAfterRegistering = messageHub->getSessionWithId(sessionId);
576 EXPECT_TRUE(sessionAfterRegistering.has_value());
577 EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
578 EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
579 messageHub->getId());
580 EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
581 kEndpointInfos[0].id);
582 EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
583 EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
584 sessionAfterRegistering2 = messageHub2->getSessionWithId(sessionId);
585 EXPECT_TRUE(sessionAfterRegistering2.has_value());
586 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
587
588 // Close the session and verify it is closed on both message hubs
589 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
590 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback2);
591 EXPECT_TRUE(messageHub->closeSession(sessionId));
592 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
593 EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback2);
594 EXPECT_NE(*sessionAfterRegistering, sessionFromCallback3);
595 EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
596 EXPECT_FALSE(messageHub2->getSessionWithId(sessionId).has_value());
597 }
598
TEST_F(MessageRouterTest,ThreeMessageHubsAndThreeSessions)599 TEST_F(MessageRouterTest, ThreeMessageHubsAndThreeSessions) {
600 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
601 MessageHubCallbackStoreData callback(/* message= */ nullptr,
602 /* session= */ nullptr);
603 MessageHubCallbackStoreData callback2(/* message= */ nullptr,
604 /* session= */ nullptr);
605 MessageHubCallbackStoreData callback3(/* message= */ nullptr,
606 /* session= */ nullptr);
607
608 std::optional<MessageRouter::MessageHub> messageHub =
609 router.registerMessageHub("hub1", /* id= */ 1, callback);
610 EXPECT_TRUE(messageHub.has_value());
611 std::optional<MessageRouter::MessageHub> messageHub2 =
612 router.registerMessageHub("hub2", /* id= */ 2, callback2);
613 EXPECT_TRUE(messageHub2.has_value());
614 std::optional<MessageRouter::MessageHub> messageHub3 =
615 router.registerMessageHub("hub3", /* id= */ 3, callback3);
616 EXPECT_TRUE(messageHub3.has_value());
617
618 // Open session from messageHub:1 to messageHub2:2
619 SessionId sessionId = messageHub->openSession(
620 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
621 EXPECT_NE(sessionId, SESSION_ID_INVALID);
622
623 // Open session from messageHub2:2 to messageHub3:3
624 SessionId sessionId2 = messageHub2->openSession(
625 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
626 EXPECT_NE(sessionId, SESSION_ID_INVALID);
627
628 // Open session from messageHub3:3 to messageHub1:1
629 SessionId sessionId3 = messageHub3->openSession(
630 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
631 EXPECT_NE(sessionId, SESSION_ID_INVALID);
632
633 // Get sessions and compare
634 // Find session: MessageHub1:1 -> MessageHub2:2
635 std::optional<Session> sessionAfterRegistering =
636 messageHub->getSessionWithId(sessionId);
637 EXPECT_TRUE(sessionAfterRegistering.has_value());
638 std::optional<Session> sessionAfterRegistering2 =
639 messageHub2->getSessionWithId(sessionId);
640 EXPECT_TRUE(sessionAfterRegistering2.has_value());
641 EXPECT_FALSE(messageHub3->getSessionWithId(sessionId).has_value());
642 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
643
644 // Find session: MessageHub2:2 -> MessageHub3:3
645 sessionAfterRegistering = messageHub2->getSessionWithId(sessionId2);
646 EXPECT_TRUE(sessionAfterRegistering.has_value());
647 sessionAfterRegistering2 = messageHub3->getSessionWithId(sessionId2);
648 EXPECT_TRUE(sessionAfterRegistering2.has_value());
649 EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
650 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
651
652 // Find session: MessageHub3:3 -> MessageHub1:1
653 sessionAfterRegistering = messageHub3->getSessionWithId(sessionId3);
654 EXPECT_TRUE(sessionAfterRegistering.has_value());
655 sessionAfterRegistering2 = messageHub->getSessionWithId(sessionId3);
656 EXPECT_TRUE(sessionAfterRegistering2.has_value());
657 EXPECT_FALSE(messageHub2->getSessionWithId(sessionId3).has_value());
658 EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
659
660 // Close sessions from receivers and verify they are closed on all hubs
661 EXPECT_TRUE(messageHub2->closeSession(sessionId));
662 EXPECT_TRUE(messageHub3->closeSession(sessionId2));
663 EXPECT_TRUE(messageHub->closeSession(sessionId3));
664 for (SessionId id : {sessionId, sessionId2, sessionId3}) {
665 EXPECT_FALSE(messageHub->getSessionWithId(id).has_value());
666 EXPECT_FALSE(messageHub2->getSessionWithId(id).has_value());
667 EXPECT_FALSE(messageHub3->getSessionWithId(id).has_value());
668 }
669 }
670
TEST_F(MessageRouterTest,SendMessageToSession)671 TEST_F(MessageRouterTest, SendMessageToSession) {
672 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
673 constexpr size_t kMessageSize = 5;
674 pw::UniquePtr<std::byte[]> messageData =
675 mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
676 for (size_t i = 0; i < 5; ++i) {
677 messageData[i] = static_cast<std::byte>(i + 1);
678 }
679
680 Message messageFromCallback1;
681 Message messageFromCallback2;
682 Message messageFromCallback3;
683 Session sessionFromCallback1;
684 Session sessionFromCallback2;
685 Session sessionFromCallback3;
686 MessageHubCallbackStoreData callback(&messageFromCallback1,
687 &sessionFromCallback1);
688 MessageHubCallbackStoreData callback2(&messageFromCallback2,
689 &sessionFromCallback2);
690 MessageHubCallbackStoreData callback3(&messageFromCallback3,
691 &sessionFromCallback3);
692
693 std::optional<MessageRouter::MessageHub> messageHub =
694 router.registerMessageHub("hub1", /* id= */ 1, callback);
695 EXPECT_TRUE(messageHub.has_value());
696 std::optional<MessageRouter::MessageHub> messageHub2 =
697 router.registerMessageHub("hub2", /* id= */ 2, callback2);
698 EXPECT_TRUE(messageHub2.has_value());
699 std::optional<MessageRouter::MessageHub> messageHub3 =
700 router.registerMessageHub("hub3", /* id= */ 3, callback3);
701 EXPECT_TRUE(messageHub3.has_value());
702
703 // Open session from messageHub:1 to messageHub2:2
704 SessionId sessionId = messageHub->openSession(
705 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
706 EXPECT_NE(sessionId, SESSION_ID_INVALID);
707
708 // Open session from messageHub2:2 to messageHub3:3
709 SessionId sessionId2 = messageHub2->openSession(
710 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
711 EXPECT_NE(sessionId, SESSION_ID_INVALID);
712
713 // Open session from messageHub3:3 to messageHub1:1
714 SessionId sessionId3 = messageHub3->openSession(
715 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
716 EXPECT_NE(sessionId, SESSION_ID_INVALID);
717
718 // Send message from messageHub:1 to messageHub2:2
719 ASSERT_TRUE(messageHub->sendMessage(std::move(messageData), kMessageSize,
720 /* messageType= */ 1,
721 /* messagePermissions= */ 0, sessionId));
722 EXPECT_EQ(messageFromCallback2.sessionId, sessionId);
723 EXPECT_EQ(messageFromCallback2.sender.messageHubId, messageHub->getId());
724 EXPECT_EQ(messageFromCallback2.sender.endpointId, kEndpointInfos[0].id);
725 EXPECT_EQ(messageFromCallback2.recipient.messageHubId, messageHub2->getId());
726 EXPECT_EQ(messageFromCallback2.recipient.endpointId, kEndpointInfos[1].id);
727 EXPECT_EQ(messageFromCallback2.messageType, 1);
728 EXPECT_EQ(messageFromCallback2.messagePermissions, 0);
729 EXPECT_EQ(messageFromCallback2.length, kMessageSize);
730 for (size_t i = 0; i < kMessageSize; ++i) {
731 EXPECT_EQ(messageFromCallback2.data[i], static_cast<std::byte>(i + 1));
732 }
733
734 messageData = mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
735 for (size_t i = 0; i < 5; ++i) {
736 messageData[i] = static_cast<std::byte>(i + 1);
737 }
738
739 // Send message from messageHub2:2 to messageHub:1
740 ASSERT_TRUE(messageHub2->sendMessage(std::move(messageData), kMessageSize,
741 /* messageType= */ 2,
742 /* messagePermissions= */ 3, sessionId));
743 EXPECT_EQ(messageFromCallback1.sessionId, sessionId);
744 EXPECT_EQ(messageFromCallback1.sender.messageHubId, messageHub2->getId());
745 EXPECT_EQ(messageFromCallback1.sender.endpointId, kEndpointInfos[1].id);
746 EXPECT_EQ(messageFromCallback1.recipient.messageHubId, messageHub->getId());
747 EXPECT_EQ(messageFromCallback1.recipient.endpointId, kEndpointInfos[0].id);
748 EXPECT_EQ(messageFromCallback1.messageType, 2);
749 EXPECT_EQ(messageFromCallback1.messagePermissions, 3);
750 EXPECT_EQ(messageFromCallback1.length, kMessageSize);
751 for (size_t i = 0; i < kMessageSize; ++i) {
752 EXPECT_EQ(messageFromCallback1.data[i], static_cast<std::byte>(i + 1));
753 }
754 }
755
TEST_F(MessageRouterTest,SendMessageToSessionUsingPointerAndFreeCallback)756 TEST_F(MessageRouterTest, SendMessageToSessionUsingPointerAndFreeCallback) {
757 struct FreeCallbackContext {
758 bool *freeCallbackCalled;
759 std::byte *message;
760 size_t length;
761 };
762
763 pw::Vector<
764 MessageRouterCallbackAllocator<FreeCallbackContext>::FreeCallbackRecord,
765 10>
766 freeCallbackRecords;
767 MessageRouterCallbackAllocator<FreeCallbackContext> allocator(
768 [](std::byte *message, size_t length,
769 const FreeCallbackContext &context) {
770 *context.freeCallbackCalled =
771 message == context.message && length == context.length;
772 },
773 freeCallbackRecords);
774
775 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
776 constexpr size_t kMessageSize = 5;
777 std::byte messageData[kMessageSize];
778 for (size_t i = 0; i < 5; ++i) {
779 messageData[i] = static_cast<std::byte>(i + 1);
780 }
781
782 Message messageFromCallback1;
783 Message messageFromCallback2;
784 Message messageFromCallback3;
785 Session sessionFromCallback1;
786 Session sessionFromCallback2;
787 Session sessionFromCallback3;
788 MessageHubCallbackStoreData callback(&messageFromCallback1,
789 &sessionFromCallback1);
790 MessageHubCallbackStoreData callback2(&messageFromCallback2,
791 &sessionFromCallback2);
792 MessageHubCallbackStoreData callback3(&messageFromCallback3,
793 &sessionFromCallback3);
794
795 std::optional<MessageRouter::MessageHub> messageHub =
796 router.registerMessageHub("hub1", /* id= */ 1, callback);
797 EXPECT_TRUE(messageHub.has_value());
798 std::optional<MessageRouter::MessageHub> messageHub2 =
799 router.registerMessageHub("hub2", /* id= */ 2, callback2);
800 EXPECT_TRUE(messageHub2.has_value());
801 std::optional<MessageRouter::MessageHub> messageHub3 =
802 router.registerMessageHub("hub3", /* id= */ 3, callback3);
803 EXPECT_TRUE(messageHub3.has_value());
804
805 // Open session from messageHub:1 to messageHub2:2
806 SessionId sessionId = messageHub->openSession(
807 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
808 EXPECT_NE(sessionId, SESSION_ID_INVALID);
809
810 // Open session from messageHub2:2 to messageHub3:3
811 SessionId sessionId2 = messageHub2->openSession(
812 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
813 EXPECT_NE(sessionId, SESSION_ID_INVALID);
814
815 // Open session from messageHub3:3 to messageHub1:1
816 SessionId sessionId3 = messageHub3->openSession(
817 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
818 EXPECT_NE(sessionId, SESSION_ID_INVALID);
819
820 // Send message from messageHub:1 to messageHub2:2
821 bool freeCallbackCalled = false;
822 FreeCallbackContext freeCallbackContext = {
823 .freeCallbackCalled = &freeCallbackCalled,
824 .message = messageData,
825 .length = kMessageSize,
826 };
827 pw::UniquePtr<std::byte[]> data = allocator.MakeUniqueArrayWithCallback(
828 messageData, kMessageSize, std::move(freeCallbackContext));
829 ASSERT_NE(data.get(), nullptr);
830
831 ASSERT_TRUE(messageHub->sendMessage(std::move(data), kMessageSize,
832 /* messageType= */ 1,
833 /* messagePermissions= */ 0, sessionId));
834 EXPECT_EQ(messageFromCallback2.sessionId, sessionId);
835 EXPECT_EQ(messageFromCallback2.sender.messageHubId, messageHub->getId());
836 EXPECT_EQ(messageFromCallback2.sender.endpointId, kEndpointInfos[0].id);
837 EXPECT_EQ(messageFromCallback2.recipient.messageHubId, messageHub2->getId());
838 EXPECT_EQ(messageFromCallback2.recipient.endpointId, kEndpointInfos[1].id);
839 EXPECT_EQ(messageFromCallback2.messageType, 1);
840 EXPECT_EQ(messageFromCallback2.messagePermissions, 0);
841 EXPECT_EQ(messageFromCallback2.length, kMessageSize);
842 for (size_t i = 0; i < kMessageSize; ++i) {
843 EXPECT_EQ(messageFromCallback2.data[i], static_cast<std::byte>(i + 1));
844 }
845
846 // Check if free callback was called
847 EXPECT_FALSE(freeCallbackCalled);
848 EXPECT_EQ(messageFromCallback2.data.get(), messageData);
849 messageFromCallback2.data.Reset();
850 EXPECT_TRUE(freeCallbackCalled);
851
852 // Send message from messageHub2:2 to messageHub:1
853 freeCallbackCalled = false;
854 FreeCallbackContext freeCallbackContext2 = {
855 .freeCallbackCalled = &freeCallbackCalled,
856 .message = messageData,
857 .length = kMessageSize,
858 };
859 data = allocator.MakeUniqueArrayWithCallback(messageData, kMessageSize,
860 std::move(freeCallbackContext2));
861 ASSERT_NE(data.get(), nullptr);
862
863 ASSERT_TRUE(messageHub2->sendMessage(std::move(data), kMessageSize,
864 /* messageType= */ 2,
865 /* messagePermissions= */ 3, sessionId));
866 EXPECT_EQ(messageFromCallback1.sessionId, sessionId);
867 EXPECT_EQ(messageFromCallback1.sender.messageHubId, messageHub2->getId());
868 EXPECT_EQ(messageFromCallback1.sender.endpointId, kEndpointInfos[1].id);
869 EXPECT_EQ(messageFromCallback1.recipient.messageHubId, messageHub->getId());
870 EXPECT_EQ(messageFromCallback1.recipient.endpointId, kEndpointInfos[0].id);
871 EXPECT_EQ(messageFromCallback1.messageType, 2);
872 EXPECT_EQ(messageFromCallback1.messagePermissions, 3);
873 EXPECT_EQ(messageFromCallback1.length, kMessageSize);
874 for (size_t i = 0; i < kMessageSize; ++i) {
875 EXPECT_EQ(messageFromCallback1.data[i], static_cast<std::byte>(i + 1));
876 }
877
878 // Check if free callback was called
879 EXPECT_FALSE(freeCallbackCalled);
880 EXPECT_EQ(messageFromCallback1.data.get(), messageData);
881 messageFromCallback1.data.Reset();
882 EXPECT_TRUE(freeCallbackCalled);
883 }
884
TEST_F(MessageRouterTest,SendMessageToSessionInvalidHubAndSession)885 TEST_F(MessageRouterTest, SendMessageToSessionInvalidHubAndSession) {
886 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
887 constexpr size_t kMessageSize = 5;
888 pw::UniquePtr<std::byte[]> messageData =
889 mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
890 for (size_t i = 0; i < 5; ++i) {
891 messageData[i] = static_cast<std::byte>(i + 1);
892 }
893
894 Message messageFromCallback1;
895 Message messageFromCallback2;
896 Message messageFromCallback3;
897 Session sessionFromCallback1;
898 Session sessionFromCallback2;
899 Session sessionFromCallback3;
900 MessageHubCallbackStoreData callback(&messageFromCallback1,
901 &sessionFromCallback1);
902 MessageHubCallbackStoreData callback2(&messageFromCallback2,
903 &sessionFromCallback2);
904 MessageHubCallbackStoreData callback3(&messageFromCallback3,
905 &sessionFromCallback3);
906
907 std::optional<MessageRouter::MessageHub> messageHub =
908 router.registerMessageHub("hub1", /* id= */ 1, callback);
909 EXPECT_TRUE(messageHub.has_value());
910 std::optional<MessageRouter::MessageHub> messageHub2 =
911 router.registerMessageHub("hub2", /* id= */ 2, callback2);
912 EXPECT_TRUE(messageHub2.has_value());
913 std::optional<MessageRouter::MessageHub> messageHub3 =
914 router.registerMessageHub("hub3", /* id= */ 3, callback3);
915 EXPECT_TRUE(messageHub3.has_value());
916
917 // Open session from messageHub:1 to messageHub2:2
918 SessionId sessionId = messageHub->openSession(
919 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
920 EXPECT_NE(sessionId, SESSION_ID_INVALID);
921
922 // Open session from messageHub2:2 to messageHub3:3
923 SessionId sessionId2 = messageHub2->openSession(
924 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
925 EXPECT_NE(sessionId, SESSION_ID_INVALID);
926
927 // Open session from messageHub3:3 to messageHub1:1
928 SessionId sessionId3 = messageHub3->openSession(
929 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
930 EXPECT_NE(sessionId, SESSION_ID_INVALID);
931
932 // Send message from messageHub:1 to messageHub2:2
933 EXPECT_FALSE(messageHub->sendMessage(std::move(messageData), kMessageSize,
934 /* messageType= */ 1,
935 /* messagePermissions= */ 0,
936 sessionId2));
937 EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData), kMessageSize,
938 /* messageType= */ 2,
939 /* messagePermissions= */ 3,
940 sessionId3));
941 EXPECT_FALSE(messageHub3->sendMessage(std::move(messageData), kMessageSize,
942 /* messageType= */ 2,
943 /* messagePermissions= */ 3,
944 sessionId));
945 }
946
TEST_F(MessageRouterTest,SendMessageToSessionCallbackFailureClosesSession)947 TEST_F(MessageRouterTest, SendMessageToSessionCallbackFailureClosesSession) {
948 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
949 constexpr size_t kMessageSize = 5;
950 pw::UniquePtr<std::byte[]> messageData =
951 mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
952 for (size_t i = 0; i < 5; ++i) {
953 messageData[i] = static_cast<std::byte>(i + 1);
954 }
955
956 bool wasMessageReceivedCalled1 = false;
957 bool wasMessageReceivedCalled2 = false;
958 bool wasMessageReceivedCalled3 = false;
959 MessageHubCallbackAlwaysFails callback1(
960 &wasMessageReceivedCalled1,
961 /* wasSessionClosedCalled= */ nullptr);
962 MessageHubCallbackAlwaysFails callback2(
963 &wasMessageReceivedCalled2,
964 /* wasSessionClosedCalled= */ nullptr);
965 MessageHubCallbackAlwaysFails callback3(
966 &wasMessageReceivedCalled3,
967 /* wasSessionClosedCalled= */ nullptr);
968
969 std::optional<MessageRouter::MessageHub> messageHub =
970 router.registerMessageHub("hub1", /* id= */ 1, callback1);
971 EXPECT_TRUE(messageHub.has_value());
972 std::optional<MessageRouter::MessageHub> messageHub2 =
973 router.registerMessageHub("hub2", /* id= */ 2, callback2);
974 EXPECT_TRUE(messageHub2.has_value());
975 std::optional<MessageRouter::MessageHub> messageHub3 =
976 router.registerMessageHub("hub3", /* id= */ 3, callback3);
977 EXPECT_TRUE(messageHub3.has_value());
978
979 // Open session from messageHub:1 to messageHub2:2
980 SessionId sessionId = messageHub->openSession(
981 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
982 EXPECT_NE(sessionId, SESSION_ID_INVALID);
983
984 // Open session from messageHub2:2 to messageHub3:3
985 SessionId sessionId2 = messageHub2->openSession(
986 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
987 EXPECT_NE(sessionId, SESSION_ID_INVALID);
988
989 // Open session from messageHub3:3 to messageHub1:1
990 SessionId sessionId3 = messageHub3->openSession(
991 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
992 EXPECT_NE(sessionId, SESSION_ID_INVALID);
993
994 // Send message from messageHub2:2 to messageHub3:3
995 EXPECT_FALSE(wasMessageReceivedCalled1);
996 EXPECT_FALSE(wasMessageReceivedCalled2);
997 EXPECT_FALSE(wasMessageReceivedCalled3);
998 EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
999 EXPECT_TRUE(messageHub2->getSessionWithId(sessionId2).has_value());
1000 EXPECT_TRUE(messageHub3->getSessionWithId(sessionId2).has_value());
1001
1002 EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData), kMessageSize,
1003 /* messageType= */ 1,
1004 /* messagePermissions= */ 0,
1005 sessionId2));
1006 EXPECT_FALSE(wasMessageReceivedCalled1);
1007 EXPECT_FALSE(wasMessageReceivedCalled2);
1008 EXPECT_TRUE(wasMessageReceivedCalled3);
1009 EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
1010 EXPECT_FALSE(messageHub2->getSessionWithId(sessionId2).has_value());
1011 EXPECT_FALSE(messageHub3->getSessionWithId(sessionId2).has_value());
1012
1013 // Try to send a message on the same session - should fail
1014 wasMessageReceivedCalled1 = false;
1015 wasMessageReceivedCalled2 = false;
1016 wasMessageReceivedCalled3 = false;
1017 messageData = mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
1018 for (size_t i = 0; i < 5; ++i) {
1019 messageData[i] = static_cast<std::byte>(i + 1);
1020 }
1021 EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData), kMessageSize,
1022 /* messageType= */ 1,
1023 /* messagePermissions= */ 0,
1024 sessionId2));
1025 messageData = mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
1026 for (size_t i = 0; i < 5; ++i) {
1027 messageData[i] = static_cast<std::byte>(i + 1);
1028 }
1029 EXPECT_FALSE(messageHub3->sendMessage(std::move(messageData), kMessageSize,
1030 /* messageType= */ 1,
1031 /* messagePermissions= */ 0,
1032 sessionId2));
1033 EXPECT_FALSE(wasMessageReceivedCalled1);
1034 EXPECT_FALSE(wasMessageReceivedCalled2);
1035 EXPECT_FALSE(wasMessageReceivedCalled3);
1036 }
1037
TEST_F(MessageRouterTest,MessageHubCallbackCanCallOtherMessageHubAPIs)1038 TEST_F(MessageRouterTest, MessageHubCallbackCanCallOtherMessageHubAPIs) {
1039 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1040 constexpr size_t kMessageSize = 5;
1041 pw::UniquePtr<std::byte[]> messageData =
1042 mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
1043 for (size_t i = 0; i < 5; ++i) {
1044 messageData[i] = static_cast<std::byte>(i + 1);
1045 }
1046
1047 MessageHubCallbackCallsMessageHubApisDuringCallback callback;
1048 MessageHubCallbackCallsMessageHubApisDuringCallback callback2;
1049 MessageHubCallbackCallsMessageHubApisDuringCallback callback3;
1050
1051 std::optional<MessageRouter::MessageHub> messageHub =
1052 router.registerMessageHub("hub1", /* id= */ 1, callback);
1053 EXPECT_TRUE(messageHub.has_value());
1054 callback.setMessageHub(&messageHub.value());
1055 std::optional<MessageRouter::MessageHub> messageHub2 =
1056 router.registerMessageHub("hub2", /* id= */ 2, callback2);
1057 EXPECT_TRUE(messageHub2.has_value());
1058 callback2.setMessageHub(&messageHub2.value());
1059 std::optional<MessageRouter::MessageHub> messageHub3 =
1060 router.registerMessageHub("hub3", /* id= */ 3, callback3);
1061 EXPECT_TRUE(messageHub3.has_value());
1062 callback3.setMessageHub(&messageHub3.value());
1063
1064 // Open session from messageHub:1 to messageHub2:2
1065 SessionId sessionId = messageHub->openSession(
1066 kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1067 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1068
1069 // Open session from messageHub2:2 to messageHub3:3
1070 SessionId sessionId2 = messageHub2->openSession(
1071 kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1072 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1073
1074 // Open session from messageHub3:3 to messageHub1:1
1075 SessionId sessionId3 = messageHub3->openSession(
1076 kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1077 EXPECT_NE(sessionId, SESSION_ID_INVALID);
1078
1079 // Send message from messageHub:1 to messageHub2:2
1080 EXPECT_TRUE(messageHub->sendMessage(std::move(messageData), kMessageSize,
1081 /* messageType= */ 1,
1082 /* messagePermissions= */ 0, sessionId));
1083
1084 // Send message from messageHub2:2 to messageHub:1
1085 messageData = mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
1086 for (size_t i = 0; i < 5; ++i) {
1087 messageData[i] = static_cast<std::byte>(i + 1);
1088 }
1089 EXPECT_TRUE(messageHub2->sendMessage(std::move(messageData), kMessageSize,
1090 /* messageType= */ 2,
1091 /* messagePermissions= */ 3, sessionId));
1092
1093 // Close all sessions
1094 EXPECT_TRUE(messageHub->closeSession(sessionId));
1095 EXPECT_TRUE(messageHub2->closeSession(sessionId2));
1096 EXPECT_TRUE(messageHub3->closeSession(sessionId3));
1097
1098 // If we finish the test, both callbacks should have been called
1099 // If the router holds the lock during the callback, this test will timeout
1100 }
1101
TEST_F(MessageRouterTest,ForEachEndpointOfHub)1102 TEST_F(MessageRouterTest, ForEachEndpointOfHub) {
1103 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1104 MessageHubCallbackStoreData callback(/* message= */ nullptr,
1105 /* session= */ nullptr);
1106 std::optional<MessageRouter::MessageHub> messageHub =
1107 router.registerMessageHub("hub1", /* id= */ 1, callback);
1108 EXPECT_TRUE(messageHub.has_value());
1109
1110 DynamicVector<EndpointInfo> endpoints;
1111 EXPECT_TRUE(router.forEachEndpointOfHub(
1112 /* messageHubId= */ 1, [&endpoints](const EndpointInfo &info) {
1113 endpoints.push_back(info);
1114 return false;
1115 }));
1116 EXPECT_EQ(endpoints.size(), kNumEndpoints);
1117 for (size_t i = 0; i < endpoints.size(); ++i) {
1118 EXPECT_EQ(endpoints[i].id, kEndpointInfos[i].id);
1119 EXPECT_STREQ(endpoints[i].name, kEndpointInfos[i].name);
1120 EXPECT_EQ(endpoints[i].version, kEndpointInfos[i].version);
1121 EXPECT_EQ(endpoints[i].type, kEndpointInfos[i].type);
1122 EXPECT_EQ(endpoints[i].requiredPermissions,
1123 kEndpointInfos[i].requiredPermissions);
1124 }
1125 }
1126
TEST_F(MessageRouterTest,ForEachEndpoint)1127 TEST_F(MessageRouterTest, ForEachEndpoint) {
1128 const char *kHubName = "hub1";
1129 constexpr MessageHubId kHubId = 1;
1130
1131 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1132 MessageHubCallbackStoreData callback(/* message= */ nullptr,
1133 /* session= */ nullptr);
1134 std::optional<MessageRouter::MessageHub> messageHub =
1135 router.registerMessageHub(kHubName, kHubId, callback);
1136 EXPECT_TRUE(messageHub.has_value());
1137
1138 DynamicVector<std::pair<MessageHubInfo, EndpointInfo>> endpoints;
1139 router.forEachEndpoint(
1140 [&endpoints](const MessageHubInfo &hubInfo, const EndpointInfo &info) {
1141 endpoints.push_back(std::make_pair(hubInfo, info));
1142 });
1143 EXPECT_EQ(endpoints.size(), kNumEndpoints);
1144 for (size_t i = 0; i < endpoints.size(); ++i) {
1145 EXPECT_EQ(endpoints[i].first.id, kHubId);
1146 EXPECT_STREQ(endpoints[i].first.name, kHubName);
1147
1148 EXPECT_EQ(endpoints[i].second.id, kEndpointInfos[i].id);
1149 EXPECT_STREQ(endpoints[i].second.name, kEndpointInfos[i].name);
1150 EXPECT_EQ(endpoints[i].second.version, kEndpointInfos[i].version);
1151 EXPECT_EQ(endpoints[i].second.type, kEndpointInfos[i].type);
1152 EXPECT_EQ(endpoints[i].second.requiredPermissions,
1153 kEndpointInfos[i].requiredPermissions);
1154 }
1155 }
1156
TEST_F(MessageRouterTest,ForEachEndpointOfHubInvalidHub)1157 TEST_F(MessageRouterTest, ForEachEndpointOfHubInvalidHub) {
1158 MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1159 MessageHubCallbackStoreData callback(/* message= */ nullptr,
1160 /* session= */ nullptr);
1161 std::optional<MessageRouter::MessageHub> messageHub =
1162 router.registerMessageHub("hub1", /* id= */ 1, callback);
1163 EXPECT_TRUE(messageHub.has_value());
1164
1165 DynamicVector<EndpointInfo> endpoints;
1166 EXPECT_FALSE(router.forEachEndpointOfHub(
1167 /* messageHubId= */ 2, [&endpoints](const EndpointInfo &info) {
1168 endpoints.push_back(info);
1169 return false;
1170 }));
1171 EXPECT_EQ(endpoints.size(), 0);
1172 }
1173
1174 } // namespace
1175 } // namespace chre::message
1176