1 // Copyright 2022 The Pigweed Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not 4 // use this file except in compliance with the License. You may obtain a copy of 5 // the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12 // License for the specific language governing permissions and limitations under 13 // the License. 14 #pragma once 15 16 #include <cinttypes> 17 18 #include "pw_function/function.h" 19 #include "pw_rpc/channel.h" 20 #include "pw_rpc/client_server.h" 21 #include "pw_rpc/internal/client_server_testing.h" 22 #include "pw_span/span.h" 23 #include "pw_status/status.h" 24 #include "pw_sync/binary_semaphore.h" 25 #include "pw_sync/mutex.h" 26 #include "pw_thread/thread.h" 27 28 namespace pw::rpc { 29 namespace internal { 30 31 // Expands on a Forwarding Channel Output implementation to allow for 32 // observation of packets. 33 template <typename FakeChannelOutputImpl, 34 size_t kOutputSize, 35 size_t kMaxPackets, 36 size_t kPayloadsBufferSizeBytes> 37 class WatchableChannelOutput 38 : public ForwardingChannelOutput<FakeChannelOutputImpl, 39 kOutputSize, 40 kMaxPackets, 41 kPayloadsBufferSizeBytes> { 42 private: 43 using Base = ForwardingChannelOutput<FakeChannelOutputImpl, 44 kOutputSize, 45 kMaxPackets, 46 kPayloadsBufferSizeBytes>; 47 48 public: MaximumTransmissionUnit()49 size_t MaximumTransmissionUnit() PW_LOCKS_EXCLUDED(mutex_) override { 50 std::lock_guard lock(mutex_); 51 return Base::MaximumTransmissionUnit(); 52 } 53 Send(span<const std::byte> buffer)54 Status Send(span<const std::byte> buffer) PW_LOCKS_EXCLUDED(mutex_) override { 55 Status status; 56 mutex_.lock(); 57 status = Base::Send(buffer); 58 mutex_.unlock(); 59 output_semaphore_.release(); 60 return status; 61 } 62 63 // Returns true if should continue waiting for additional output WaitForOutput()64 bool WaitForOutput() PW_LOCKS_EXCLUDED(mutex_) { 65 output_semaphore_.acquire(); 66 std::lock_guard lock(mutex_); 67 return should_wait_; 68 } 69 StopWaitingForOutput()70 void StopWaitingForOutput() PW_LOCKS_EXCLUDED(mutex_) { 71 std::lock_guard lock(mutex_); 72 should_wait_ = false; 73 output_semaphore_.release(); 74 } 75 76 protected: 77 explicit WatchableChannelOutput( 78 TestPacketProcessor&& server_packet_processor = nullptr, 79 TestPacketProcessor&& client_packet_processor = nullptr) Base(std::move (server_packet_processor),std::move (client_packet_processor))80 : Base(std::move(server_packet_processor), 81 std::move(client_packet_processor)) {} 82 PacketCount()83 size_t PacketCount() const PW_EXCLUSIVE_LOCKS_REQUIRED(mutex_) override { 84 return Base::PacketCount(); 85 } 86 87 sync::Mutex mutex_; 88 89 private: EncodeNextUnsentPacket(std::array<std::byte,kPayloadsBufferSizeBytes> & packet_buffer)90 Result<ConstByteSpan> EncodeNextUnsentPacket( 91 std::array<std::byte, kPayloadsBufferSizeBytes>& packet_buffer) 92 PW_LOCKS_EXCLUDED(mutex_) override { 93 std::lock_guard lock(mutex_); 94 return Base::EncodeNextUnsentPacket(packet_buffer); 95 } 96 sync::BinarySemaphore output_semaphore_; 97 bool should_wait_ PW_GUARDED_BY(mutex_) = true; 98 }; 99 100 // Provides a testing context with a real client and server 101 template <typename WatchableChannelOutputImpl, 102 size_t kOutputSize = 128, 103 size_t kMaxPackets = 16, 104 size_t kPayloadsBufferSizeBytes = 128> 105 class ClientServerTestContextThreaded 106 : public ClientServerTestContext<WatchableChannelOutputImpl, 107 kOutputSize, 108 kMaxPackets, 109 kPayloadsBufferSizeBytes> { 110 private: 111 using Instance = ClientServerTestContextThreaded<WatchableChannelOutputImpl, 112 kOutputSize, 113 kMaxPackets, 114 kPayloadsBufferSizeBytes>; 115 using Base = ClientServerTestContext<WatchableChannelOutputImpl, 116 kOutputSize, 117 kMaxPackets, 118 kPayloadsBufferSizeBytes>; 119 120 public: ~ClientServerTestContextThreaded()121 ~ClientServerTestContextThreaded() { 122 Base::channel_output_.StopWaitingForOutput(); 123 thread_.join(); 124 } 125 126 protected: 127 explicit ClientServerTestContextThreaded( 128 const thread::Options& options, 129 TestPacketProcessor&& server_packet_processor = nullptr, 130 TestPacketProcessor&& client_packet_processor = nullptr) Base(std::move (server_packet_processor),std::move (client_packet_processor))131 : Base(std::move(server_packet_processor), 132 std::move(client_packet_processor)), 133 thread_(options, [this] { Run(); }) {} 134 135 private: 136 using Base::ForwardNewPackets; Run()137 void Run() { 138 auto& ctx = *static_cast<Instance*>(this); 139 while (ctx.channel_output_.WaitForOutput()) { 140 ctx.ForwardNewPackets(); 141 } 142 } 143 Thread thread_; 144 }; 145 146 } // namespace internal 147 } // namespace pw::rpc 148