xref: /aosp_15_r20/external/pigweed/pw_rpc/call_test.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_rpc/internal/call.h"
16 
17 #include <algorithm>
18 #include <array>
19 #include <cstdint>
20 #include <cstring>
21 #include <optional>
22 
23 #include "pw_rpc/internal/test_utils.h"
24 #include "pw_rpc/service.h"
25 #include "pw_rpc_private/fake_server_reader_writer.h"
26 #include "pw_rpc_private/test_method.h"
27 #include "pw_status/status_with_size.h"
28 #include "pw_unit_test/framework.h"
29 
30 namespace pw::rpc {
31 
32 class TestService : public Service {
33  public:
TestService(uint32_t id)34   constexpr TestService(uint32_t id) : Service(id, method) {}
35 
36   static constexpr internal::TestMethodUnion method = internal::TestMethod(8);
37 };
38 
39 namespace internal {
40 namespace {
41 
42 constexpr uint32_t kChannelId = 99;
43 constexpr uint32_t kServiceId = 16;
44 constexpr uint32_t kMethodId = 8;
45 constexpr uint32_t kCallId = 327;
46 constexpr Packet kPacket(
47     pwpb::PacketType::REQUEST, kChannelId, kServiceId, kMethodId, kCallId);
48 
49 using ::pw::rpc::internal::test::FakeServerReader;
50 using ::pw::rpc::internal::test::FakeServerReaderWriter;
51 using ::pw::rpc::internal::test::FakeServerWriter;
52 using ::std::byte;
53 using ::testing::Test;
54 
55 static_assert(sizeof(Call) ==
56                   // IntrusiveList::Item pointer
57                   sizeof(IntrusiveList<Call>::Item) +
58                       // Endpoint pointer
59                       sizeof(Endpoint*) +
60                       // call_id, channel_id, service_id, method_id
61                       4 * sizeof(uint32_t) +
62                       // Packed state and properties
63                       sizeof(void*) +
64                       // on_error and on_next callbacks
65                       2 * sizeof(Function<void(Status)>),
66               "Unexpected padding in Call!");
67 
68 static_assert(sizeof(CallProperties) == sizeof(uint8_t));
69 
TEST(CallProperties,ValuesMatch)70 TEST(CallProperties, ValuesMatch) {
71   constexpr CallProperties props_1(
72       MethodType::kBidirectionalStreaming, kClientCall, kRawProto);
73   static_assert(props_1.method_type() == MethodType::kBidirectionalStreaming);
74   static_assert(props_1.call_type() == kClientCall);
75   static_assert(props_1.callback_proto_type() == kRawProto);
76 
77   constexpr CallProperties props_2(
78       MethodType::kClientStreaming, kServerCall, kProtoStruct);
79   static_assert(props_2.method_type() == MethodType::kClientStreaming);
80   static_assert(props_2.call_type() == kServerCall);
81   static_assert(props_2.callback_proto_type() == kProtoStruct);
82 
83   constexpr CallProperties props_3(
84       MethodType::kUnary, kClientCall, kProtoStruct);
85   static_assert(props_3.method_type() == MethodType::kUnary);
86   static_assert(props_3.call_type() == kClientCall);
87   static_assert(props_3.callback_proto_type() == kProtoStruct);
88 }
89 
90 class ServerWriterTest : public Test {
91  public:
ServerWriterTest()92   ServerWriterTest() : context_(TestService::method.method()) {
93     rpc_lock().lock();
94     FakeServerWriter writer_temp(context_.get().ClaimLocked());
95     rpc_lock().unlock();
96     writer_ = std::move(writer_temp);
97   }
98 
99   ServerContextForTest<TestService, kChannelId, kServiceId, kCallId> context_;
100   FakeServerWriter writer_;
101 };
102 
TEST_F(ServerWriterTest,ConstructWithContext_StartsOpen)103 TEST_F(ServerWriterTest, ConstructWithContext_StartsOpen) {
104   EXPECT_TRUE(writer_.active());
105 }
106 
TEST_F(ServerWriterTest,Move_ClosesOriginal)107 TEST_F(ServerWriterTest, Move_ClosesOriginal) {
108   FakeServerWriter moved(std::move(writer_));
109 
110 #ifndef __clang_analyzer__
111   EXPECT_FALSE(writer_.active());
112 #endif  // ignore use-after-move
113   EXPECT_TRUE(moved.active());
114 }
115 
TEST_F(ServerWriterTest,DefaultConstruct_Closed)116 TEST_F(ServerWriterTest, DefaultConstruct_Closed) {
117   FakeServerWriter writer;
118   EXPECT_FALSE(writer.active());
119 }
120 
TEST_F(ServerWriterTest,Construct_RegistersWithServer)121 TEST_F(ServerWriterTest, Construct_RegistersWithServer) {
122   RpcLockGuard lock;
123   IntrusiveList<Call>::iterator call = context_.server().FindCall(kPacket);
124   ASSERT_NE(call, context_.server().calls_end());
125   EXPECT_EQ(static_cast<void*>(&*call), static_cast<void*>(&writer_));
126 }
127 
TEST_F(ServerWriterTest,Destruct_RemovesFromServer)128 TEST_F(ServerWriterTest, Destruct_RemovesFromServer) {
129   {
130     // Note `lock_guard` cannot be used here, because while the constructor
131     // of `FakeServerWriter` requires the lock be held, the destructor acquires
132     // it!
133     rpc_lock().lock();
134     FakeServerWriter writer(context_.get().ClaimLocked());
135     rpc_lock().unlock();
136   }
137 
138   RpcLockGuard lock;
139   EXPECT_EQ(context_.server().FindCall(kPacket), context_.server().calls_end());
140 }
141 
TEST_F(ServerWriterTest,Finish_RemovesFromServer)142 TEST_F(ServerWriterTest, Finish_RemovesFromServer) {
143   EXPECT_EQ(OkStatus(), writer_.Finish());
144   RpcLockGuard lock;
145   EXPECT_EQ(context_.server().FindCall(kPacket), context_.server().calls_end());
146 }
147 
TEST_F(ServerWriterTest,Finish_SendsResponse)148 TEST_F(ServerWriterTest, Finish_SendsResponse) {
149   EXPECT_EQ(OkStatus(), writer_.Finish());
150 
151   ASSERT_EQ(context_.output().total_packets(), 1u);
152   const Packet& packet = context_.output().last_packet();
153   EXPECT_EQ(packet.type(), pwpb::PacketType::RESPONSE);
154   EXPECT_EQ(packet.channel_id(), context_.channel_id());
155   EXPECT_EQ(packet.service_id(), context_.service_id());
156   EXPECT_EQ(packet.method_id(), context_.get().method().id());
157   EXPECT_TRUE(packet.payload().empty());
158   EXPECT_EQ(packet.status(), OkStatus());
159 }
160 
TEST_F(ServerWriterTest,Finish_ReturnsStatusFromChannelSend)161 TEST_F(ServerWriterTest, Finish_ReturnsStatusFromChannelSend) {
162   context_.output().set_send_status(Status::Unauthenticated());
163 
164   // All non-OK statuses are remapped to UNKNOWN.
165   EXPECT_EQ(Status::Unknown(), writer_.Finish());
166 }
167 
TEST_F(ServerWriterTest,Finish)168 TEST_F(ServerWriterTest, Finish) {
169   ASSERT_TRUE(writer_.active());
170   EXPECT_EQ(OkStatus(), writer_.Finish());
171   EXPECT_FALSE(writer_.active());
172   EXPECT_EQ(Status::FailedPrecondition(), writer_.Finish());
173 }
174 
TEST_F(ServerWriterTest,Open_SendsPacketWithPayload)175 TEST_F(ServerWriterTest, Open_SendsPacketWithPayload) {
176   constexpr byte data[] = {byte{0xf0}, byte{0x0d}};
177   ASSERT_EQ(OkStatus(), writer_.Write(data));
178 
179   byte encoded[64];
180   auto result = context_.server_stream(data).Encode(encoded);
181   ASSERT_EQ(OkStatus(), result.status());
182 
183   ConstByteSpan payload = context_.output().last_packet().payload();
184   EXPECT_EQ(sizeof(data), payload.size());
185   EXPECT_EQ(0, std::memcmp(data, payload.data(), sizeof(data)));
186 }
187 
TEST_F(ServerWriterTest,Open_WriteCallback_SendsPacketWithPayload)188 TEST_F(ServerWriterTest, Open_WriteCallback_SendsPacketWithPayload) {
189   constexpr byte data[] = {byte{0xaa}, byte{0xbb}, byte{0xcc}, byte{0xdd}};
190 
191   ASSERT_EQ(OkStatus(), writer_.Write([&data](ByteSpan buffer) {
192     std::memcpy(buffer.data(), data, sizeof(data));
193     return StatusWithSize(sizeof(data));
194   }));
195 
196   byte encoded[64];
197   auto result = context_.server_stream(data).Encode(encoded);
198   ASSERT_EQ(OkStatus(), result.status());
199 
200   ConstByteSpan payload = context_.output().last_packet().payload();
201   EXPECT_EQ(sizeof(data), payload.size());
202   EXPECT_EQ(0, std::memcmp(data, payload.data(), sizeof(data)));
203 }
204 
TEST_F(ServerWriterTest,Open_WriteCallback_ErrorPropagates)205 TEST_F(ServerWriterTest, Open_WriteCallback_ErrorPropagates) {
206   ASSERT_EQ(Status::DataLoss(),
207             writer_.Write([](ByteSpan) { return StatusWithSize::DataLoss(); }));
208 }
209 
TEST_F(ServerWriterTest,Open_WriteCallback_NullptrReturnsInvalidArgument)210 TEST_F(ServerWriterTest, Open_WriteCallback_NullptrReturnsInvalidArgument) {
211   ASSERT_EQ(Status::InvalidArgument(), writer_.Write(nullptr));
212 }
213 
TEST_F(ServerWriterTest,Closed_IgnoresFinish)214 TEST_F(ServerWriterTest, Closed_IgnoresFinish) {
215   EXPECT_EQ(OkStatus(), writer_.Finish());
216   EXPECT_EQ(Status::FailedPrecondition(), writer_.Finish());
217 }
218 
TEST_F(ServerWriterTest,DefaultConstructor_NoClientStream)219 TEST_F(ServerWriterTest, DefaultConstructor_NoClientStream) {
220   FakeServerWriter writer;
221   RpcLockGuard lock;
222   EXPECT_FALSE(writer.as_server_call().has_client_stream());
223   EXPECT_FALSE(writer.as_server_call().client_requested_completion());
224 }
225 
TEST_F(ServerWriterTest,Open_NoClientStream)226 TEST_F(ServerWriterTest, Open_NoClientStream) {
227   RpcLockGuard lock;
228   EXPECT_FALSE(writer_.as_server_call().has_client_stream());
229   EXPECT_TRUE(writer_.as_server_call().has_server_stream());
230   EXPECT_FALSE(writer_.as_server_call().client_requested_completion());
231 }
232 
233 class ServerReaderTest : public Test {
234  public:
ServerReaderTest()235   ServerReaderTest() : context_(TestService::method.method()) {
236     rpc_lock().lock();
237     FakeServerReader reader_temp(context_.get().ClaimLocked());
238     rpc_lock().unlock();
239     reader_ = std::move(reader_temp);
240   }
241 
242   ServerContextForTest<TestService> context_;
243   FakeServerReader reader_;
244 };
245 
TEST_F(ServerReaderTest,DefaultConstructor_StreamClosed)246 TEST_F(ServerReaderTest, DefaultConstructor_StreamClosed) {
247   FakeServerReader reader;
248   EXPECT_FALSE(reader.as_server_call().active());
249   RpcLockGuard lock;
250   EXPECT_FALSE(reader.as_server_call().client_requested_completion());
251 }
252 
TEST_F(ServerReaderTest,Open_ClientStreamStartsOpen)253 TEST_F(ServerReaderTest, Open_ClientStreamStartsOpen) {
254   RpcLockGuard lock;
255   EXPECT_TRUE(reader_.as_server_call().has_client_stream());
256   EXPECT_FALSE(reader_.as_server_call().client_requested_completion());
257 }
258 
TEST_F(ServerReaderTest,Close_ClosesStream)259 TEST_F(ServerReaderTest, Close_ClosesStream) {
260   EXPECT_TRUE(reader_.as_server_call().active());
261   rpc_lock().lock();
262   EXPECT_FALSE(reader_.as_server_call().client_requested_completion());
263   rpc_lock().unlock();
264   EXPECT_EQ(OkStatus(),
265             reader_.as_server_call().CloseAndSendResponse(OkStatus()));
266 
267   EXPECT_FALSE(reader_.as_server_call().active());
268   RpcLockGuard lock;
269   EXPECT_TRUE(reader_.as_server_call().client_requested_completion());
270 }
271 
TEST_F(ServerReaderTest,RequestCompletion_OnlyMakesClientNotReady)272 TEST_F(ServerReaderTest, RequestCompletion_OnlyMakesClientNotReady) {
273   EXPECT_TRUE(reader_.active());
274   rpc_lock().lock();
275   EXPECT_FALSE(reader_.as_server_call().client_requested_completion());
276   reader_.as_server_call().HandleClientRequestedCompletion();
277 
278   EXPECT_TRUE(reader_.active());
279   RpcLockGuard lock;
280   EXPECT_TRUE(reader_.as_server_call().client_requested_completion());
281 }
282 
283 class ServerReaderWriterTest : public Test {
284  public:
ServerReaderWriterTest()285   ServerReaderWriterTest() : context_(TestService::method.method()) {
286     rpc_lock().lock();
287     FakeServerReaderWriter reader_writer_temp(context_.get().ClaimLocked());
288     rpc_lock().unlock();
289     reader_writer_ = std::move(reader_writer_temp);
290   }
291 
292   ServerContextForTest<TestService> context_;
293   FakeServerReaderWriter reader_writer_;
294 };
295 
TEST_F(ServerReaderWriterTest,Move_MaintainsClientStream)296 TEST_F(ServerReaderWriterTest, Move_MaintainsClientStream) {
297   FakeServerReaderWriter destination;
298 
299   rpc_lock().lock();
300   EXPECT_FALSE(destination.as_server_call().client_requested_completion());
301   rpc_lock().unlock();
302 
303   destination = std::move(reader_writer_);
304   RpcLockGuard lock;
305   EXPECT_TRUE(destination.as_server_call().has_client_stream());
306   EXPECT_FALSE(destination.as_server_call().client_requested_completion());
307 }
308 
TEST_F(ServerReaderWriterTest,Move_MovesCallbacks)309 TEST_F(ServerReaderWriterTest, Move_MovesCallbacks) {
310   int calls = 0;
311   reader_writer_.set_on_error([&calls](Status) { calls += 1; });
312   reader_writer_.set_on_next([&calls](ConstByteSpan) { calls += 1; });
313   reader_writer_.set_on_completion_requested_if_enabled(
314       [&calls]() { calls += 1; });
315 
316   FakeServerReaderWriter destination(std::move(reader_writer_));
317   rpc_lock().lock();
318   destination.as_server_call().HandlePayload({});
319   rpc_lock().lock();
320   destination.as_server_call().HandleClientRequestedCompletion();
321   rpc_lock().lock();
322   destination.as_server_call().HandleError(Status::Unknown());
323 
324   EXPECT_EQ(calls, 2 + PW_RPC_COMPLETION_REQUEST_CALLBACK);
325 }
326 
TEST_F(ServerReaderWriterTest,Move_ClearsCallAndChannelId)327 TEST_F(ServerReaderWriterTest, Move_ClearsCallAndChannelId) {
328   rpc_lock().lock();
329   reader_writer_.set_id(999);
330   EXPECT_NE(reader_writer_.channel_id_locked(), 0u);
331   rpc_lock().unlock();
332 
333   FakeServerReaderWriter destination(std::move(reader_writer_));
334 
335   RpcLockGuard lock;
336   EXPECT_EQ(reader_writer_.id(), 0u);
337   EXPECT_EQ(reader_writer_.channel_id_locked(), 0u);
338 }
339 
TEST_F(ServerReaderWriterTest,DefaultConstructorAssign_Reset)340 TEST_F(ServerReaderWriterTest, DefaultConstructorAssign_Reset) {
341   reader_writer_ = {};
342 
343   RpcLockGuard lock;
344   EXPECT_EQ(reader_writer_.service_id(), 0u);
345   EXPECT_EQ(reader_writer_.method_id(), 0u);
346 }
347 
TEST_F(ServerReaderWriterTest,Move_SourceAwaitingCleanup_CleansUpCalls)348 TEST_F(ServerReaderWriterTest, Move_SourceAwaitingCleanup_CleansUpCalls) {
349   std::optional<Status> on_error_cb;
350   reader_writer_.set_on_error([&on_error_cb](Status error) {
351     ASSERT_FALSE(on_error_cb.has_value());
352     on_error_cb = error;
353   });
354 
355   rpc_lock().lock();
356   context_.server().CloseCallAndMarkForCleanup(reader_writer_.as_server_call(),
357                                                Status::NotFound());
358   rpc_lock().unlock();
359 
360   FakeServerReaderWriter destination(std::move(reader_writer_));
361 
362   EXPECT_EQ(Status::NotFound(), on_error_cb);
363 }
364 
TEST_F(ServerReaderWriterTest,Move_BothAwaitingCleanup_CleansUpCalls)365 TEST_F(ServerReaderWriterTest, Move_BothAwaitingCleanup_CleansUpCalls) {
366   rpc_lock().lock();
367   // Use call ID 123 so this call is distinct from the other.
368   FakeServerReaderWriter destination(context_.get(123).ClaimLocked());
369   rpc_lock().unlock();
370 
371   std::optional<Status> destination_on_error_cb;
372   destination.set_on_error([&destination_on_error_cb](Status error) {
373     ASSERT_FALSE(destination_on_error_cb.has_value());
374     destination_on_error_cb = error;
375   });
376 
377   std::optional<Status> source_on_error_cb;
378   reader_writer_.set_on_error([&source_on_error_cb](Status error) {
379     ASSERT_FALSE(source_on_error_cb.has_value());
380     source_on_error_cb = error;
381   });
382 
383   // Simulate these two calls being closed by another thread.
384   rpc_lock().lock();
385   context_.server().CloseCallAndMarkForCleanup(destination.as_server_call(),
386                                                Status::NotFound());
387   context_.server().CloseCallAndMarkForCleanup(reader_writer_.as_server_call(),
388                                                Status::Unauthenticated());
389   rpc_lock().unlock();
390 
391   destination = std::move(reader_writer_);
392 
393   EXPECT_EQ(Status::NotFound(), destination_on_error_cb);
394   EXPECT_EQ(Status::Unauthenticated(), source_on_error_cb);
395 }
396 
TEST_F(ServerReaderWriterTest,Close_ClearsCallAndChannelId)397 TEST_F(ServerReaderWriterTest, Close_ClearsCallAndChannelId) {
398   rpc_lock().lock();
399   reader_writer_.set_id(999);
400   EXPECT_NE(reader_writer_.channel_id_locked(), 0u);
401   rpc_lock().unlock();
402 
403   EXPECT_EQ(OkStatus(), reader_writer_.Finish());
404 
405   RpcLockGuard lock;
406   EXPECT_EQ(reader_writer_.id(), 0u);
407   EXPECT_EQ(reader_writer_.channel_id_locked(), 0u);
408 }
409 
410 }  // namespace
411 }  // namespace internal
412 }  // namespace pw::rpc
413