xref: /aosp_15_r20/system/chre/util/tests/message_router_test.cc (revision 84e339476a462649f82315436d70fd732297a399)
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <pw_allocator/allocator.h>
18 #include <pw_allocator/capability.h>
19 #include <pw_allocator/unique_ptr.h>
20 #include <cstddef>
21 #include <cstdint>
22 #include <optional>
23 #include <utility>
24 
25 #include "chre/util/dynamic_vector.h"
26 #include "chre/util/system/message_common.h"
27 #include "chre/util/system/message_router.h"
28 #include "chre/util/system/message_router_callback_allocator.h"
29 #include "chre_api/chre.h"
30 #include "gtest/gtest.h"
31 
32 namespace chre::message {
33 namespace {
34 
35 constexpr size_t kMaxMessageHubs = 3;
36 constexpr size_t kMaxSessions = 10;
37 constexpr size_t kMaxFreeCallbackRecords = kMaxSessions * 2;
38 constexpr size_t kNumEndpoints = 3;
39 
40 const EndpointInfo kEndpointInfos[kNumEndpoints] = {
41     EndpointInfo(/* id= */ 1, /* name= */ "endpoint1", /* version= */ 1,
42                  EndpointType::NANOAPP, CHRE_MESSAGE_PERMISSION_NONE),
43     EndpointInfo(/* id= */ 2, /* name= */ "endpoint2", /* version= */ 10,
44                  EndpointType::HOST_NATIVE, CHRE_MESSAGE_PERMISSION_BLE),
45     EndpointInfo(/* id= */ 3, /* name= */ "endpoint3", /* version= */ 100,
46                  EndpointType::GENERIC, CHRE_MESSAGE_PERMISSION_AUDIO)};
47 
48 class TestAllocator : public pw::Allocator {
49  public:
50   static constexpr Capabilities kCapabilities = 0;
51 
TestAllocator()52   TestAllocator() : pw::Allocator(kCapabilities) {}
53 
DoAllocate(Layout layout)54   virtual void *DoAllocate(Layout layout) override {
55     if (layout.alignment() > alignof(std::max_align_t)) {
56       void *ptr;
57       return posix_memalign(&ptr, layout.alignment(), layout.size()) == 0
58                  ? ptr
59                  : nullptr;
60     } else {
61       return malloc(layout.size());
62     }
63   }
64 
DoDeallocate(void * ptr)65   virtual void DoDeallocate(void *ptr) override {
66     free(ptr);
67   }
68 };
69 
70 class MessageRouterTest : public ::testing::Test {
71  protected:
SetUp()72   void SetUp() override {}
73 
74   TestAllocator mAllocator;
75 };
76 
77 //! Base class for MessageHubCallbacks used in tests
78 class MessageHubCallbackBase : public MessageRouter::MessageHubCallback {
79  public:
forEachEndpoint(const pw::Function<bool (const EndpointInfo &)> & function)80   void forEachEndpoint(
81       const pw::Function<bool(const EndpointInfo &)> &function) override {
82     for (const EndpointInfo &endpointInfo : kEndpointInfos) {
83       if (function(endpointInfo)) {
84         return;
85       }
86     }
87   }
88 
getEndpointInfo(EndpointId endpointId)89   std::optional<EndpointInfo> getEndpointInfo(EndpointId endpointId) override {
90     for (const EndpointInfo &endpointInfo : kEndpointInfos) {
91       if (endpointInfo.id == endpointId) {
92         return endpointInfo;
93       }
94     }
95     return std::nullopt;
96   }
97 };
98 
99 //! MessageHubCallback that stores the data passed to onMessageReceived and
100 //! onSessionClosed
101 class MessageHubCallbackStoreData : public MessageHubCallbackBase {
102  public:
MessageHubCallbackStoreData(Message * message,Session * session)103   MessageHubCallbackStoreData(Message *message, Session *session)
104       : mMessage(message), mSession(session) {}
105 
onMessageReceived(pw::UniquePtr<std::byte[]> && data,size_t length,uint32_t messageType,uint32_t messagePermissions,const Session & session,bool sentBySessionInitiator)106   bool onMessageReceived(pw::UniquePtr<std::byte[]> &&data, size_t length,
107                          uint32_t messageType, uint32_t messagePermissions,
108                          const Session &session,
109                          bool sentBySessionInitiator) override {
110     if (mMessage != nullptr) {
111       mMessage->sender = sentBySessionInitiator ? session.initiator
112                                                 : session.peer;
113       mMessage->recipient =
114           sentBySessionInitiator ? session.peer : session.initiator;
115       mMessage->sessionId = session.sessionId;
116       mMessage->data = std::move(data);
117       mMessage->length = length;
118       mMessage->messageType = messageType;
119       mMessage->messagePermissions = messagePermissions;
120     }
121     return true;
122   }
123 
onSessionClosed(const Session & session)124   void onSessionClosed(const Session &session) override {
125     if (mSession != nullptr) {
126       *mSession = session;
127     }
128   }
129 
130  private:
131   Message *mMessage;
132   Session *mSession;
133 };
134 
135 //! MessageHubCallback that always fails to process messages
136 class MessageHubCallbackAlwaysFails : public MessageHubCallbackBase {
137  public:
MessageHubCallbackAlwaysFails(bool * wasMessageReceivedCalled,bool * wasSessionClosedCalled)138   MessageHubCallbackAlwaysFails(bool *wasMessageReceivedCalled,
139                                 bool *wasSessionClosedCalled)
140       : mWasMessageReceivedCalled(wasMessageReceivedCalled),
141         mWasSessionClosedCalled(wasSessionClosedCalled) {}
142 
onMessageReceived(pw::UniquePtr<std::byte[]> &&,size_t,uint32_t,uint32_t,const Session &,bool)143   bool onMessageReceived(pw::UniquePtr<std::byte[]> && /* data */,
144                          size_t /* length */, uint32_t /* messageType */,
145                          uint32_t /* messagePermissions */,
146                          const Session & /* session */,
147                          bool /* sentBySessionInitiator */) override {
148     if (mWasMessageReceivedCalled != nullptr) {
149       *mWasMessageReceivedCalled = true;
150     }
151     return false;
152   }
153 
onSessionClosed(const Session &)154   void onSessionClosed(const Session & /* session */) override {
155     if (mWasSessionClosedCalled != nullptr) {
156       *mWasSessionClosedCalled = true;
157     }
158   }
159 
160  private:
161   bool *mWasMessageReceivedCalled;
162   bool *mWasSessionClosedCalled;
163 };
164 
165 //! MessageHubCallback that calls MessageHub APIs during callbacks
166 class MessageHubCallbackCallsMessageHubApisDuringCallback
167     : public MessageHubCallbackBase {
168  public:
onMessageReceived(pw::UniquePtr<std::byte[]> &&,size_t,uint32_t,uint32_t,const Session &,bool)169   bool onMessageReceived(pw::UniquePtr<std::byte[]> && /* data */,
170                          size_t /* length */, uint32_t /* messageType */,
171                          uint32_t /* messagePermissions */,
172                          const Session & /* session */,
173                          bool /* sentBySessionInitiator */) override {
174     if (mMessageHub != nullptr) {
175       // Call a function that locks the MessageRouter mutex
176       mMessageHub->openSession(kEndpointInfos[0].id, mMessageHub->getId(),
177                                kEndpointInfos[1].id);
178     }
179     return true;
180   }
181 
onSessionClosed(const Session &)182   void onSessionClosed(const Session & /* session */) override {
183     if (mMessageHub != nullptr) {
184       // Call a function that locks the MessageRouter mutex
185       mMessageHub->openSession(kEndpointInfos[0].id, mMessageHub->getId(),
186                                kEndpointInfos[1].id);
187     }
188   }
189 
setMessageHub(MessageRouter::MessageHub * messageHub)190   void setMessageHub(MessageRouter::MessageHub *messageHub) {
191     mMessageHub = messageHub;
192   }
193 
194  private:
195   MessageRouter::MessageHub *mMessageHub = nullptr;
196 };
197 
TEST_F(MessageRouterTest,RegisterMessageHubNameIsUnique)198 TEST_F(MessageRouterTest, RegisterMessageHubNameIsUnique) {
199   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
200 
201   MessageHubCallbackStoreData callback(/* message= */ nullptr,
202                                        /* session= */ nullptr);
203   std::optional<MessageRouter::MessageHub> messageHub1 =
204       router.registerMessageHub("hub1", /* id= */ 1, callback);
205   EXPECT_TRUE(messageHub1.has_value());
206   std::optional<MessageRouter::MessageHub> messageHub2 =
207       router.registerMessageHub("hub2", /* id= */ 2, callback);
208   EXPECT_TRUE(messageHub2.has_value());
209 
210   std::optional<MessageRouter::MessageHub> messageHub3 =
211       router.registerMessageHub("hub1", /* id= */ 1, callback);
212   EXPECT_FALSE(messageHub3.has_value());
213 }
214 
TEST_F(MessageRouterTest,RegisterMessageHubIdIsUnique)215 TEST_F(MessageRouterTest, RegisterMessageHubIdIsUnique) {
216   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
217 
218   MessageHubCallbackStoreData callback(/* message= */ nullptr,
219                                        /* session= */ nullptr);
220   std::optional<MessageRouter::MessageHub> messageHub1 =
221       router.registerMessageHub("hub1", /* id= */ 1, callback);
222   EXPECT_TRUE(messageHub1.has_value());
223   std::optional<MessageRouter::MessageHub> messageHub2 =
224       router.registerMessageHub("hub2", /* id= */ 2, callback);
225   EXPECT_TRUE(messageHub2.has_value());
226 
227   std::optional<MessageRouter::MessageHub> messageHub3 =
228       router.registerMessageHub("hub3", /* id= */ 1, callback);
229   EXPECT_FALSE(messageHub3.has_value());
230 }
231 
TEST_F(MessageRouterTest,RegisterMessageHubGetListOfHubs)232 TEST_F(MessageRouterTest, RegisterMessageHubGetListOfHubs) {
233   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
234 
235   MessageHubCallbackStoreData callback(/* message= */ nullptr,
236                                        /* session= */ nullptr);
237   std::optional<MessageRouter::MessageHub> messageHub1 =
238       router.registerMessageHub("hub1", /* id= */ 1, callback);
239   EXPECT_TRUE(messageHub1.has_value());
240   std::optional<MessageRouter::MessageHub> messageHub2 =
241       router.registerMessageHub("hub2", /* id= */ 2, callback);
242   EXPECT_TRUE(messageHub2.has_value());
243   std::optional<MessageRouter::MessageHub> messageHub3 =
244       router.registerMessageHub("hub3", /* id= */ 3, callback);
245   EXPECT_TRUE(messageHub3.has_value());
246 
247   DynamicVector<MessageHubInfo> messageHubs;
248   router.forEachMessageHub(
249       [&messageHubs](const MessageHubInfo &messageHubInfo) {
250         messageHubs.push_back(messageHubInfo);
251         return false;
252       });
253   EXPECT_EQ(messageHubs.size(), 3);
254   EXPECT_EQ(messageHubs[0].name, "hub1");
255   EXPECT_EQ(messageHubs[1].name, "hub2");
256   EXPECT_EQ(messageHubs[2].name, "hub3");
257   EXPECT_EQ(messageHubs[0].id, 1);
258   EXPECT_EQ(messageHubs[1].id, 2);
259   EXPECT_EQ(messageHubs[2].id, 3);
260   EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
261   EXPECT_EQ(messageHubs[1].id, messageHub2->getId());
262   EXPECT_EQ(messageHubs[2].id, messageHub3->getId());
263 }
264 
TEST_F(MessageRouterTest,RegisterMessageHubGetListOfHubsWithUnregister)265 TEST_F(MessageRouterTest, RegisterMessageHubGetListOfHubsWithUnregister) {
266   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
267 
268   MessageHubCallbackStoreData callback(/* message= */ nullptr,
269                                        /* session= */ nullptr);
270   std::optional<MessageRouter::MessageHub> messageHub1 =
271       router.registerMessageHub("hub1", /* id= */ 1, callback);
272   EXPECT_TRUE(messageHub1.has_value());
273   std::optional<MessageRouter::MessageHub> messageHub2 =
274       router.registerMessageHub("hub2", /* id= */ 2, callback);
275   EXPECT_TRUE(messageHub2.has_value());
276   std::optional<MessageRouter::MessageHub> messageHub3 =
277       router.registerMessageHub("hub3", /* id= */ 3, callback);
278   EXPECT_TRUE(messageHub3.has_value());
279 
280   DynamicVector<MessageHubInfo> messageHubs;
281   router.forEachMessageHub(
282       [&messageHubs](const MessageHubInfo &messageHubInfo) {
283         messageHubs.push_back(messageHubInfo);
284         return false;
285       });
286   EXPECT_EQ(messageHubs.size(), 3);
287   EXPECT_EQ(messageHubs[0].name, "hub1");
288   EXPECT_EQ(messageHubs[1].name, "hub2");
289   EXPECT_EQ(messageHubs[2].name, "hub3");
290   EXPECT_EQ(messageHubs[0].id, 1);
291   EXPECT_EQ(messageHubs[1].id, 2);
292   EXPECT_EQ(messageHubs[2].id, 3);
293   EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
294   EXPECT_EQ(messageHubs[1].id, messageHub2->getId());
295   EXPECT_EQ(messageHubs[2].id, messageHub3->getId());
296 
297   // Clear messageHubs and reset messageHub2
298   messageHubs.clear();
299   messageHub2.reset();
300 
301   router.forEachMessageHub(
302       [&messageHubs](const MessageHubInfo &messageHubInfo) {
303         messageHubs.push_back(messageHubInfo);
304         return false;
305       });
306   EXPECT_EQ(messageHubs.size(), 2);
307   EXPECT_EQ(messageHubs[0].name, "hub1");
308   EXPECT_EQ(messageHubs[1].name, "hub3");
309   EXPECT_EQ(messageHubs[0].id, 1);
310   EXPECT_EQ(messageHubs[1].id, 3);
311   EXPECT_EQ(messageHubs[0].id, messageHub1->getId());
312   EXPECT_EQ(messageHubs[1].id, messageHub3->getId());
313 }
314 
TEST_F(MessageRouterTest,RegisterMessageHubTooManyFails)315 TEST_F(MessageRouterTest, RegisterMessageHubTooManyFails) {
316   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
317   static_assert(kMaxMessageHubs == 3);
318   constexpr const char *kNames[3] = {"hub1", "hub2", "hub3"};
319 
320   MessageHubCallbackStoreData callback(/* message= */ nullptr,
321                                        /* session= */ nullptr);
322   MessageRouter::MessageHub messageHubs[kMaxMessageHubs];
323   for (size_t i = 0; i < kMaxMessageHubs; ++i) {
324     std::optional<MessageRouter::MessageHub> messageHub =
325         router.registerMessageHub(kNames[i], /* id= */ i, callback);
326     EXPECT_TRUE(messageHub.has_value());
327     messageHubs[i] = std::move(*messageHub);
328   }
329 
330   std::optional<MessageRouter::MessageHub> messageHub =
331       router.registerMessageHub("shouldfail", /* id= */ kMaxMessageHubs * 2,
332                                 callback);
333   EXPECT_FALSE(messageHub.has_value());
334 }
335 
TEST_F(MessageRouterTest,GetEndpointInfo)336 TEST_F(MessageRouterTest, GetEndpointInfo) {
337   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
338 
339   MessageHubCallbackStoreData callback(/* message= */ nullptr,
340                                        /* session= */ nullptr);
341   std::optional<MessageRouter::MessageHub> messageHub1 =
342       router.registerMessageHub("hub1", /* id= */ 1, callback);
343   EXPECT_TRUE(messageHub1.has_value());
344   std::optional<MessageRouter::MessageHub> messageHub2 =
345       router.registerMessageHub("hub2", /* id= */ 2, callback);
346   EXPECT_TRUE(messageHub2.has_value());
347   std::optional<MessageRouter::MessageHub> messageHub3 =
348       router.registerMessageHub("hub3", /* id= */ 3, callback);
349   EXPECT_TRUE(messageHub3.has_value());
350 
351   for (size_t i = 0; i < kNumEndpoints; ++i) {
352     EXPECT_EQ(
353         router.getEndpointInfo(messageHub1->getId(), kEndpointInfos[i].id),
354         kEndpointInfos[i]);
355     EXPECT_EQ(
356         router.getEndpointInfo(messageHub2->getId(), kEndpointInfos[i].id),
357         kEndpointInfos[i]);
358     EXPECT_EQ(
359         router.getEndpointInfo(messageHub3->getId(), kEndpointInfos[i].id),
360         kEndpointInfos[i]);
361   }
362 }
363 
TEST_F(MessageRouterTest,RegisterSessionTwoDifferentMessageHubs)364 TEST_F(MessageRouterTest, RegisterSessionTwoDifferentMessageHubs) {
365   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
366   Session sessionFromCallback1;
367   Session sessionFromCallback2;
368   MessageHubCallbackStoreData callback(/* message= */ nullptr,
369                                        &sessionFromCallback1);
370   MessageHubCallbackStoreData callback2(/* message= */ nullptr,
371                                         &sessionFromCallback2);
372 
373   std::optional<MessageRouter::MessageHub> messageHub =
374       router.registerMessageHub("hub1", /* id= */ 1, callback);
375   EXPECT_TRUE(messageHub.has_value());
376   std::optional<MessageRouter::MessageHub> messageHub2 =
377       router.registerMessageHub("hub2", /* id= */ 2, callback2);
378   EXPECT_TRUE(messageHub2.has_value());
379 
380   // Open session from messageHub:1 to messageHub2:2
381   SessionId sessionId = messageHub->openSession(
382       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
383   EXPECT_NE(sessionId, SESSION_ID_INVALID);
384 
385   // Get session from messageHub and compare it with messageHub2
386   std::optional<Session> sessionAfterRegistering =
387       messageHub->getSessionWithId(sessionId);
388   EXPECT_TRUE(sessionAfterRegistering.has_value());
389   EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
390   EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
391             messageHub->getId());
392   EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
393             kEndpointInfos[0].id);
394   EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
395   EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
396   std::optional<Session> sessionAfterRegistering2 =
397       messageHub2->getSessionWithId(sessionId);
398   EXPECT_TRUE(sessionAfterRegistering2.has_value());
399   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
400 
401   // Close the session and verify it is closed on both message hubs
402   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
403   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback2);
404   EXPECT_TRUE(messageHub->closeSession(sessionId));
405   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
406   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback2);
407   EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
408   EXPECT_FALSE(messageHub2->getSessionWithId(sessionId).has_value());
409 }
410 
TEST_F(MessageRouterTest,UnregisterMessageHubCausesSessionClosed)411 TEST_F(MessageRouterTest, UnregisterMessageHubCausesSessionClosed) {
412   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
413   Session sessionFromCallback1;
414   Session sessionFromCallback2;
415   MessageHubCallbackStoreData callback(/* message= */ nullptr,
416                                        &sessionFromCallback1);
417   MessageHubCallbackStoreData callback2(/* message= */ nullptr,
418                                         &sessionFromCallback2);
419 
420   std::optional<MessageRouter::MessageHub> messageHub =
421       router.registerMessageHub("hub1", /* id= */ 1, callback);
422   EXPECT_TRUE(messageHub.has_value());
423   std::optional<MessageRouter::MessageHub> messageHub2 =
424       router.registerMessageHub("hub2", /* id= */ 2, callback2);
425   EXPECT_TRUE(messageHub2.has_value());
426 
427   // Open session from messageHub:1 to messageHub2:2
428   SessionId sessionId = messageHub->openSession(
429       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
430   EXPECT_NE(sessionId, SESSION_ID_INVALID);
431 
432   // Get session from messageHub and compare it with messageHub2
433   std::optional<Session> sessionAfterRegistering =
434       messageHub->getSessionWithId(sessionId);
435   EXPECT_TRUE(sessionAfterRegistering.has_value());
436   EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
437   EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
438             messageHub->getId());
439   EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
440             kEndpointInfos[0].id);
441   EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
442   EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
443   std::optional<Session> sessionAfterRegistering2 =
444       messageHub2->getSessionWithId(sessionId);
445   EXPECT_TRUE(sessionAfterRegistering2.has_value());
446   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
447 
448   // Close the session and verify it is closed on the other hub
449   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
450   messageHub2.reset();
451   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
452   EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
453 }
454 
TEST_F(MessageRouterTest,RegisterSessionSameMessageHubInvalid)455 TEST_F(MessageRouterTest, RegisterSessionSameMessageHubInvalid) {
456   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
457   Session sessionFromCallback1;
458   Session sessionFromCallback2;
459   MessageHubCallbackStoreData callback(/* message= */ nullptr,
460                                        &sessionFromCallback1);
461   MessageHubCallbackStoreData callback2(/* message= */ nullptr,
462                                         &sessionFromCallback2);
463 
464   std::optional<MessageRouter::MessageHub> messageHub =
465       router.registerMessageHub("hub1", /* id= */ 1, callback);
466   EXPECT_TRUE(messageHub.has_value());
467   std::optional<MessageRouter::MessageHub> messageHub2 =
468       router.registerMessageHub("hub2", /* id= */ 2, callback2);
469   EXPECT_TRUE(messageHub2.has_value());
470 
471   // Open session from messageHub:2 to messageHub:2
472   SessionId sessionId = messageHub->openSession(
473       kEndpointInfos[1].id, messageHub->getId(), kEndpointInfos[1].id);
474   EXPECT_EQ(sessionId, SESSION_ID_INVALID);
475 
476   // Open session from messageHub:1 to messageHub:3
477   sessionId = messageHub->openSession(kEndpointInfos[0].id, messageHub->getId(),
478                                       kEndpointInfos[2].id);
479   EXPECT_EQ(sessionId, SESSION_ID_INVALID);
480 }
481 
TEST_F(MessageRouterTest,RegisterSessionDifferentMessageHubsSameEndpoints)482 TEST_F(MessageRouterTest, RegisterSessionDifferentMessageHubsSameEndpoints) {
483   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
484   Session sessionFromCallback1;
485   Session sessionFromCallback2;
486   MessageHubCallbackStoreData callback(/* message= */ nullptr,
487                                        &sessionFromCallback1);
488   MessageHubCallbackStoreData callback2(/* message= */ nullptr,
489                                         &sessionFromCallback2);
490 
491   std::optional<MessageRouter::MessageHub> messageHub =
492       router.registerMessageHub("hub1", /* id= */ 1, callback);
493   EXPECT_TRUE(messageHub.has_value());
494   std::optional<MessageRouter::MessageHub> messageHub2 =
495       router.registerMessageHub("hub2", /* id= */ 2, callback2);
496   EXPECT_TRUE(messageHub2.has_value());
497 
498   // Open session from messageHub:1 to messageHub:2
499   SessionId sessionId = messageHub->openSession(
500       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[0].id);
501   EXPECT_NE(sessionId, SESSION_ID_INVALID);
502 }
503 
TEST_F(MessageRouterTest,RegisterSessionTwoDifferentMessageHubsInvalidEndpoint)504 TEST_F(MessageRouterTest,
505        RegisterSessionTwoDifferentMessageHubsInvalidEndpoint) {
506   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
507   MessageHubCallbackStoreData callback(/* message= */ nullptr,
508                                        /* session= */ nullptr);
509   MessageHubCallbackStoreData callback2(/* message= */ nullptr,
510                                         /* session= */ nullptr);
511 
512   std::optional<MessageRouter::MessageHub> messageHub =
513       router.registerMessageHub("hub1", /* id= */ 1, callback);
514   EXPECT_TRUE(messageHub.has_value());
515   std::optional<MessageRouter::MessageHub> messageHub2 =
516       router.registerMessageHub("hub2", /* id= */ 2, callback2);
517   EXPECT_TRUE(messageHub2.has_value());
518 
519   // Open session from messageHub with other non-registered endpoint - not
520   // valid
521   SessionId sessionId = messageHub->openSession(
522       kEndpointInfos[0].id, messageHub2->getId(), /* toEndpointId= */ 10);
523   EXPECT_EQ(sessionId, SESSION_ID_INVALID);
524 }
525 
TEST_F(MessageRouterTest,ThirdMessageHubTriesToFindOthersSession)526 TEST_F(MessageRouterTest, ThirdMessageHubTriesToFindOthersSession) {
527   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
528   Session sessionFromCallback1;
529   Session sessionFromCallback2;
530   Session sessionFromCallback3;
531   MessageHubCallbackStoreData callback(/* message= */ nullptr,
532                                        &sessionFromCallback1);
533   MessageHubCallbackStoreData callback2(/* message= */ nullptr,
534                                         &sessionFromCallback2);
535   MessageHubCallbackStoreData callback3(/* message= */ nullptr,
536                                         &sessionFromCallback3);
537 
538   std::optional<MessageRouter::MessageHub> messageHub =
539       router.registerMessageHub("hub1", /* id= */ 1, callback);
540   EXPECT_TRUE(messageHub.has_value());
541   std::optional<MessageRouter::MessageHub> messageHub2 =
542       router.registerMessageHub("hub2", /* id= */ 2, callback2);
543   EXPECT_TRUE(messageHub2.has_value());
544   std::optional<MessageRouter::MessageHub> messageHub3 =
545       router.registerMessageHub("hub3", /* id= */ 3, callback3);
546   EXPECT_TRUE(messageHub3.has_value());
547 
548   // Open session from messageHub:1 to messageHub2:2
549   SessionId sessionId = messageHub->openSession(
550       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
551   EXPECT_NE(sessionId, SESSION_ID_INVALID);
552 
553   // Get session from messageHub and compare it with messageHub2
554   std::optional<Session> sessionAfterRegistering =
555       messageHub->getSessionWithId(sessionId);
556   EXPECT_TRUE(sessionAfterRegistering.has_value());
557   EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
558   EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
559             messageHub->getId());
560   EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
561             kEndpointInfos[0].id);
562   EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
563   EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
564   std::optional<Session> sessionAfterRegistering2 =
565       messageHub2->getSessionWithId(sessionId);
566   EXPECT_TRUE(sessionAfterRegistering2.has_value());
567   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
568 
569   // Third message hub tries to find the session - not found
570   EXPECT_FALSE(messageHub3->getSessionWithId(sessionId).has_value());
571   // Third message hub tries to close the session - not found
572   EXPECT_FALSE(messageHub3->closeSession(sessionId));
573 
574   // Get session from messageHub and compare it with messageHub2 again
575   sessionAfterRegistering = messageHub->getSessionWithId(sessionId);
576   EXPECT_TRUE(sessionAfterRegistering.has_value());
577   EXPECT_EQ(sessionAfterRegistering->sessionId, sessionId);
578   EXPECT_EQ(sessionAfterRegistering->initiator.messageHubId,
579             messageHub->getId());
580   EXPECT_EQ(sessionAfterRegistering->initiator.endpointId,
581             kEndpointInfos[0].id);
582   EXPECT_EQ(sessionAfterRegistering->peer.messageHubId, messageHub2->getId());
583   EXPECT_EQ(sessionAfterRegistering->peer.endpointId, kEndpointInfos[1].id);
584   sessionAfterRegistering2 = messageHub2->getSessionWithId(sessionId);
585   EXPECT_TRUE(sessionAfterRegistering2.has_value());
586   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
587 
588   // Close the session and verify it is closed on both message hubs
589   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback1);
590   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback2);
591   EXPECT_TRUE(messageHub->closeSession(sessionId));
592   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback1);
593   EXPECT_EQ(*sessionAfterRegistering, sessionFromCallback2);
594   EXPECT_NE(*sessionAfterRegistering, sessionFromCallback3);
595   EXPECT_FALSE(messageHub->getSessionWithId(sessionId).has_value());
596   EXPECT_FALSE(messageHub2->getSessionWithId(sessionId).has_value());
597 }
598 
TEST_F(MessageRouterTest,ThreeMessageHubsAndThreeSessions)599 TEST_F(MessageRouterTest, ThreeMessageHubsAndThreeSessions) {
600   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
601   MessageHubCallbackStoreData callback(/* message= */ nullptr,
602                                        /* session= */ nullptr);
603   MessageHubCallbackStoreData callback2(/* message= */ nullptr,
604                                         /* session= */ nullptr);
605   MessageHubCallbackStoreData callback3(/* message= */ nullptr,
606                                         /* session= */ nullptr);
607 
608   std::optional<MessageRouter::MessageHub> messageHub =
609       router.registerMessageHub("hub1", /* id= */ 1, callback);
610   EXPECT_TRUE(messageHub.has_value());
611   std::optional<MessageRouter::MessageHub> messageHub2 =
612       router.registerMessageHub("hub2", /* id= */ 2, callback2);
613   EXPECT_TRUE(messageHub2.has_value());
614   std::optional<MessageRouter::MessageHub> messageHub3 =
615       router.registerMessageHub("hub3", /* id= */ 3, callback3);
616   EXPECT_TRUE(messageHub3.has_value());
617 
618   // Open session from messageHub:1 to messageHub2:2
619   SessionId sessionId = messageHub->openSession(
620       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
621   EXPECT_NE(sessionId, SESSION_ID_INVALID);
622 
623   // Open session from messageHub2:2 to messageHub3:3
624   SessionId sessionId2 = messageHub2->openSession(
625       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
626   EXPECT_NE(sessionId, SESSION_ID_INVALID);
627 
628   // Open session from messageHub3:3 to messageHub1:1
629   SessionId sessionId3 = messageHub3->openSession(
630       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
631   EXPECT_NE(sessionId, SESSION_ID_INVALID);
632 
633   // Get sessions and compare
634   // Find session: MessageHub1:1 -> MessageHub2:2
635   std::optional<Session> sessionAfterRegistering =
636       messageHub->getSessionWithId(sessionId);
637   EXPECT_TRUE(sessionAfterRegistering.has_value());
638   std::optional<Session> sessionAfterRegistering2 =
639       messageHub2->getSessionWithId(sessionId);
640   EXPECT_TRUE(sessionAfterRegistering2.has_value());
641   EXPECT_FALSE(messageHub3->getSessionWithId(sessionId).has_value());
642   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
643 
644   // Find session: MessageHub2:2 -> MessageHub3:3
645   sessionAfterRegistering = messageHub2->getSessionWithId(sessionId2);
646   EXPECT_TRUE(sessionAfterRegistering.has_value());
647   sessionAfterRegistering2 = messageHub3->getSessionWithId(sessionId2);
648   EXPECT_TRUE(sessionAfterRegistering2.has_value());
649   EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
650   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
651 
652   // Find session: MessageHub3:3 -> MessageHub1:1
653   sessionAfterRegistering = messageHub3->getSessionWithId(sessionId3);
654   EXPECT_TRUE(sessionAfterRegistering.has_value());
655   sessionAfterRegistering2 = messageHub->getSessionWithId(sessionId3);
656   EXPECT_TRUE(sessionAfterRegistering2.has_value());
657   EXPECT_FALSE(messageHub2->getSessionWithId(sessionId3).has_value());
658   EXPECT_EQ(*sessionAfterRegistering, *sessionAfterRegistering2);
659 
660   // Close sessions from receivers and verify they are closed on all hubs
661   EXPECT_TRUE(messageHub2->closeSession(sessionId));
662   EXPECT_TRUE(messageHub3->closeSession(sessionId2));
663   EXPECT_TRUE(messageHub->closeSession(sessionId3));
664   for (SessionId id : {sessionId, sessionId2, sessionId3}) {
665     EXPECT_FALSE(messageHub->getSessionWithId(id).has_value());
666     EXPECT_FALSE(messageHub2->getSessionWithId(id).has_value());
667     EXPECT_FALSE(messageHub3->getSessionWithId(id).has_value());
668   }
669 }
670 
TEST_F(MessageRouterTest,SendMessageToSession)671 TEST_F(MessageRouterTest, SendMessageToSession) {
672   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
673   constexpr size_t kMessageSize = 5;
674   pw::UniquePtr<std::byte[]> messageData =
675       mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
676   for (size_t i = 0; i < 5; ++i) {
677     messageData[i] = static_cast<std::byte>(i + 1);
678   }
679 
680   Message messageFromCallback1;
681   Message messageFromCallback2;
682   Message messageFromCallback3;
683   Session sessionFromCallback1;
684   Session sessionFromCallback2;
685   Session sessionFromCallback3;
686   MessageHubCallbackStoreData callback(&messageFromCallback1,
687                                        &sessionFromCallback1);
688   MessageHubCallbackStoreData callback2(&messageFromCallback2,
689                                         &sessionFromCallback2);
690   MessageHubCallbackStoreData callback3(&messageFromCallback3,
691                                         &sessionFromCallback3);
692 
693   std::optional<MessageRouter::MessageHub> messageHub =
694       router.registerMessageHub("hub1", /* id= */ 1, callback);
695   EXPECT_TRUE(messageHub.has_value());
696   std::optional<MessageRouter::MessageHub> messageHub2 =
697       router.registerMessageHub("hub2", /* id= */ 2, callback2);
698   EXPECT_TRUE(messageHub2.has_value());
699   std::optional<MessageRouter::MessageHub> messageHub3 =
700       router.registerMessageHub("hub3", /* id= */ 3, callback3);
701   EXPECT_TRUE(messageHub3.has_value());
702 
703   // Open session from messageHub:1 to messageHub2:2
704   SessionId sessionId = messageHub->openSession(
705       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
706   EXPECT_NE(sessionId, SESSION_ID_INVALID);
707 
708   // Open session from messageHub2:2 to messageHub3:3
709   SessionId sessionId2 = messageHub2->openSession(
710       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
711   EXPECT_NE(sessionId, SESSION_ID_INVALID);
712 
713   // Open session from messageHub3:3 to messageHub1:1
714   SessionId sessionId3 = messageHub3->openSession(
715       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
716   EXPECT_NE(sessionId, SESSION_ID_INVALID);
717 
718   // Send message from messageHub:1 to messageHub2:2
719   ASSERT_TRUE(messageHub->sendMessage(std::move(messageData), kMessageSize,
720                                       /* messageType= */ 1,
721                                       /* messagePermissions= */ 0, sessionId));
722   EXPECT_EQ(messageFromCallback2.sessionId, sessionId);
723   EXPECT_EQ(messageFromCallback2.sender.messageHubId, messageHub->getId());
724   EXPECT_EQ(messageFromCallback2.sender.endpointId, kEndpointInfos[0].id);
725   EXPECT_EQ(messageFromCallback2.recipient.messageHubId, messageHub2->getId());
726   EXPECT_EQ(messageFromCallback2.recipient.endpointId, kEndpointInfos[1].id);
727   EXPECT_EQ(messageFromCallback2.messageType, 1);
728   EXPECT_EQ(messageFromCallback2.messagePermissions, 0);
729   EXPECT_EQ(messageFromCallback2.length, kMessageSize);
730   for (size_t i = 0; i < kMessageSize; ++i) {
731     EXPECT_EQ(messageFromCallback2.data[i], static_cast<std::byte>(i + 1));
732   }
733 
734   messageData = mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
735   for (size_t i = 0; i < 5; ++i) {
736     messageData[i] = static_cast<std::byte>(i + 1);
737   }
738 
739   // Send message from messageHub2:2 to messageHub:1
740   ASSERT_TRUE(messageHub2->sendMessage(std::move(messageData), kMessageSize,
741                                        /* messageType= */ 2,
742                                        /* messagePermissions= */ 3, sessionId));
743   EXPECT_EQ(messageFromCallback1.sessionId, sessionId);
744   EXPECT_EQ(messageFromCallback1.sender.messageHubId, messageHub2->getId());
745   EXPECT_EQ(messageFromCallback1.sender.endpointId, kEndpointInfos[1].id);
746   EXPECT_EQ(messageFromCallback1.recipient.messageHubId, messageHub->getId());
747   EXPECT_EQ(messageFromCallback1.recipient.endpointId, kEndpointInfos[0].id);
748   EXPECT_EQ(messageFromCallback1.messageType, 2);
749   EXPECT_EQ(messageFromCallback1.messagePermissions, 3);
750   EXPECT_EQ(messageFromCallback1.length, kMessageSize);
751   for (size_t i = 0; i < kMessageSize; ++i) {
752     EXPECT_EQ(messageFromCallback1.data[i], static_cast<std::byte>(i + 1));
753   }
754 }
755 
TEST_F(MessageRouterTest,SendMessageToSessionUsingPointerAndFreeCallback)756 TEST_F(MessageRouterTest, SendMessageToSessionUsingPointerAndFreeCallback) {
757   struct FreeCallbackContext {
758     bool *freeCallbackCalled;
759     std::byte *message;
760     size_t length;
761   };
762 
763   pw::Vector<
764       MessageRouterCallbackAllocator<FreeCallbackContext>::FreeCallbackRecord,
765       10>
766       freeCallbackRecords;
767   MessageRouterCallbackAllocator<FreeCallbackContext> allocator(
768       [](std::byte *message, size_t length,
769          const FreeCallbackContext &context) {
770         *context.freeCallbackCalled =
771             message == context.message && length == context.length;
772       },
773       freeCallbackRecords);
774 
775   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
776   constexpr size_t kMessageSize = 5;
777   std::byte messageData[kMessageSize];
778   for (size_t i = 0; i < 5; ++i) {
779     messageData[i] = static_cast<std::byte>(i + 1);
780   }
781 
782   Message messageFromCallback1;
783   Message messageFromCallback2;
784   Message messageFromCallback3;
785   Session sessionFromCallback1;
786   Session sessionFromCallback2;
787   Session sessionFromCallback3;
788   MessageHubCallbackStoreData callback(&messageFromCallback1,
789                                        &sessionFromCallback1);
790   MessageHubCallbackStoreData callback2(&messageFromCallback2,
791                                         &sessionFromCallback2);
792   MessageHubCallbackStoreData callback3(&messageFromCallback3,
793                                         &sessionFromCallback3);
794 
795   std::optional<MessageRouter::MessageHub> messageHub =
796       router.registerMessageHub("hub1", /* id= */ 1, callback);
797   EXPECT_TRUE(messageHub.has_value());
798   std::optional<MessageRouter::MessageHub> messageHub2 =
799       router.registerMessageHub("hub2", /* id= */ 2, callback2);
800   EXPECT_TRUE(messageHub2.has_value());
801   std::optional<MessageRouter::MessageHub> messageHub3 =
802       router.registerMessageHub("hub3", /* id= */ 3, callback3);
803   EXPECT_TRUE(messageHub3.has_value());
804 
805   // Open session from messageHub:1 to messageHub2:2
806   SessionId sessionId = messageHub->openSession(
807       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
808   EXPECT_NE(sessionId, SESSION_ID_INVALID);
809 
810   // Open session from messageHub2:2 to messageHub3:3
811   SessionId sessionId2 = messageHub2->openSession(
812       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
813   EXPECT_NE(sessionId, SESSION_ID_INVALID);
814 
815   // Open session from messageHub3:3 to messageHub1:1
816   SessionId sessionId3 = messageHub3->openSession(
817       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
818   EXPECT_NE(sessionId, SESSION_ID_INVALID);
819 
820   // Send message from messageHub:1 to messageHub2:2
821   bool freeCallbackCalled = false;
822   FreeCallbackContext freeCallbackContext = {
823       .freeCallbackCalled = &freeCallbackCalled,
824       .message = messageData,
825       .length = kMessageSize,
826   };
827   pw::UniquePtr<std::byte[]> data = allocator.MakeUniqueArrayWithCallback(
828       messageData, kMessageSize, std::move(freeCallbackContext));
829   ASSERT_NE(data.get(), nullptr);
830 
831   ASSERT_TRUE(messageHub->sendMessage(std::move(data), kMessageSize,
832                                       /* messageType= */ 1,
833                                       /* messagePermissions= */ 0, sessionId));
834   EXPECT_EQ(messageFromCallback2.sessionId, sessionId);
835   EXPECT_EQ(messageFromCallback2.sender.messageHubId, messageHub->getId());
836   EXPECT_EQ(messageFromCallback2.sender.endpointId, kEndpointInfos[0].id);
837   EXPECT_EQ(messageFromCallback2.recipient.messageHubId, messageHub2->getId());
838   EXPECT_EQ(messageFromCallback2.recipient.endpointId, kEndpointInfos[1].id);
839   EXPECT_EQ(messageFromCallback2.messageType, 1);
840   EXPECT_EQ(messageFromCallback2.messagePermissions, 0);
841   EXPECT_EQ(messageFromCallback2.length, kMessageSize);
842   for (size_t i = 0; i < kMessageSize; ++i) {
843     EXPECT_EQ(messageFromCallback2.data[i], static_cast<std::byte>(i + 1));
844   }
845 
846   // Check if free callback was called
847   EXPECT_FALSE(freeCallbackCalled);
848   EXPECT_EQ(messageFromCallback2.data.get(), messageData);
849   messageFromCallback2.data.Reset();
850   EXPECT_TRUE(freeCallbackCalled);
851 
852   // Send message from messageHub2:2 to messageHub:1
853   freeCallbackCalled = false;
854   FreeCallbackContext freeCallbackContext2 = {
855       .freeCallbackCalled = &freeCallbackCalled,
856       .message = messageData,
857       .length = kMessageSize,
858   };
859   data = allocator.MakeUniqueArrayWithCallback(messageData, kMessageSize,
860                                                std::move(freeCallbackContext2));
861   ASSERT_NE(data.get(), nullptr);
862 
863   ASSERT_TRUE(messageHub2->sendMessage(std::move(data), kMessageSize,
864                                        /* messageType= */ 2,
865                                        /* messagePermissions= */ 3, sessionId));
866   EXPECT_EQ(messageFromCallback1.sessionId, sessionId);
867   EXPECT_EQ(messageFromCallback1.sender.messageHubId, messageHub2->getId());
868   EXPECT_EQ(messageFromCallback1.sender.endpointId, kEndpointInfos[1].id);
869   EXPECT_EQ(messageFromCallback1.recipient.messageHubId, messageHub->getId());
870   EXPECT_EQ(messageFromCallback1.recipient.endpointId, kEndpointInfos[0].id);
871   EXPECT_EQ(messageFromCallback1.messageType, 2);
872   EXPECT_EQ(messageFromCallback1.messagePermissions, 3);
873   EXPECT_EQ(messageFromCallback1.length, kMessageSize);
874   for (size_t i = 0; i < kMessageSize; ++i) {
875     EXPECT_EQ(messageFromCallback1.data[i], static_cast<std::byte>(i + 1));
876   }
877 
878   // Check if free callback was called
879   EXPECT_FALSE(freeCallbackCalled);
880   EXPECT_EQ(messageFromCallback1.data.get(), messageData);
881   messageFromCallback1.data.Reset();
882   EXPECT_TRUE(freeCallbackCalled);
883 }
884 
TEST_F(MessageRouterTest,SendMessageToSessionInvalidHubAndSession)885 TEST_F(MessageRouterTest, SendMessageToSessionInvalidHubAndSession) {
886   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
887   constexpr size_t kMessageSize = 5;
888   pw::UniquePtr<std::byte[]> messageData =
889       mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
890   for (size_t i = 0; i < 5; ++i) {
891     messageData[i] = static_cast<std::byte>(i + 1);
892   }
893 
894   Message messageFromCallback1;
895   Message messageFromCallback2;
896   Message messageFromCallback3;
897   Session sessionFromCallback1;
898   Session sessionFromCallback2;
899   Session sessionFromCallback3;
900   MessageHubCallbackStoreData callback(&messageFromCallback1,
901                                        &sessionFromCallback1);
902   MessageHubCallbackStoreData callback2(&messageFromCallback2,
903                                         &sessionFromCallback2);
904   MessageHubCallbackStoreData callback3(&messageFromCallback3,
905                                         &sessionFromCallback3);
906 
907   std::optional<MessageRouter::MessageHub> messageHub =
908       router.registerMessageHub("hub1", /* id= */ 1, callback);
909   EXPECT_TRUE(messageHub.has_value());
910   std::optional<MessageRouter::MessageHub> messageHub2 =
911       router.registerMessageHub("hub2", /* id= */ 2, callback2);
912   EXPECT_TRUE(messageHub2.has_value());
913   std::optional<MessageRouter::MessageHub> messageHub3 =
914       router.registerMessageHub("hub3", /* id= */ 3, callback3);
915   EXPECT_TRUE(messageHub3.has_value());
916 
917   // Open session from messageHub:1 to messageHub2:2
918   SessionId sessionId = messageHub->openSession(
919       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
920   EXPECT_NE(sessionId, SESSION_ID_INVALID);
921 
922   // Open session from messageHub2:2 to messageHub3:3
923   SessionId sessionId2 = messageHub2->openSession(
924       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
925   EXPECT_NE(sessionId, SESSION_ID_INVALID);
926 
927   // Open session from messageHub3:3 to messageHub1:1
928   SessionId sessionId3 = messageHub3->openSession(
929       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
930   EXPECT_NE(sessionId, SESSION_ID_INVALID);
931 
932   // Send message from messageHub:1 to messageHub2:2
933   EXPECT_FALSE(messageHub->sendMessage(std::move(messageData), kMessageSize,
934                                        /* messageType= */ 1,
935                                        /* messagePermissions= */ 0,
936                                        sessionId2));
937   EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData), kMessageSize,
938                                         /* messageType= */ 2,
939                                         /* messagePermissions= */ 3,
940                                         sessionId3));
941   EXPECT_FALSE(messageHub3->sendMessage(std::move(messageData), kMessageSize,
942                                         /* messageType= */ 2,
943                                         /* messagePermissions= */ 3,
944                                         sessionId));
945 }
946 
TEST_F(MessageRouterTest,SendMessageToSessionCallbackFailureClosesSession)947 TEST_F(MessageRouterTest, SendMessageToSessionCallbackFailureClosesSession) {
948   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
949   constexpr size_t kMessageSize = 5;
950   pw::UniquePtr<std::byte[]> messageData =
951       mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
952   for (size_t i = 0; i < 5; ++i) {
953     messageData[i] = static_cast<std::byte>(i + 1);
954   }
955 
956   bool wasMessageReceivedCalled1 = false;
957   bool wasMessageReceivedCalled2 = false;
958   bool wasMessageReceivedCalled3 = false;
959   MessageHubCallbackAlwaysFails callback1(
960       &wasMessageReceivedCalled1,
961       /* wasSessionClosedCalled= */ nullptr);
962   MessageHubCallbackAlwaysFails callback2(
963       &wasMessageReceivedCalled2,
964       /* wasSessionClosedCalled= */ nullptr);
965   MessageHubCallbackAlwaysFails callback3(
966       &wasMessageReceivedCalled3,
967       /* wasSessionClosedCalled= */ nullptr);
968 
969   std::optional<MessageRouter::MessageHub> messageHub =
970       router.registerMessageHub("hub1", /* id= */ 1, callback1);
971   EXPECT_TRUE(messageHub.has_value());
972   std::optional<MessageRouter::MessageHub> messageHub2 =
973       router.registerMessageHub("hub2", /* id= */ 2, callback2);
974   EXPECT_TRUE(messageHub2.has_value());
975   std::optional<MessageRouter::MessageHub> messageHub3 =
976       router.registerMessageHub("hub3", /* id= */ 3, callback3);
977   EXPECT_TRUE(messageHub3.has_value());
978 
979   // Open session from messageHub:1 to messageHub2:2
980   SessionId sessionId = messageHub->openSession(
981       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
982   EXPECT_NE(sessionId, SESSION_ID_INVALID);
983 
984   // Open session from messageHub2:2 to messageHub3:3
985   SessionId sessionId2 = messageHub2->openSession(
986       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
987   EXPECT_NE(sessionId, SESSION_ID_INVALID);
988 
989   // Open session from messageHub3:3 to messageHub1:1
990   SessionId sessionId3 = messageHub3->openSession(
991       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
992   EXPECT_NE(sessionId, SESSION_ID_INVALID);
993 
994   // Send message from messageHub2:2 to messageHub3:3
995   EXPECT_FALSE(wasMessageReceivedCalled1);
996   EXPECT_FALSE(wasMessageReceivedCalled2);
997   EXPECT_FALSE(wasMessageReceivedCalled3);
998   EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
999   EXPECT_TRUE(messageHub2->getSessionWithId(sessionId2).has_value());
1000   EXPECT_TRUE(messageHub3->getSessionWithId(sessionId2).has_value());
1001 
1002   EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData), kMessageSize,
1003                                         /* messageType= */ 1,
1004                                         /* messagePermissions= */ 0,
1005                                         sessionId2));
1006   EXPECT_FALSE(wasMessageReceivedCalled1);
1007   EXPECT_FALSE(wasMessageReceivedCalled2);
1008   EXPECT_TRUE(wasMessageReceivedCalled3);
1009   EXPECT_FALSE(messageHub->getSessionWithId(sessionId2).has_value());
1010   EXPECT_FALSE(messageHub2->getSessionWithId(sessionId2).has_value());
1011   EXPECT_FALSE(messageHub3->getSessionWithId(sessionId2).has_value());
1012 
1013   // Try to send a message on the same session - should fail
1014   wasMessageReceivedCalled1 = false;
1015   wasMessageReceivedCalled2 = false;
1016   wasMessageReceivedCalled3 = false;
1017   messageData = mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
1018   for (size_t i = 0; i < 5; ++i) {
1019     messageData[i] = static_cast<std::byte>(i + 1);
1020   }
1021   EXPECT_FALSE(messageHub2->sendMessage(std::move(messageData), kMessageSize,
1022                                         /* messageType= */ 1,
1023                                         /* messagePermissions= */ 0,
1024                                         sessionId2));
1025   messageData = mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
1026   for (size_t i = 0; i < 5; ++i) {
1027     messageData[i] = static_cast<std::byte>(i + 1);
1028   }
1029   EXPECT_FALSE(messageHub3->sendMessage(std::move(messageData), kMessageSize,
1030                                         /* messageType= */ 1,
1031                                         /* messagePermissions= */ 0,
1032                                         sessionId2));
1033   EXPECT_FALSE(wasMessageReceivedCalled1);
1034   EXPECT_FALSE(wasMessageReceivedCalled2);
1035   EXPECT_FALSE(wasMessageReceivedCalled3);
1036 }
1037 
TEST_F(MessageRouterTest,MessageHubCallbackCanCallOtherMessageHubAPIs)1038 TEST_F(MessageRouterTest, MessageHubCallbackCanCallOtherMessageHubAPIs) {
1039   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1040   constexpr size_t kMessageSize = 5;
1041   pw::UniquePtr<std::byte[]> messageData =
1042       mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
1043   for (size_t i = 0; i < 5; ++i) {
1044     messageData[i] = static_cast<std::byte>(i + 1);
1045   }
1046 
1047   MessageHubCallbackCallsMessageHubApisDuringCallback callback;
1048   MessageHubCallbackCallsMessageHubApisDuringCallback callback2;
1049   MessageHubCallbackCallsMessageHubApisDuringCallback callback3;
1050 
1051   std::optional<MessageRouter::MessageHub> messageHub =
1052       router.registerMessageHub("hub1", /* id= */ 1, callback);
1053   EXPECT_TRUE(messageHub.has_value());
1054   callback.setMessageHub(&messageHub.value());
1055   std::optional<MessageRouter::MessageHub> messageHub2 =
1056       router.registerMessageHub("hub2", /* id= */ 2, callback2);
1057   EXPECT_TRUE(messageHub2.has_value());
1058   callback2.setMessageHub(&messageHub2.value());
1059   std::optional<MessageRouter::MessageHub> messageHub3 =
1060       router.registerMessageHub("hub3", /* id= */ 3, callback3);
1061   EXPECT_TRUE(messageHub3.has_value());
1062   callback3.setMessageHub(&messageHub3.value());
1063 
1064   // Open session from messageHub:1 to messageHub2:2
1065   SessionId sessionId = messageHub->openSession(
1066       kEndpointInfos[0].id, messageHub2->getId(), kEndpointInfos[1].id);
1067   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1068 
1069   // Open session from messageHub2:2 to messageHub3:3
1070   SessionId sessionId2 = messageHub2->openSession(
1071       kEndpointInfos[1].id, messageHub3->getId(), kEndpointInfos[2].id);
1072   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1073 
1074   // Open session from messageHub3:3 to messageHub1:1
1075   SessionId sessionId3 = messageHub3->openSession(
1076       kEndpointInfos[2].id, messageHub->getId(), kEndpointInfos[0].id);
1077   EXPECT_NE(sessionId, SESSION_ID_INVALID);
1078 
1079   // Send message from messageHub:1 to messageHub2:2
1080   EXPECT_TRUE(messageHub->sendMessage(std::move(messageData), kMessageSize,
1081                                       /* messageType= */ 1,
1082                                       /* messagePermissions= */ 0, sessionId));
1083 
1084   // Send message from messageHub2:2 to messageHub:1
1085   messageData = mAllocator.MakeUniqueArray<std::byte>(kMessageSize);
1086   for (size_t i = 0; i < 5; ++i) {
1087     messageData[i] = static_cast<std::byte>(i + 1);
1088   }
1089   EXPECT_TRUE(messageHub2->sendMessage(std::move(messageData), kMessageSize,
1090                                        /* messageType= */ 2,
1091                                        /* messagePermissions= */ 3, sessionId));
1092 
1093   // Close all sessions
1094   EXPECT_TRUE(messageHub->closeSession(sessionId));
1095   EXPECT_TRUE(messageHub2->closeSession(sessionId2));
1096   EXPECT_TRUE(messageHub3->closeSession(sessionId3));
1097 
1098   // If we finish the test, both callbacks should have been called
1099   // If the router holds the lock during the callback, this test will timeout
1100 }
1101 
TEST_F(MessageRouterTest,ForEachEndpointOfHub)1102 TEST_F(MessageRouterTest, ForEachEndpointOfHub) {
1103   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1104   MessageHubCallbackStoreData callback(/* message= */ nullptr,
1105                                        /* session= */ nullptr);
1106   std::optional<MessageRouter::MessageHub> messageHub =
1107       router.registerMessageHub("hub1", /* id= */ 1, callback);
1108   EXPECT_TRUE(messageHub.has_value());
1109 
1110   DynamicVector<EndpointInfo> endpoints;
1111   EXPECT_TRUE(router.forEachEndpointOfHub(
1112       /* messageHubId= */ 1, [&endpoints](const EndpointInfo &info) {
1113         endpoints.push_back(info);
1114         return false;
1115       }));
1116   EXPECT_EQ(endpoints.size(), kNumEndpoints);
1117   for (size_t i = 0; i < endpoints.size(); ++i) {
1118     EXPECT_EQ(endpoints[i].id, kEndpointInfos[i].id);
1119     EXPECT_STREQ(endpoints[i].name, kEndpointInfos[i].name);
1120     EXPECT_EQ(endpoints[i].version, kEndpointInfos[i].version);
1121     EXPECT_EQ(endpoints[i].type, kEndpointInfos[i].type);
1122     EXPECT_EQ(endpoints[i].requiredPermissions,
1123               kEndpointInfos[i].requiredPermissions);
1124   }
1125 }
1126 
TEST_F(MessageRouterTest,ForEachEndpoint)1127 TEST_F(MessageRouterTest, ForEachEndpoint) {
1128   const char *kHubName = "hub1";
1129   constexpr MessageHubId kHubId = 1;
1130 
1131   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1132   MessageHubCallbackStoreData callback(/* message= */ nullptr,
1133                                        /* session= */ nullptr);
1134   std::optional<MessageRouter::MessageHub> messageHub =
1135       router.registerMessageHub(kHubName, kHubId, callback);
1136   EXPECT_TRUE(messageHub.has_value());
1137 
1138   DynamicVector<std::pair<MessageHubInfo, EndpointInfo>> endpoints;
1139   router.forEachEndpoint(
1140       [&endpoints](const MessageHubInfo &hubInfo, const EndpointInfo &info) {
1141         endpoints.push_back(std::make_pair(hubInfo, info));
1142       });
1143   EXPECT_EQ(endpoints.size(), kNumEndpoints);
1144   for (size_t i = 0; i < endpoints.size(); ++i) {
1145     EXPECT_EQ(endpoints[i].first.id, kHubId);
1146     EXPECT_STREQ(endpoints[i].first.name, kHubName);
1147 
1148     EXPECT_EQ(endpoints[i].second.id, kEndpointInfos[i].id);
1149     EXPECT_STREQ(endpoints[i].second.name, kEndpointInfos[i].name);
1150     EXPECT_EQ(endpoints[i].second.version, kEndpointInfos[i].version);
1151     EXPECT_EQ(endpoints[i].second.type, kEndpointInfos[i].type);
1152     EXPECT_EQ(endpoints[i].second.requiredPermissions,
1153               kEndpointInfos[i].requiredPermissions);
1154   }
1155 }
1156 
TEST_F(MessageRouterTest,ForEachEndpointOfHubInvalidHub)1157 TEST_F(MessageRouterTest, ForEachEndpointOfHubInvalidHub) {
1158   MessageRouterWithStorage<kMaxMessageHubs, kMaxSessions> router;
1159   MessageHubCallbackStoreData callback(/* message= */ nullptr,
1160                                        /* session= */ nullptr);
1161   std::optional<MessageRouter::MessageHub> messageHub =
1162       router.registerMessageHub("hub1", /* id= */ 1, callback);
1163   EXPECT_TRUE(messageHub.has_value());
1164 
1165   DynamicVector<EndpointInfo> endpoints;
1166   EXPECT_FALSE(router.forEachEndpointOfHub(
1167       /* messageHubId= */ 2, [&endpoints](const EndpointInfo &info) {
1168         endpoints.push_back(info);
1169         return false;
1170       }));
1171   EXPECT_EQ(endpoints.size(), 0);
1172 }
1173 
1174 }  // namespace
1175 }  // namespace chre::message
1176