1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <queue>
17 
18 #include "pw_rpc_transport/egress_ingress.h"
19 #include "pw_rpc_transport/service_registry.h"
20 #include "pw_work_queue/test_thread.h"
21 #include "pw_work_queue/work_queue.h"
22 
23 namespace pw::rpc {
24 
25 // A transport that loops back all received frames to a given ingress.
26 class TestLoopbackTransport : public RpcFrameSender {
27  public:
TestLoopbackTransport(size_t mtu)28   explicit TestLoopbackTransport(size_t mtu) : mtu_(mtu) {
29     work_thread_ =
30         Thread(work_queue::test::WorkQueueThreadOptions(), work_queue_);
31   }
32 
~TestLoopbackTransport()33   ~TestLoopbackTransport() override {
34     work_queue_.RequestStop();
35 #if PW_THREAD_JOINING_ENABLED
36     work_thread_.join();
37 #else
38     work_thread_.detach();
39 #endif  // PW_THREAD_JOINING_ENABLED
40   }
41 
MaximumTransmissionUnit()42   size_t MaximumTransmissionUnit() const override { return mtu_; }
43 
Send(RpcFrame frame)44   Status Send(RpcFrame frame) override {
45     buffer_queue_.emplace();
46     std::vector<std::byte>& buffer = buffer_queue_.back();
47     std::copy(
48         frame.header.begin(), frame.header.end(), std::back_inserter(buffer));
49     std::copy(
50         frame.payload.begin(), frame.payload.end(), std::back_inserter(buffer));
51 
52     // Defer processing frame on ingress to avoid deadlocks.
53     return work_queue_.PushWork([this]() {
54       ingress_->ProcessIncomingData(buffer_queue_.front()).IgnoreError();
55       buffer_queue_.pop();
56     });
57   }
58 
SetIngress(RpcIngressHandler & ingress)59   void SetIngress(RpcIngressHandler& ingress) { ingress_ = &ingress; }
60 
61  private:
62   size_t mtu_;
63   std::queue<std::vector<std::byte>> buffer_queue_;
64   RpcIngressHandler* ingress_ = nullptr;
65   Thread work_thread_;
66   work_queue::WorkQueueWithBuffer<1> work_queue_;
67 };
68 
69 // An egress handler that passes the received RPC packet to the service
70 // registry.
71 class TestLocalEgress : public RpcEgressHandler {
72  public:
SendRpcPacket(ConstByteSpan packet)73   Status SendRpcPacket(ConstByteSpan packet) override {
74     if (!registry_) {
75       return Status::FailedPrecondition();
76     }
77     return registry_->ProcessRpcPacket(packet);
78   }
79 
SetRegistry(ServiceRegistry & registry)80   void SetRegistry(ServiceRegistry& registry) { registry_ = &registry; }
81 
82  private:
83   ServiceRegistry* registry_ = nullptr;
84 };
85 
86 class TestLoopbackServiceRegistry : public ServiceRegistry {
87  public:
88 #if PW_RPC_DYNAMIC_ALLOCATION
89   static constexpr int kInitTxChannelCount = 0;
90 #else
91   static constexpr int kInitTxChannelCount = 1;
92 #endif
93   static constexpr int kTestChannelId = 1;
94   static constexpr size_t kMtu = 512;
95   static constexpr size_t kMaxPacketSize = 256;
96 
TestLoopbackServiceRegistry()97   TestLoopbackServiceRegistry() : ServiceRegistry(tx_channels_) {
98     PW_ASSERT(
99         client_server().client().OpenChannel(kTestChannelId, egress_).ok());
100 #if PW_RPC_DYNAMIC_ALLOCATION
101     PW_ASSERT(
102         client_server().server().OpenChannel(kTestChannelId, egress_).ok());
103 #endif
104     transport_.SetIngress(ingress_);
105     local_egress_.SetRegistry(*this);
106   }
107 
channel_id()108   int channel_id() const { return kTestChannelId; }
109 
110  private:
111   TestLoopbackTransport transport_{kMtu};
112   TestLocalEgress local_egress_;
113   SimpleRpcEgress<kMaxPacketSize> egress_{"egress", transport_};
114   std::array<Channel, kInitTxChannelCount> tx_channels_;
115   std::array<ChannelEgress, 1> rx_channels_ = {
116       rpc::ChannelEgress{kTestChannelId, local_egress_}};
117   SimpleRpcIngress<kMaxPacketSize> ingress_{rx_channels_};
118 };
119 
120 }  // namespace pw::rpc
121