1 // Copyright 2022 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include <atomic>
16 #include <iostream>
17
18 #include "pw_function/function.h"
19 #include "pw_rpc/pwpb/client_server_testing_threaded.h"
20 #include "pw_rpc_test_protos/test.rpc.pwpb.h"
21 #include "pw_sync/binary_semaphore.h"
22 #include "pw_thread/non_portable_test_thread_options.h"
23 #include "pw_unit_test/framework.h"
24
25 namespace pw::rpc {
26 namespace {
27
28 namespace TestRequest = ::pw::rpc::test::pwpb::TestRequest;
29 namespace TestResponse = ::pw::rpc::test::pwpb::TestResponse;
30 namespace TestStreamResponse = ::pw::rpc::test::pwpb::TestStreamResponse;
31
32 } // namespace
33
34 namespace test {
35
36 using GeneratedService = ::pw::rpc::test::pw_rpc::pwpb::TestService;
37
38 class TestService final : public GeneratedService::Service<TestService> {
39 public:
TestUnaryRpc(const TestRequest::Message & request,TestResponse::Message & response)40 Status TestUnaryRpc(const TestRequest::Message& request,
41 TestResponse::Message& response) {
42 response.value = request.integer + 1;
43 return static_cast<Status::Code>(request.status_code);
44 }
45
TestAnotherUnaryRpc(const TestRequest::Message & request,TestResponse::Message & response)46 Status TestAnotherUnaryRpc(const TestRequest::Message& request,
47 TestResponse::Message& response) {
48 response.value = 42;
49 response.repeated_field.SetEncoder(
50 [](TestResponse::StreamEncoder& encoder) {
51 constexpr std::array<uint32_t, 3> kValues = {7, 8, 9};
52 return encoder.WriteRepeatedField(kValues);
53 });
54 return static_cast<Status::Code>(request.status_code);
55 }
56
TestServerStreamRpc(const TestRequest::Message &,ServerWriter<TestStreamResponse::Message> &)57 static void TestServerStreamRpc(const TestRequest::Message&,
58 ServerWriter<TestStreamResponse::Message>&) {}
59
TestClientStreamRpc(ServerReader<TestRequest::Message,TestStreamResponse::Message> &)60 void TestClientStreamRpc(
61 ServerReader<TestRequest::Message, TestStreamResponse::Message>&) {}
62
TestBidirectionalStreamRpc(ServerReaderWriter<TestRequest::Message,TestStreamResponse::Message> &)63 void TestBidirectionalStreamRpc(
64 ServerReaderWriter<TestRequest::Message, TestStreamResponse::Message>&) {}
65 };
66
67 } // namespace test
68
69 namespace {
70
71 class RpcCaller {
72 public:
73 template <auto kMethod = test::GeneratedService::TestUnaryRpc>
BlockOnResponse(uint32_t i,Client & client,uint32_t channel_id)74 Status BlockOnResponse(uint32_t i, Client& client, uint32_t channel_id) {
75 TestRequest::Message request{.integer = i,
76 .status_code = OkStatus().code()};
77 response_status_ = OkStatus();
78 auto call = kMethod(
79 client,
80 channel_id,
81 request,
82 [this](const TestResponse::Message&, Status status) {
83 response_status_ = status;
84 semaphore_.release();
85 },
86 [this](Status status) {
87 response_status_ = status;
88 semaphore_.release();
89 });
90
91 semaphore_.acquire();
92 return response_status_;
93 }
94
95 private:
96 Status response_status_ = OkStatus();
97 pw::sync::BinarySemaphore semaphore_;
98 };
99
TEST(PwpbClientServerTestContextThreaded,ReceivesUnaryRpcResponseThreaded)100 TEST(PwpbClientServerTestContextThreaded, ReceivesUnaryRpcResponseThreaded) {
101 // TODO: b/290860904 - Replace TestOptionsThread0 with TestThreadContext.
102 PwpbClientServerTestContextThreaded<> ctx(thread::test::TestOptionsThread0());
103 test::TestService service;
104 ctx.server().RegisterService(service);
105
106 RpcCaller caller;
107 constexpr auto value = 1;
108 EXPECT_EQ(caller.BlockOnResponse(value, ctx.client(), ctx.channel().id()),
109 OkStatus());
110
111 const auto request =
112 ctx.request<test::pw_rpc::pwpb::TestService::TestUnaryRpc>(0);
113 const auto response =
114 ctx.response<test::pw_rpc::pwpb::TestService::TestUnaryRpc>(0);
115
116 EXPECT_EQ(value, request.integer);
117 EXPECT_EQ(value + 1, response.value);
118 }
119
TEST(PwpbClientServerTestContextThreaded,ReceivesMultipleResponsesThreaded)120 TEST(PwpbClientServerTestContextThreaded, ReceivesMultipleResponsesThreaded) {
121 PwpbClientServerTestContextThreaded<> ctx(thread::test::TestOptionsThread0());
122 test::TestService service;
123 ctx.server().RegisterService(service);
124
125 RpcCaller caller;
126 constexpr auto value1 = 1;
127 constexpr auto value2 = 2;
128 EXPECT_EQ(caller.BlockOnResponse(value1, ctx.client(), ctx.channel().id()),
129 OkStatus());
130 EXPECT_EQ(caller.BlockOnResponse(value2, ctx.client(), ctx.channel().id()),
131 OkStatus());
132
133 const auto request1 =
134 ctx.request<test::pw_rpc::pwpb::TestService::TestUnaryRpc>(0);
135 const auto request2 =
136 ctx.request<test::pw_rpc::pwpb::TestService::TestUnaryRpc>(1);
137 const auto response1 =
138 ctx.response<test::pw_rpc::pwpb::TestService::TestUnaryRpc>(0);
139 const auto response2 =
140 ctx.response<test::pw_rpc::pwpb::TestService::TestUnaryRpc>(1);
141
142 EXPECT_EQ(value1, request1.integer);
143 EXPECT_EQ(value2, request2.integer);
144 EXPECT_EQ(value1 + 1, response1.value);
145 EXPECT_EQ(value2 + 1, response2.value);
146 }
147
TEST(PwpbClientServerTestContextThreaded,ReceivesMultipleResponsesThreadedWithPacketProcessor)148 TEST(PwpbClientServerTestContextThreaded,
149 ReceivesMultipleResponsesThreadedWithPacketProcessor) {
150 using ProtectedInt = std::pair<int, pw::sync::Mutex>;
151 ProtectedInt server_counter{};
152 auto server_processor = [&server_counter](
153 ClientServer& client_server,
154 pw::ConstByteSpan packet) -> pw::Status {
155 server_counter.second.lock();
156 ++server_counter.first;
157 server_counter.second.unlock();
158 return client_server.ProcessPacket(packet);
159 };
160
161 ProtectedInt client_counter{};
162 auto client_processor = [&client_counter](
163 ClientServer& client_server,
164 pw::ConstByteSpan packet) -> pw::Status {
165 client_counter.second.lock();
166 ++client_counter.first;
167 client_counter.second.unlock();
168 return client_server.ProcessPacket(packet);
169 };
170
171 PwpbClientServerTestContextThreaded<> ctx(
172 thread::test::TestOptionsThread0(), server_processor, client_processor);
173 test::TestService service;
174 ctx.server().RegisterService(service);
175
176 RpcCaller caller;
177 constexpr auto value1 = 1;
178 constexpr auto value2 = 2;
179 EXPECT_EQ(caller.BlockOnResponse(value1, ctx.client(), ctx.channel().id()),
180 OkStatus());
181 EXPECT_EQ(caller.BlockOnResponse(value2, ctx.client(), ctx.channel().id()),
182 OkStatus());
183
184 const auto request1 =
185 ctx.request<test::pw_rpc::pwpb::TestService::TestUnaryRpc>(0);
186 const auto request2 =
187 ctx.request<test::pw_rpc::pwpb::TestService::TestUnaryRpc>(1);
188 const auto response1 =
189 ctx.response<test::pw_rpc::pwpb::TestService::TestUnaryRpc>(0);
190 const auto response2 =
191 ctx.response<test::pw_rpc::pwpb::TestService::TestUnaryRpc>(1);
192
193 EXPECT_EQ(value1, request1.integer);
194 EXPECT_EQ(value2, request2.integer);
195 EXPECT_EQ(value1 + 1, response1.value);
196 EXPECT_EQ(value2 + 1, response2.value);
197
198 server_counter.second.lock();
199 EXPECT_EQ(server_counter.first, 2);
200 server_counter.second.unlock();
201 client_counter.second.lock();
202 EXPECT_EQ(client_counter.first, 2);
203 client_counter.second.unlock();
204 }
205
TEST(PwpbClientServerTestContextThreaded,ResponseWithCallbacks)206 TEST(PwpbClientServerTestContextThreaded, ResponseWithCallbacks) {
207 PwpbClientServerTestContextThreaded<> ctx(thread::test::TestOptionsThread0());
208 test::TestService service;
209 ctx.server().RegisterService(service);
210
211 RpcCaller caller;
212 // DataLoss expected on initial response, since pwpb provides no way to
213 // populate response callback. We setup callbacks on response packet below.
214 EXPECT_EQ(caller.BlockOnResponse<test::GeneratedService::TestAnotherUnaryRpc>(
215 0, ctx.client(), ctx.channel().id()),
216 Status::DataLoss());
217
218 // To decode a response object that requires to set callbacks, pass it to the
219 // response() method as a parameter.
220 pw::Vector<uint32_t, 4> values{};
221
222 TestResponse::Message response{};
223 response.repeated_field.SetDecoder(
224 [&values](TestResponse::StreamDecoder& decoder) {
225 return decoder.ReadRepeatedField(values);
226 });
227 ctx.response<test::GeneratedService::TestAnotherUnaryRpc>(0, response);
228
229 EXPECT_EQ(42, response.value);
230
231 EXPECT_EQ(3u, values.size());
232 EXPECT_EQ(7u, values[0]);
233 EXPECT_EQ(8u, values[1]);
234 EXPECT_EQ(9u, values[2]);
235 }
236
237 } // namespace
238 } // namespace pw::rpc
239