xref: /aosp_15_r20/external/pigweed/pw_rpc/public/pw_rpc/integration_test_socket_client.h (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2021 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <atomic>
17 #include <cstdint>
18 #include <optional>
19 #include <thread>
20 
21 #include "pw_hdlc/decoder.h"
22 #include "pw_hdlc/default_addresses.h"
23 #include "pw_hdlc/encoded_size.h"
24 #include "pw_hdlc/rpc_channel.h"
25 #include "pw_rpc/integration_testing.h"
26 #include "pw_span/span.h"
27 #include "pw_status/try.h"
28 #include "pw_stream/socket_stream.h"
29 
30 namespace pw::rpc::integration_test {
31 
32 // Wraps an RPC client with a socket stream and a channel configured to use it.
33 // Useful for integration tests that run across a socket.
34 template <size_t kMaxTransmissionUnit>
35 class SocketClientContext {
36  public:
SocketClientContext()37   constexpr SocketClientContext()
38       : rpc_dispatch_thread_handle_(std::nullopt),
39         channel_output_(stream_, hdlc::kDefaultRpcAddress, "socket"),
40         channel_output_with_manipulator_(channel_output_),
41         channel_(
42             Channel::Create<kChannelId>(&channel_output_with_manipulator_)),
43         client_(span(&channel_, 1)) {}
44 
client()45   Client& client() { return client_; }
46 
47   // Connects to the specified host:port and starts a background thread to read
48   // packets from the socket.
Start(const char * host,uint16_t port)49   Status Start(const char* host, uint16_t port) {
50     PW_TRY(stream_.Connect(host, port));
51     rpc_dispatch_thread_handle_.emplace(&SocketClientContext::ProcessPackets,
52                                         this);
53     return OkStatus();
54   }
55 
56   // Terminates the client, joining the RPC dispatch thread.
Terminate()57   void Terminate() {
58     PW_ASSERT(rpc_dispatch_thread_handle_.has_value());
59     should_terminate_.test_and_set();
60     // Close the stream to avoid blocking forever on a socket read.
61     stream_.Close();
62     rpc_dispatch_thread_handle_->join();
63   }
64 
65   // Configure options for the socket associated with the client.
SetSockOpt(int level,int optname,const void * optval,unsigned int optlen)66   int SetSockOpt(int level,
67                  int optname,
68                  const void* optval,
69                  unsigned int optlen) {
70     return stream_.SetSockOpt(level, optname, optval, optlen);
71   }
72 
SetEgressChannelManipulator(ChannelManipulator * new_channel_manipulator)73   void SetEgressChannelManipulator(
74       ChannelManipulator* new_channel_manipulator) {
75     channel_output_with_manipulator_.set_channel_manipulator(
76         new_channel_manipulator);
77   }
78 
SetIngressChannelManipulator(ChannelManipulator * new_channel_manipulator)79   void SetIngressChannelManipulator(
80       ChannelManipulator* new_channel_manipulator) {
81     if (new_channel_manipulator != nullptr) {
82       new_channel_manipulator->set_send_packet([&](ConstByteSpan payload) {
83         return client_.ProcessPacket(payload);
84       });
85     }
86     ingress_channel_manipulator_ = new_channel_manipulator;
87   }
88 
89   // Calls Start for localhost.
Start(uint16_t port)90   Status Start(uint16_t port) { return Start("localhost", port); }
91 
92  private:
93   void ProcessPackets();
94 
95   class ChannelOutputWithManipulator : public ChannelOutput {
96    public:
ChannelOutputWithManipulator(ChannelOutput & actual_output)97     ChannelOutputWithManipulator(ChannelOutput& actual_output)
98         : ChannelOutput(actual_output.name()),
99           actual_output_(actual_output),
100           channel_manipulator_(nullptr) {}
101 
set_channel_manipulator(ChannelManipulator * new_channel_manipulator)102     void set_channel_manipulator(ChannelManipulator* new_channel_manipulator) {
103       if (new_channel_manipulator != nullptr) {
104         new_channel_manipulator->set_send_packet(
105             ChannelManipulator::SendCallback(
106                 [&](ConstByteSpan payload)
107                     __attribute__((no_thread_safety_analysis)) {
108                       return actual_output_.Send(payload);
109                     }));
110       }
111       channel_manipulator_ = new_channel_manipulator;
112     }
113 
MaximumTransmissionUnit()114     size_t MaximumTransmissionUnit() override {
115       return actual_output_.MaximumTransmissionUnit();
116     }
Send(span<const std::byte> buffer)117     Status Send(span<const std::byte> buffer) override
118         PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock()) {
119       if (channel_manipulator_ != nullptr) {
120         return channel_manipulator_->ProcessAndSend(buffer);
121       }
122 
123       return actual_output_.Send(buffer);
124     }
125 
126    private:
127     ChannelOutput& actual_output_;
128     ChannelManipulator* channel_manipulator_;
129   };
130 
131   std::atomic_flag should_terminate_ = ATOMIC_FLAG_INIT;
132   std::optional<std::thread> rpc_dispatch_thread_handle_;
133   stream::SocketStream stream_;
134   hdlc::FixedMtuChannelOutput<kMaxTransmissionUnit> channel_output_;
135   ChannelOutputWithManipulator channel_output_with_manipulator_;
136   ChannelManipulator* ingress_channel_manipulator_;
137   Channel channel_;
138   Client client_;
139 };
140 
141 template <size_t kMaxTransmissionUnit>
ProcessPackets()142 void SocketClientContext<kMaxTransmissionUnit>::ProcessPackets() {
143   constexpr size_t kDecoderBufferSize =
144       hdlc::Decoder::RequiredBufferSizeForFrameSize(kMaxTransmissionUnit);
145   std::array<std::byte, kDecoderBufferSize> decode_buffer;
146   hdlc::Decoder decoder(decode_buffer);
147 
148   while (true) {
149     std::byte byte[1];
150     Result<ByteSpan> read = stream_.Read(byte);
151 
152     if (should_terminate_.test()) {
153       return;
154     }
155 
156     if (!read.ok() || read->empty()) {
157       continue;
158     }
159 
160     if (auto result = decoder.Process(*byte); result.ok()) {
161       hdlc::Frame& frame = result.value();
162       if (frame.address() == hdlc::kDefaultRpcAddress) {
163         if (ingress_channel_manipulator_ != nullptr) {
164           PW_ASSERT(
165               ingress_channel_manipulator_->ProcessAndSend(frame.data()).ok());
166         } else {
167           PW_ASSERT(client_.ProcessPacket(frame.data()).ok());
168         }
169       }
170     }
171   }
172 }
173 
174 }  // namespace pw::rpc::integration_test
175