1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_rpc/client.h"
16
17 #include <optional>
18
19 #include "pw_rpc/internal/client_call.h"
20 #include "pw_rpc/internal/packet.h"
21 #include "pw_rpc/raw/client_reader_writer.h"
22 #include "pw_rpc/raw/client_testing.h"
23 #include "pw_unit_test/framework.h"
24
25 namespace pw::rpc {
26
27 void UnaryMethod();
28 void BidirectionalStreamMethod();
29
30 template <>
31 struct internal::MethodInfo<UnaryMethod> {
32 static constexpr uint32_t kServiceId = 100;
33 static constexpr uint32_t kMethodId = 200;
34 static constexpr MethodType kType = MethodType::kUnary;
35 };
36
37 template <>
38 struct internal::MethodInfo<BidirectionalStreamMethod> {
39 static constexpr uint32_t kServiceId = 100;
40 static constexpr uint32_t kMethodId = 300;
41 static constexpr MethodType kType = MethodType::kBidirectionalStreaming;
42 };
43
44 namespace {
45
46 // Captures payload from on_next and statuses from on_error and on_completed.
47 // Payloads are assumed to be null-terminated strings.
48 template <typename CallType>
49 struct CallContext {
OnNextpw::rpc::__anon783869c00111::CallContext50 auto OnNext() {
51 return [this](ConstByteSpan string) {
52 payload = reinterpret_cast<const char*>(string.data());
53 };
54 }
55
UnaryOnCompletedpw::rpc::__anon783869c00111::CallContext56 auto UnaryOnCompleted() {
57 return [this](ConstByteSpan string, Status status) {
58 payload = reinterpret_cast<const char*>(string.data());
59 completed = status;
60 };
61 }
62
StreamOnCompletedpw::rpc::__anon783869c00111::CallContext63 auto StreamOnCompleted() {
64 return [this](Status status) { completed = status; };
65 }
66
OnErrorpw::rpc::__anon783869c00111::CallContext67 auto OnError() {
68 return [this](Status status) { error = status; };
69 }
70
71 CallType call;
72
73 const char* payload;
74 std::optional<Status> completed;
75 std::optional<Status> error;
76 };
77
78 template <auto kMethod, typename Context>
StartUnaryCall(Context & context,std::optional<uint32_t> channel_id=std::nullopt)79 CallContext<RawUnaryReceiver> StartUnaryCall(
80 Context& context, std::optional<uint32_t> channel_id = std::nullopt)
81 PW_LOCKS_EXCLUDED(internal::rpc_lock()) {
82 CallContext<RawUnaryReceiver> call_context;
83 call_context.call =
84 internal::UnaryResponseClientCall::Start<RawUnaryReceiver>(
85 context.client(),
86 channel_id.value_or(context.channel().id()),
87 internal::MethodInfo<kMethod>::kServiceId,
88 internal::MethodInfo<kMethod>::kMethodId,
89 call_context.UnaryOnCompleted(),
90 call_context.OnError(),
91 {});
92 return call_context;
93 }
94
95 template <auto kMethod, typename Context>
StartStreamCall(Context & context,std::optional<uint32_t> channel_id=std::nullopt)96 CallContext<RawClientReaderWriter> StartStreamCall(
97 Context& context, std::optional<uint32_t> channel_id = std::nullopt)
98 PW_LOCKS_EXCLUDED(internal::rpc_lock()) {
99 CallContext<RawClientReaderWriter> call_context;
100 call_context.call =
101 internal::StreamResponseClientCall::Start<RawClientReaderWriter>(
102 context.client(),
103 channel_id.value_or(context.channel().id()),
104 internal::MethodInfo<kMethod>::kServiceId,
105 internal::MethodInfo<kMethod>::kMethodId,
106 call_context.OnNext(),
107 call_context.StreamOnCompleted(),
108 call_context.OnError(),
109 {});
110 return call_context;
111 }
112
TEST(Client,ProcessPacket_InvokesUnaryCallbacks)113 TEST(Client, ProcessPacket_InvokesUnaryCallbacks) {
114 RawClientTestContext context;
115 CallContext call_context = StartUnaryCall<UnaryMethod>(context);
116
117 ASSERT_NE(call_context.completed, OkStatus());
118
119 context.server().SendResponse<UnaryMethod>(as_bytes(span("you nary?!?")),
120 OkStatus());
121
122 ASSERT_NE(call_context.payload, nullptr);
123 EXPECT_STREQ(call_context.payload, "you nary?!?");
124 EXPECT_EQ(call_context.completed, OkStatus());
125 EXPECT_FALSE(call_context.call.active());
126 }
127
TEST(Client,ProcessPacket_NoCallbackSet)128 TEST(Client, ProcessPacket_NoCallbackSet) {
129 RawClientTestContext context;
130 CallContext call_context = StartUnaryCall<UnaryMethod>(context);
131 call_context.call.set_on_completed(nullptr);
132
133 ASSERT_NE(call_context.completed, OkStatus());
134
135 context.server().SendResponse<UnaryMethod>(as_bytes(span("you nary?!?")),
136 OkStatus());
137 EXPECT_FALSE(call_context.call.active());
138 }
139
TEST(Client,ProcessPacket_InvokesStreamCallbacks)140 TEST(Client, ProcessPacket_InvokesStreamCallbacks) {
141 RawClientTestContext context;
142 auto call = StartStreamCall<BidirectionalStreamMethod>(context);
143
144 context.server().SendServerStream<BidirectionalStreamMethod>(
145 as_bytes(span("<=>")));
146
147 ASSERT_NE(call.payload, nullptr);
148 EXPECT_STREQ(call.payload, "<=>");
149
150 context.server().SendResponse<BidirectionalStreamMethod>(Status::NotFound());
151
152 EXPECT_EQ(call.completed, Status::NotFound());
153 }
154
TEST(Client,ProcessPacket_UnassignedChannelId_ReturnsDataLoss)155 TEST(Client, ProcessPacket_UnassignedChannelId_ReturnsDataLoss) {
156 RawClientTestContext context;
157 auto call_cts = StartStreamCall<BidirectionalStreamMethod>(context);
158
159 std::byte encoded[64];
160 uint32_t arbitrary_call_id = 24602;
161 Result<span<const std::byte>> result =
162 internal::Packet(
163 internal::pwpb::PacketType::kResponse,
164 Channel::kUnassignedChannelId,
165 internal::MethodInfo<BidirectionalStreamMethod>::kServiceId,
166 internal::MethodInfo<BidirectionalStreamMethod>::kMethodId,
167 arbitrary_call_id)
168 .Encode(encoded);
169 ASSERT_TRUE(result.ok());
170
171 EXPECT_EQ(context.client().ProcessPacket(*result), Status::DataLoss());
172 }
173
TEST(Client,ProcessPacket_InvokesErrorCallback)174 TEST(Client, ProcessPacket_InvokesErrorCallback) {
175 RawClientTestContext context;
176 auto call = StartStreamCall<BidirectionalStreamMethod>(context);
177
178 context.server().SendServerError<BidirectionalStreamMethod>(
179 Status::Aborted());
180
181 EXPECT_EQ(call.error, Status::Aborted());
182 }
183
TEST(Client,ProcessPacket_SendsClientErrorOnUnregisteredServerStream)184 TEST(Client, ProcessPacket_SendsClientErrorOnUnregisteredServerStream) {
185 RawClientTestContext context;
186 context.server().SendServerStream<BidirectionalStreamMethod>({});
187
188 StatusView errors = context.output().errors<BidirectionalStreamMethod>();
189 ASSERT_EQ(errors.size(), 1u);
190 EXPECT_EQ(errors.front(), Status::FailedPrecondition());
191 }
192
TEST(Client,ProcessPacket_NonServerStreamOnUnregisteredCall_SendsNothing)193 TEST(Client, ProcessPacket_NonServerStreamOnUnregisteredCall_SendsNothing) {
194 RawClientTestContext context;
195 context.server().SendServerError<UnaryMethod>(Status::NotFound());
196 EXPECT_EQ(context.output().total_packets(), 0u);
197
198 context.server().SendResponse<UnaryMethod>({}, Status::Unavailable());
199 EXPECT_EQ(context.output().total_packets(), 0u);
200 }
201
TEST(Client,ProcessPacket_ReturnsDataLossOnBadPacket)202 TEST(Client, ProcessPacket_ReturnsDataLossOnBadPacket) {
203 RawClientTestContext context;
204
205 constexpr std::byte bad_packet[]{
206 std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
207 EXPECT_EQ(context.client().ProcessPacket(bad_packet), Status::DataLoss());
208 }
209
TEST(Client,ProcessPacket_ReturnsInvalidArgumentOnServerPacket)210 TEST(Client, ProcessPacket_ReturnsInvalidArgumentOnServerPacket) {
211 RawClientTestContext context;
212
213 std::byte encoded[64];
214 Result<span<const std::byte>> result =
215 internal::Packet(internal::pwpb::PacketType::REQUEST, 1, 2, 3, 4)
216 .Encode(encoded);
217 ASSERT_TRUE(result.ok());
218
219 EXPECT_EQ(context.client().ProcessPacket(*result), Status::InvalidArgument());
220 }
221
GetChannel(internal::Endpoint & endpoint,uint32_t id)222 const internal::ChannelBase* GetChannel(internal::Endpoint& endpoint,
223 uint32_t id) {
224 internal::RpcLockGuard lock;
225 return endpoint.GetInternalChannel(id);
226 }
227
TEST(Client,CloseChannel_NoCalls)228 TEST(Client, CloseChannel_NoCalls) {
229 RawClientTestContext ctx;
230 ASSERT_NE(nullptr, GetChannel(ctx.client(), ctx.kDefaultChannelId));
231 EXPECT_EQ(OkStatus(), ctx.client().CloseChannel(ctx.kDefaultChannelId));
232 EXPECT_EQ(nullptr, GetChannel(ctx.client(), ctx.kDefaultChannelId));
233 EXPECT_EQ(ctx.output().total_packets(), 0u);
234 }
235
TEST(Client,CloseChannel_UnknownChannel)236 TEST(Client, CloseChannel_UnknownChannel) {
237 RawClientTestContext ctx;
238 ASSERT_EQ(nullptr, GetChannel(ctx.client(), 13579));
239 EXPECT_EQ(Status::NotFound(), ctx.client().CloseChannel(13579));
240 }
241
TEST(Client,CloseChannel_CallsErrorCallback)242 TEST(Client, CloseChannel_CallsErrorCallback) {
243 RawClientTestContext ctx;
244 CallContext call_ctx = StartUnaryCall<UnaryMethod>(ctx);
245
246 ASSERT_NE(call_ctx.completed, OkStatus());
247 ASSERT_EQ(1u,
248 static_cast<internal::Endpoint&>(ctx.client()).active_call_count());
249
250 EXPECT_EQ(OkStatus(), ctx.client().CloseChannel(1));
251
252 EXPECT_EQ(0u,
253 static_cast<internal::Endpoint&>(ctx.client()).active_call_count());
254 ASSERT_EQ(call_ctx.error, Status::Aborted()); // set by the on_error callback
255 }
256
TEST(Client,CloseChannel_ErrorCallbackReusesCallObjectForCallOnClosedChannel)257 TEST(Client, CloseChannel_ErrorCallbackReusesCallObjectForCallOnClosedChannel) {
258 struct {
259 RawClientTestContext<> ctx;
260 CallContext<RawUnaryReceiver> call_ctx;
261 } context;
262
263 context.call_ctx = StartUnaryCall<UnaryMethod>(context.ctx);
264 context.call_ctx.call.set_on_error([&context](Status error) {
265 context.call_ctx = StartUnaryCall<UnaryMethod>(context.ctx, 1);
266 context.call_ctx.error = error;
267 });
268
269 EXPECT_EQ(OkStatus(), context.ctx.client().CloseChannel(1));
270 EXPECT_EQ(context.call_ctx.error, Status::Aborted());
271
272 EXPECT_FALSE(context.call_ctx.call.active());
273 EXPECT_EQ(0u,
274 static_cast<internal::Endpoint&>(context.ctx.client())
275 .active_call_count());
276 }
277
TEST(Client,CloseChannel_ErrorCallbackReusesCallObjectForActiveCall)278 TEST(Client, CloseChannel_ErrorCallbackReusesCallObjectForActiveCall) {
279 class ContextWithTwoChannels {
280 public:
281 ContextWithTwoChannels()
282 : channels_{Channel::Create<1>(&channel_output_),
283 Channel::Create<2>(&channel_output_)},
284 client_(channels_),
285 packet_buffer{},
286 fake_server_(channel_output_, client_, 1, packet_buffer) {}
287
288 Channel& channel() { return channels_[0]; }
289 Client& client() { return client_; }
290 CallContext<RawUnaryReceiver>& call_ctx() { return call_context_; }
291 RawUnaryReceiver& call() { return call_context_.call; }
292
293 void StartCall(uint32_t channel_id) {
294 call_context_ = StartUnaryCall<UnaryMethod>(*this, channel_id);
295 }
296
297 private:
298 RawFakeChannelOutput<10, 256> channel_output_;
299 Channel channels_[2];
300 Client client_;
301 std::byte packet_buffer[64];
302 FakeServer fake_server_;
303
304 CallContext<RawUnaryReceiver> call_context_;
305 } context;
306
307 context.StartCall(1);
308 context.call().set_on_error([&context](Status error) {
309 context.StartCall(2);
310 context.call_ctx().error = error;
311 });
312
313 EXPECT_EQ(OkStatus(), context.client().CloseChannel(1));
314 EXPECT_EQ(context.call_ctx().error, Status::Aborted());
315
316 EXPECT_TRUE(context.call().active());
317 EXPECT_EQ(
318 1u,
319 static_cast<internal::Endpoint&>(context.client()).active_call_count());
320 }
321
TEST(Client,OpenChannel_UnusedSlot)322 TEST(Client, OpenChannel_UnusedSlot) {
323 RawClientTestContext ctx;
324 ASSERT_EQ(OkStatus(), ctx.client().CloseChannel(1));
325 ASSERT_EQ(nullptr, GetChannel(ctx.client(), 9));
326
327 EXPECT_EQ(OkStatus(), ctx.client().OpenChannel(9, ctx.output()));
328
329 EXPECT_NE(nullptr, GetChannel(ctx.client(), 9));
330 }
331
TEST(Client,OpenChannel_AlreadyExists)332 TEST(Client, OpenChannel_AlreadyExists) {
333 RawClientTestContext ctx;
334 ASSERT_NE(nullptr, GetChannel(ctx.client(), 1));
335 EXPECT_EQ(Status::AlreadyExists(), ctx.client().OpenChannel(1, ctx.output()));
336 }
337
TEST(Client,OpenChannel_AdditionalSlot)338 TEST(Client, OpenChannel_AdditionalSlot) {
339 RawClientTestContext ctx;
340
341 constexpr Status kExpected =
342 PW_RPC_DYNAMIC_ALLOCATION == 0 ? Status::ResourceExhausted() : OkStatus();
343 EXPECT_EQ(kExpected, ctx.client().OpenChannel(19823, ctx.output()));
344 }
345
346 } // namespace
347 } // namespace pw::rpc
348