xref: /aosp_15_r20/external/pigweed/pw_rpc/public/pw_rpc/internal/test_utils.h (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 // Internal-only testing utilities. public/pw_rpc/test_method_context.h provides
16 // improved public-facing utilities for testing RPC services.
17 #pragma once
18 
19 #include <array>
20 #include <cstddef>
21 #include <cstdint>
22 
23 #include "pw_assert/assert.h"
24 #include "pw_rpc/channel.h"
25 #include "pw_rpc/client.h"
26 #include "pw_rpc/internal/method.h"
27 #include "pw_rpc/internal/packet.h"
28 #include "pw_rpc/raw/fake_channel_output.h"
29 #include "pw_rpc/server.h"
30 #include "pw_span/span.h"
31 #include "pw_unit_test/framework.h"
32 
33 namespace pw::rpc::internal {
34 
35 // Version of the Server with extra methods exposed for testing.
36 class TestServer : public Server {
37  public:
38   using Server::calls_end;
39   using Server::CloseCallAndMarkForCleanup;
40   using Server::FindCall;
41 };
42 
43 template <typename Service,
44           uint32_t kChannelId = 99,
45           uint32_t kServiceId = 16,
46           uint32_t kDefaultCallId = 437>
47 class ServerContextForTest {
48  public:
channel_id()49   static constexpr uint32_t channel_id() { return kChannelId; }
service_id()50   static constexpr uint32_t service_id() { return kServiceId; }
51 
ServerContextForTest(const internal::Method & method)52   ServerContextForTest(const internal::Method& method)
53       : channel_(Channel::Create<kChannelId>(&output_)),
54         server_(span(&channel_, 1)),
55         service_(kServiceId),
56         context_(server_, channel_.id(), service_, method, kDefaultCallId) {
57     server_.RegisterService(service_);
58   }
59 
60   // Create packets for this context's channel, service, and method.
request(span<const std::byte> payload)61   internal::Packet request(span<const std::byte> payload) const {
62     return internal::Packet(internal::pwpb::PacketType::REQUEST,
63                             kChannelId,
64                             kServiceId,
65                             context_.method().id(),
66                             kDefaultCallId,
67                             payload);
68   }
69 
70   internal::Packet response(Status status,
71                             span<const std::byte> payload = {}) const {
72     return internal::Packet(internal::pwpb::PacketType::RESPONSE,
73                             kChannelId,
74                             kServiceId,
75                             context_.method().id(),
76                             kDefaultCallId,
77                             payload,
78                             status);
79   }
80 
server_stream(span<const std::byte> payload)81   internal::Packet server_stream(span<const std::byte> payload) const {
82     return internal::Packet(internal::pwpb::PacketType::SERVER_STREAM,
83                             kChannelId,
84                             kServiceId,
85                             context_.method().id(),
86                             kDefaultCallId,
87                             payload);
88   }
89 
client_stream(span<const std::byte> payload)90   internal::Packet client_stream(span<const std::byte> payload) const {
91     return internal::Packet(internal::pwpb::PacketType::CLIENT_STREAM,
92                             kChannelId,
93                             kServiceId,
94                             context_.method().id(),
95                             kDefaultCallId,
96                             payload);
97   }
98 
99   CallContext get(uint32_t id = kDefaultCallId) const {
100     return CallContext(context_.server(),
101                        context_.channel_id(),
102                        context_.service(),
103                        context_.method(),
104                        id);
105   }
106 
output()107   internal::test::FakeChannelOutput& output() { return output_; }
server()108   TestServer& server() { return static_cast<TestServer&>(server_); }
service()109   Service& service() { return service_; }
110 
111  private:
112   RawFakeChannelOutput<5> output_;
113   rpc::Channel channel_;
114   rpc::Server server_;
115   Service service_;
116 
117   const internal::CallContext context_;
118 };
119 
120 template <size_t kInputBufferSize = 128,
121           uint32_t kChannelId = 99,
122           uint32_t kServiceId = 16,
123           uint32_t kMethodId = 111>
124 class ClientContextForTest {
125  public:
channel_id()126   static constexpr uint32_t channel_id() { return kChannelId; }
service_id()127   static constexpr uint32_t service_id() { return kServiceId; }
method_id()128   static constexpr uint32_t method_id() { return kMethodId; }
129 
ClientContextForTest()130   ClientContextForTest()
131       : channel_(Channel::Create<kChannelId>(&output_)),
132         client_(span(&channel_, 1)) {}
133 
output()134   const internal::test::FakeChannelOutput& output() const { return output_; }
channel()135   Channel& channel() { return static_cast<Channel&>(channel_); }
client()136   Client& client() { return client_; }
137 
138   // Sends a packet to be processed by the client. Returns the client's
139   // ProcessPacket status.
140   Status SendPacket(internal::pwpb::PacketType type,
141                     Status status = OkStatus(),
142                     span<const std::byte> payload = {}) {
143     uint32_t call_id =
144         output().total_packets() > 0 ? output().last_packet().call_id() : 0;
145 
146     internal::Packet packet(
147         type, kChannelId, kServiceId, kMethodId, call_id, payload, status);
148     std::byte buffer[kInputBufferSize];
149     Result result = packet.Encode(buffer);
150     EXPECT_EQ(result.status(), OkStatus());
151     return client_.ProcessPacket(result.value_or(ConstByteSpan()));
152   }
153 
154   Status SendResponse(Status status, span<const std::byte> payload = {}) {
155     return SendPacket(internal::pwpb::PacketType::RESPONSE, status, payload);
156   }
157 
SendServerStream(span<const std::byte> payload)158   Status SendServerStream(span<const std::byte> payload) {
159     return SendPacket(
160         internal::pwpb::PacketType::SERVER_STREAM, OkStatus(), payload);
161   }
162 
163  private:
164   RawFakeChannelOutput<5> output_;
165   rpc::Channel channel_;
166   Client client_;
167 };
168 
169 }  // namespace pw::rpc::internal
170