xref: /aosp_15_r20/external/pigweed/pw_rpc/public/pw_rpc/internal/client_server_testing_threaded.h (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2022 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <cinttypes>
17 
18 #include "pw_function/function.h"
19 #include "pw_rpc/channel.h"
20 #include "pw_rpc/client_server.h"
21 #include "pw_rpc/internal/client_server_testing.h"
22 #include "pw_span/span.h"
23 #include "pw_status/status.h"
24 #include "pw_sync/binary_semaphore.h"
25 #include "pw_sync/mutex.h"
26 #include "pw_thread/thread.h"
27 
28 namespace pw::rpc {
29 namespace internal {
30 
31 // Expands on a Forwarding Channel Output implementation to allow for
32 // observation of packets.
33 template <typename FakeChannelOutputImpl,
34           size_t kOutputSize,
35           size_t kMaxPackets,
36           size_t kPayloadsBufferSizeBytes>
37 class WatchableChannelOutput
38     : public ForwardingChannelOutput<FakeChannelOutputImpl,
39                                      kOutputSize,
40                                      kMaxPackets,
41                                      kPayloadsBufferSizeBytes> {
42  private:
43   using Base = ForwardingChannelOutput<FakeChannelOutputImpl,
44                                        kOutputSize,
45                                        kMaxPackets,
46                                        kPayloadsBufferSizeBytes>;
47 
48  public:
MaximumTransmissionUnit()49   size_t MaximumTransmissionUnit() PW_LOCKS_EXCLUDED(mutex_) override {
50     std::lock_guard lock(mutex_);
51     return Base::MaximumTransmissionUnit();
52   }
53 
Send(span<const std::byte> buffer)54   Status Send(span<const std::byte> buffer) PW_LOCKS_EXCLUDED(mutex_) override {
55     Status status;
56     mutex_.lock();
57     status = Base::Send(buffer);
58     mutex_.unlock();
59     output_semaphore_.release();
60     return status;
61   }
62 
63   // Returns true if should continue waiting for additional output
WaitForOutput()64   bool WaitForOutput() PW_LOCKS_EXCLUDED(mutex_) {
65     output_semaphore_.acquire();
66     std::lock_guard lock(mutex_);
67     return should_wait_;
68   }
69 
StopWaitingForOutput()70   void StopWaitingForOutput() PW_LOCKS_EXCLUDED(mutex_) {
71     std::lock_guard lock(mutex_);
72     should_wait_ = false;
73     output_semaphore_.release();
74   }
75 
76  protected:
77   explicit WatchableChannelOutput(
78       TestPacketProcessor&& server_packet_processor = nullptr,
79       TestPacketProcessor&& client_packet_processor = nullptr)
Base(std::move (server_packet_processor),std::move (client_packet_processor))80       : Base(std::move(server_packet_processor),
81              std::move(client_packet_processor)) {}
82 
PacketCount()83   size_t PacketCount() const PW_EXCLUSIVE_LOCKS_REQUIRED(mutex_) override {
84     return Base::PacketCount();
85   }
86 
87   sync::Mutex mutex_;
88 
89  private:
EncodeNextUnsentPacket(std::array<std::byte,kPayloadsBufferSizeBytes> & packet_buffer)90   Result<ConstByteSpan> EncodeNextUnsentPacket(
91       std::array<std::byte, kPayloadsBufferSizeBytes>& packet_buffer)
92       PW_LOCKS_EXCLUDED(mutex_) override {
93     std::lock_guard lock(mutex_);
94     return Base::EncodeNextUnsentPacket(packet_buffer);
95   }
96   sync::BinarySemaphore output_semaphore_;
97   bool should_wait_ PW_GUARDED_BY(mutex_) = true;
98 };
99 
100 // Provides a testing context with a real client and server
101 template <typename WatchableChannelOutputImpl,
102           size_t kOutputSize = 128,
103           size_t kMaxPackets = 16,
104           size_t kPayloadsBufferSizeBytes = 128>
105 class ClientServerTestContextThreaded
106     : public ClientServerTestContext<WatchableChannelOutputImpl,
107                                      kOutputSize,
108                                      kMaxPackets,
109                                      kPayloadsBufferSizeBytes> {
110  private:
111   using Instance = ClientServerTestContextThreaded<WatchableChannelOutputImpl,
112                                                    kOutputSize,
113                                                    kMaxPackets,
114                                                    kPayloadsBufferSizeBytes>;
115   using Base = ClientServerTestContext<WatchableChannelOutputImpl,
116                                        kOutputSize,
117                                        kMaxPackets,
118                                        kPayloadsBufferSizeBytes>;
119 
120  public:
~ClientServerTestContextThreaded()121   ~ClientServerTestContextThreaded() {
122     Base::channel_output_.StopWaitingForOutput();
123     thread_.join();
124   }
125 
126  protected:
127   explicit ClientServerTestContextThreaded(
128       const thread::Options& options,
129       TestPacketProcessor&& server_packet_processor = nullptr,
130       TestPacketProcessor&& client_packet_processor = nullptr)
Base(std::move (server_packet_processor),std::move (client_packet_processor))131       : Base(std::move(server_packet_processor),
132              std::move(client_packet_processor)),
133         thread_(options, [this] { Run(); }) {}
134 
135  private:
136   using Base::ForwardNewPackets;
Run()137   void Run() {
138     auto& ctx = *static_cast<Instance*>(this);
139     while (ctx.channel_output_.WaitForOutput()) {
140       ctx.ForwardNewPackets();
141     }
142   }
143   Thread thread_;
144 };
145 
146 }  // namespace internal
147 }  // namespace pw::rpc
148