xref: /aosp_15_r20/external/pigweed/pw_rpc/raw/client_test.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_rpc/client.h"
16 
17 #include <optional>
18 
19 #include "pw_rpc/internal/client_call.h"
20 #include "pw_rpc/internal/packet.h"
21 #include "pw_rpc/raw/client_reader_writer.h"
22 #include "pw_rpc/raw/client_testing.h"
23 #include "pw_unit_test/framework.h"
24 
25 namespace pw::rpc {
26 
27 void UnaryMethod();
28 void BidirectionalStreamMethod();
29 
30 template <>
31 struct internal::MethodInfo<UnaryMethod> {
32   static constexpr uint32_t kServiceId = 100;
33   static constexpr uint32_t kMethodId = 200;
34   static constexpr MethodType kType = MethodType::kUnary;
35 };
36 
37 template <>
38 struct internal::MethodInfo<BidirectionalStreamMethod> {
39   static constexpr uint32_t kServiceId = 100;
40   static constexpr uint32_t kMethodId = 300;
41   static constexpr MethodType kType = MethodType::kBidirectionalStreaming;
42 };
43 
44 namespace {
45 
46 // Captures payload from on_next and statuses from on_error and on_completed.
47 // Payloads are assumed to be null-terminated strings.
48 template <typename CallType>
49 struct CallContext {
OnNextpw::rpc::__anon783869c00111::CallContext50   auto OnNext() {
51     return [this](ConstByteSpan string) {
52       payload = reinterpret_cast<const char*>(string.data());
53     };
54   }
55 
UnaryOnCompletedpw::rpc::__anon783869c00111::CallContext56   auto UnaryOnCompleted() {
57     return [this](ConstByteSpan string, Status status) {
58       payload = reinterpret_cast<const char*>(string.data());
59       completed = status;
60     };
61   }
62 
StreamOnCompletedpw::rpc::__anon783869c00111::CallContext63   auto StreamOnCompleted() {
64     return [this](Status status) { completed = status; };
65   }
66 
OnErrorpw::rpc::__anon783869c00111::CallContext67   auto OnError() {
68     return [this](Status status) { error = status; };
69   }
70 
71   CallType call;
72 
73   const char* payload;
74   std::optional<Status> completed;
75   std::optional<Status> error;
76 };
77 
78 template <auto kMethod, typename Context>
StartUnaryCall(Context & context,std::optional<uint32_t> channel_id=std::nullopt)79 CallContext<RawUnaryReceiver> StartUnaryCall(
80     Context& context, std::optional<uint32_t> channel_id = std::nullopt)
81     PW_LOCKS_EXCLUDED(internal::rpc_lock()) {
82   CallContext<RawUnaryReceiver> call_context;
83   call_context.call =
84       internal::UnaryResponseClientCall::Start<RawUnaryReceiver>(
85           context.client(),
86           channel_id.value_or(context.channel().id()),
87           internal::MethodInfo<kMethod>::kServiceId,
88           internal::MethodInfo<kMethod>::kMethodId,
89           call_context.UnaryOnCompleted(),
90           call_context.OnError(),
91           {});
92   return call_context;
93 }
94 
95 template <auto kMethod, typename Context>
StartStreamCall(Context & context,std::optional<uint32_t> channel_id=std::nullopt)96 CallContext<RawClientReaderWriter> StartStreamCall(
97     Context& context, std::optional<uint32_t> channel_id = std::nullopt)
98     PW_LOCKS_EXCLUDED(internal::rpc_lock()) {
99   CallContext<RawClientReaderWriter> call_context;
100   call_context.call =
101       internal::StreamResponseClientCall::Start<RawClientReaderWriter>(
102           context.client(),
103           channel_id.value_or(context.channel().id()),
104           internal::MethodInfo<kMethod>::kServiceId,
105           internal::MethodInfo<kMethod>::kMethodId,
106           call_context.OnNext(),
107           call_context.StreamOnCompleted(),
108           call_context.OnError(),
109           {});
110   return call_context;
111 }
112 
TEST(Client,ProcessPacket_InvokesUnaryCallbacks)113 TEST(Client, ProcessPacket_InvokesUnaryCallbacks) {
114   RawClientTestContext context;
115   CallContext call_context = StartUnaryCall<UnaryMethod>(context);
116 
117   ASSERT_NE(call_context.completed, OkStatus());
118 
119   context.server().SendResponse<UnaryMethod>(as_bytes(span("you nary?!?")),
120                                              OkStatus());
121 
122   ASSERT_NE(call_context.payload, nullptr);
123   EXPECT_STREQ(call_context.payload, "you nary?!?");
124   EXPECT_EQ(call_context.completed, OkStatus());
125   EXPECT_FALSE(call_context.call.active());
126 }
127 
TEST(Client,ProcessPacket_NoCallbackSet)128 TEST(Client, ProcessPacket_NoCallbackSet) {
129   RawClientTestContext context;
130   CallContext call_context = StartUnaryCall<UnaryMethod>(context);
131   call_context.call.set_on_completed(nullptr);
132 
133   ASSERT_NE(call_context.completed, OkStatus());
134 
135   context.server().SendResponse<UnaryMethod>(as_bytes(span("you nary?!?")),
136                                              OkStatus());
137   EXPECT_FALSE(call_context.call.active());
138 }
139 
TEST(Client,ProcessPacket_InvokesStreamCallbacks)140 TEST(Client, ProcessPacket_InvokesStreamCallbacks) {
141   RawClientTestContext context;
142   auto call = StartStreamCall<BidirectionalStreamMethod>(context);
143 
144   context.server().SendServerStream<BidirectionalStreamMethod>(
145       as_bytes(span("<=>")));
146 
147   ASSERT_NE(call.payload, nullptr);
148   EXPECT_STREQ(call.payload, "<=>");
149 
150   context.server().SendResponse<BidirectionalStreamMethod>(Status::NotFound());
151 
152   EXPECT_EQ(call.completed, Status::NotFound());
153 }
154 
TEST(Client,ProcessPacket_UnassignedChannelId_ReturnsDataLoss)155 TEST(Client, ProcessPacket_UnassignedChannelId_ReturnsDataLoss) {
156   RawClientTestContext context;
157   auto call_cts = StartStreamCall<BidirectionalStreamMethod>(context);
158 
159   std::byte encoded[64];
160   uint32_t arbitrary_call_id = 24602;
161   Result<span<const std::byte>> result =
162       internal::Packet(
163           internal::pwpb::PacketType::kResponse,
164           Channel::kUnassignedChannelId,
165           internal::MethodInfo<BidirectionalStreamMethod>::kServiceId,
166           internal::MethodInfo<BidirectionalStreamMethod>::kMethodId,
167           arbitrary_call_id)
168           .Encode(encoded);
169   ASSERT_TRUE(result.ok());
170 
171   EXPECT_EQ(context.client().ProcessPacket(*result), Status::DataLoss());
172 }
173 
TEST(Client,ProcessPacket_InvokesErrorCallback)174 TEST(Client, ProcessPacket_InvokesErrorCallback) {
175   RawClientTestContext context;
176   auto call = StartStreamCall<BidirectionalStreamMethod>(context);
177 
178   context.server().SendServerError<BidirectionalStreamMethod>(
179       Status::Aborted());
180 
181   EXPECT_EQ(call.error, Status::Aborted());
182 }
183 
TEST(Client,ProcessPacket_SendsClientErrorOnUnregisteredServerStream)184 TEST(Client, ProcessPacket_SendsClientErrorOnUnregisteredServerStream) {
185   RawClientTestContext context;
186   context.server().SendServerStream<BidirectionalStreamMethod>({});
187 
188   StatusView errors = context.output().errors<BidirectionalStreamMethod>();
189   ASSERT_EQ(errors.size(), 1u);
190   EXPECT_EQ(errors.front(), Status::FailedPrecondition());
191 }
192 
TEST(Client,ProcessPacket_NonServerStreamOnUnregisteredCall_SendsNothing)193 TEST(Client, ProcessPacket_NonServerStreamOnUnregisteredCall_SendsNothing) {
194   RawClientTestContext context;
195   context.server().SendServerError<UnaryMethod>(Status::NotFound());
196   EXPECT_EQ(context.output().total_packets(), 0u);
197 
198   context.server().SendResponse<UnaryMethod>({}, Status::Unavailable());
199   EXPECT_EQ(context.output().total_packets(), 0u);
200 }
201 
TEST(Client,ProcessPacket_ReturnsDataLossOnBadPacket)202 TEST(Client, ProcessPacket_ReturnsDataLossOnBadPacket) {
203   RawClientTestContext context;
204 
205   constexpr std::byte bad_packet[]{
206       std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
207   EXPECT_EQ(context.client().ProcessPacket(bad_packet), Status::DataLoss());
208 }
209 
TEST(Client,ProcessPacket_ReturnsInvalidArgumentOnServerPacket)210 TEST(Client, ProcessPacket_ReturnsInvalidArgumentOnServerPacket) {
211   RawClientTestContext context;
212 
213   std::byte encoded[64];
214   Result<span<const std::byte>> result =
215       internal::Packet(internal::pwpb::PacketType::REQUEST, 1, 2, 3, 4)
216           .Encode(encoded);
217   ASSERT_TRUE(result.ok());
218 
219   EXPECT_EQ(context.client().ProcessPacket(*result), Status::InvalidArgument());
220 }
221 
GetChannel(internal::Endpoint & endpoint,uint32_t id)222 const internal::ChannelBase* GetChannel(internal::Endpoint& endpoint,
223                                         uint32_t id) {
224   internal::RpcLockGuard lock;
225   return endpoint.GetInternalChannel(id);
226 }
227 
TEST(Client,CloseChannel_NoCalls)228 TEST(Client, CloseChannel_NoCalls) {
229   RawClientTestContext ctx;
230   ASSERT_NE(nullptr, GetChannel(ctx.client(), ctx.kDefaultChannelId));
231   EXPECT_EQ(OkStatus(), ctx.client().CloseChannel(ctx.kDefaultChannelId));
232   EXPECT_EQ(nullptr, GetChannel(ctx.client(), ctx.kDefaultChannelId));
233   EXPECT_EQ(ctx.output().total_packets(), 0u);
234 }
235 
TEST(Client,CloseChannel_UnknownChannel)236 TEST(Client, CloseChannel_UnknownChannel) {
237   RawClientTestContext ctx;
238   ASSERT_EQ(nullptr, GetChannel(ctx.client(), 13579));
239   EXPECT_EQ(Status::NotFound(), ctx.client().CloseChannel(13579));
240 }
241 
TEST(Client,CloseChannel_CallsErrorCallback)242 TEST(Client, CloseChannel_CallsErrorCallback) {
243   RawClientTestContext ctx;
244   CallContext call_ctx = StartUnaryCall<UnaryMethod>(ctx);
245 
246   ASSERT_NE(call_ctx.completed, OkStatus());
247   ASSERT_EQ(1u,
248             static_cast<internal::Endpoint&>(ctx.client()).active_call_count());
249 
250   EXPECT_EQ(OkStatus(), ctx.client().CloseChannel(1));
251 
252   EXPECT_EQ(0u,
253             static_cast<internal::Endpoint&>(ctx.client()).active_call_count());
254   ASSERT_EQ(call_ctx.error, Status::Aborted());  // set by the on_error callback
255 }
256 
TEST(Client,CloseChannel_ErrorCallbackReusesCallObjectForCallOnClosedChannel)257 TEST(Client, CloseChannel_ErrorCallbackReusesCallObjectForCallOnClosedChannel) {
258   struct {
259     RawClientTestContext<> ctx;
260     CallContext<RawUnaryReceiver> call_ctx;
261   } context;
262 
263   context.call_ctx = StartUnaryCall<UnaryMethod>(context.ctx);
264   context.call_ctx.call.set_on_error([&context](Status error) {
265     context.call_ctx = StartUnaryCall<UnaryMethod>(context.ctx, 1);
266     context.call_ctx.error = error;
267   });
268 
269   EXPECT_EQ(OkStatus(), context.ctx.client().CloseChannel(1));
270   EXPECT_EQ(context.call_ctx.error, Status::Aborted());
271 
272   EXPECT_FALSE(context.call_ctx.call.active());
273   EXPECT_EQ(0u,
274             static_cast<internal::Endpoint&>(context.ctx.client())
275                 .active_call_count());
276 }
277 
TEST(Client,CloseChannel_ErrorCallbackReusesCallObjectForActiveCall)278 TEST(Client, CloseChannel_ErrorCallbackReusesCallObjectForActiveCall) {
279   class ContextWithTwoChannels {
280    public:
281     ContextWithTwoChannels()
282         : channels_{Channel::Create<1>(&channel_output_),
283                     Channel::Create<2>(&channel_output_)},
284           client_(channels_),
285           packet_buffer{},
286           fake_server_(channel_output_, client_, 1, packet_buffer) {}
287 
288     Channel& channel() { return channels_[0]; }
289     Client& client() { return client_; }
290     CallContext<RawUnaryReceiver>& call_ctx() { return call_context_; }
291     RawUnaryReceiver& call() { return call_context_.call; }
292 
293     void StartCall(uint32_t channel_id) {
294       call_context_ = StartUnaryCall<UnaryMethod>(*this, channel_id);
295     }
296 
297    private:
298     RawFakeChannelOutput<10, 256> channel_output_;
299     Channel channels_[2];
300     Client client_;
301     std::byte packet_buffer[64];
302     FakeServer fake_server_;
303 
304     CallContext<RawUnaryReceiver> call_context_;
305   } context;
306 
307   context.StartCall(1);
308   context.call().set_on_error([&context](Status error) {
309     context.StartCall(2);
310     context.call_ctx().error = error;
311   });
312 
313   EXPECT_EQ(OkStatus(), context.client().CloseChannel(1));
314   EXPECT_EQ(context.call_ctx().error, Status::Aborted());
315 
316   EXPECT_TRUE(context.call().active());
317   EXPECT_EQ(
318       1u,
319       static_cast<internal::Endpoint&>(context.client()).active_call_count());
320 }
321 
TEST(Client,OpenChannel_UnusedSlot)322 TEST(Client, OpenChannel_UnusedSlot) {
323   RawClientTestContext ctx;
324   ASSERT_EQ(OkStatus(), ctx.client().CloseChannel(1));
325   ASSERT_EQ(nullptr, GetChannel(ctx.client(), 9));
326 
327   EXPECT_EQ(OkStatus(), ctx.client().OpenChannel(9, ctx.output()));
328 
329   EXPECT_NE(nullptr, GetChannel(ctx.client(), 9));
330 }
331 
TEST(Client,OpenChannel_AlreadyExists)332 TEST(Client, OpenChannel_AlreadyExists) {
333   RawClientTestContext ctx;
334   ASSERT_NE(nullptr, GetChannel(ctx.client(), 1));
335   EXPECT_EQ(Status::AlreadyExists(), ctx.client().OpenChannel(1, ctx.output()));
336 }
337 
TEST(Client,OpenChannel_AdditionalSlot)338 TEST(Client, OpenChannel_AdditionalSlot) {
339   RawClientTestContext ctx;
340 
341   constexpr Status kExpected =
342       PW_RPC_DYNAMIC_ALLOCATION == 0 ? Status::ResourceExhausted() : OkStatus();
343   EXPECT_EQ(kExpected, ctx.client().OpenChannel(19823, ctx.output()));
344 }
345 
346 }  // namespace
347 }  // namespace pw::rpc
348