xref: /aosp_15_r20/external/pigweed/pw_rpc/server.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 // clang-format off
16 #include "pw_rpc/internal/log_config.h" // PW_LOG_* macros must be first.
17 
18 #include "pw_rpc/server.h"
19 // clang-format on
20 
21 #include <algorithm>
22 
23 #include "pw_log/log.h"
24 #include "pw_rpc/internal/endpoint.h"
25 #include "pw_rpc/internal/packet.h"
26 #include "pw_rpc/service_id.h"
27 
28 namespace pw::rpc {
29 namespace {
30 
31 using internal::Packet;
32 using internal::pwpb::PacketType;
33 
34 }  // namespace
35 
ProcessPacket(ConstByteSpan packet_data)36 Status Server::ProcessPacket(ConstByteSpan packet_data) {
37   PW_TRY_ASSIGN(Packet packet,
38                 Endpoint::ProcessPacket(packet_data, Packet::kServer));
39   return ProcessPacket(packet);
40 }
41 
ProcessPacket(internal::Packet packet)42 Status Server::ProcessPacket(internal::Packet packet) {
43   internal::rpc_lock().lock();
44 
45   static constexpr bool kLogAllIncomingPackets = false;
46   if constexpr (kLogAllIncomingPackets) {
47     PW_LOG_INFO("RPC server received packet type %u for %u:%08x/%08x",
48                 static_cast<unsigned>(packet.type()),
49                 static_cast<unsigned>(packet.channel_id()),
50                 static_cast<unsigned>(packet.service_id()),
51                 static_cast<unsigned>(packet.method_id()));
52   }
53 
54   internal::ChannelBase* channel = GetInternalChannel(packet.channel_id());
55   if (channel == nullptr) {
56     internal::rpc_lock().unlock();
57     PW_LOG_WARN("RPC server received packet for unknown channel %u",
58                 static_cast<unsigned>(packet.channel_id()));
59     return Status::Unavailable();
60   }
61 
62   const auto [service, method] = FindMethodLocked(packet);
63 
64   if (method == nullptr) {
65     // Don't send responses to errors to avoid infinite error cycles.
66     if (packet.type() != PacketType::CLIENT_ERROR) {
67       channel->Send(Packet::ServerError(packet, Status::NotFound()))
68           .IgnoreError();
69     }
70     internal::rpc_lock().unlock();
71     PW_LOG_DEBUG("Received packet on channel %u for unknown RPC %08x/%08x",
72                  static_cast<unsigned>(packet.channel_id()),
73                  static_cast<unsigned>(packet.service_id()),
74                  static_cast<unsigned>(packet.method_id()));
75     return OkStatus();  // OK since the packet was handled.
76   }
77 
78   // Handle request packets separately to avoid an unnecessary call lookup. The
79   // Call constructor looks up and cancels any duplicate calls.
80   if (packet.type() == PacketType::REQUEST) {
81     const internal::CallContext context(
82         *this, packet.channel_id(), *service, *method, packet.call_id());
83     method->Invoke(context, packet);
84     return OkStatus();
85   }
86 
87   IntrusiveList<internal::Call>::iterator call = FindCall(packet);
88 
89   switch (packet.type()) {
90     case PacketType::CLIENT_STREAM:
91       HandleClientStreamPacket(packet, *channel, call);
92       break;
93     case PacketType::CLIENT_ERROR:
94       if (call != calls_end()) {
95         PW_LOG_DEBUG("Server call %u for %u:%08x/%08x terminated with error %s",
96                      static_cast<unsigned>(packet.call_id()),
97                      static_cast<unsigned>(packet.channel_id()),
98                      static_cast<unsigned>(packet.service_id()),
99                      static_cast<unsigned>(packet.method_id()),
100                      packet.status().str());
101         call->HandleError(packet.status());
102       } else {
103         internal::rpc_lock().unlock();
104       }
105       break;
106     case PacketType::CLIENT_REQUEST_COMPLETION:
107       HandleCompletionRequest(packet, *channel, call);
108       break;
109     case PacketType::REQUEST:  // Handled above
110     case PacketType::RESPONSE:
111     case PacketType::SERVER_ERROR:
112     case PacketType::SERVER_STREAM:
113     default:
114       internal::rpc_lock().unlock();
115       PW_LOG_WARN("pw_rpc server unable to handle packet of type %u",
116                   unsigned(packet.type()));
117   }
118 
119   return OkStatus();  // OK since the packet was handled
120 }
121 
FindMethod(uint32_t service_id,uint32_t method_id)122 std::tuple<Service*, const internal::Method*> Server::FindMethod(
123     uint32_t service_id, uint32_t method_id) {
124   internal::RpcLockGuard lock;
125   return FindMethodLocked(service_id, method_id);
126 }
127 
FindMethodLocked(uint32_t service_id,uint32_t method_id)128 std::tuple<Service*, const internal::Method*> Server::FindMethodLocked(
129     uint32_t service_id, uint32_t method_id) {
130   auto service = std::find_if(services_.begin(), services_.end(), [&](auto& s) {
131     return internal::UnwrapServiceId(s.service_id()) == service_id;
132   });
133 
134   if (service == services_.end()) {
135     return {};
136   }
137 
138   return {&(*service), service->FindMethod(method_id)};
139 }
140 
HandleCompletionRequest(const internal::Packet & packet,internal::ChannelBase & channel,IntrusiveList<internal::Call>::iterator call) const141 void Server::HandleCompletionRequest(
142     const internal::Packet& packet,
143     internal::ChannelBase& channel,
144     IntrusiveList<internal::Call>::iterator call) const {
145   if (call == calls_end()) {
146     channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
147         .IgnoreError();  // Errors are logged in Channel::Send.
148     internal::rpc_lock().unlock();
149     PW_LOG_DEBUG(
150         "Received a request completion packet for %u:%08x/%08x, which is not a"
151         "pending call",
152         static_cast<unsigned>(packet.channel_id()),
153         static_cast<unsigned>(packet.service_id()),
154         static_cast<unsigned>(packet.method_id()));
155     return;
156   }
157 
158   if (call->client_requested_completion()) {
159     internal::rpc_lock().unlock();
160     PW_LOG_DEBUG("Received multiple completion requests for %u:%08x/%08x",
161                  static_cast<unsigned>(packet.channel_id()),
162                  static_cast<unsigned>(packet.service_id()),
163                  static_cast<unsigned>(packet.method_id()));
164     return;
165   }
166 
167   static_cast<internal::ServerCall&>(*call).HandleClientRequestedCompletion();
168 }
169 
HandleClientStreamPacket(const internal::Packet & packet,internal::ChannelBase & channel,IntrusiveList<internal::Call>::iterator call) const170 void Server::HandleClientStreamPacket(
171     const internal::Packet& packet,
172     internal::ChannelBase& channel,
173     IntrusiveList<internal::Call>::iterator call) const {
174   if (call == calls_end()) {
175     channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
176         .IgnoreError();  // Errors are logged in Channel::Send.
177     internal::rpc_lock().unlock();
178     PW_LOG_DEBUG(
179         "Received client stream packet for %u:%08x/%08x, which is not pending",
180         static_cast<unsigned>(packet.channel_id()),
181         static_cast<unsigned>(packet.service_id()),
182         static_cast<unsigned>(packet.method_id()));
183     return;
184   }
185 
186   if (!call->has_client_stream()) {
187     channel.Send(Packet::ServerError(packet, Status::InvalidArgument()))
188         .IgnoreError();  // Errors are logged in Channel::Send.
189     internal::rpc_lock().unlock();
190     PW_LOG_DEBUG(
191         "Received client stream packet for %u:%08x/%08x, which doesn't have a "
192         "client stream",
193         static_cast<unsigned>(packet.channel_id()),
194         static_cast<unsigned>(packet.service_id()),
195         static_cast<unsigned>(packet.method_id()));
196     return;
197   }
198 
199   if (call->client_requested_completion()) {
200     channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
201         .IgnoreError();  // Errors are logged in Channel::Send.
202     internal::rpc_lock().unlock();
203     PW_LOG_DEBUG(
204         "Received client stream packet for %u:%08x/%08x, but its client stream "
205         "is closed",
206         static_cast<unsigned>(packet.channel_id()),
207         static_cast<unsigned>(packet.service_id()),
208         static_cast<unsigned>(packet.method_id()));
209     return;
210   }
211 
212   call->HandlePayload(packet.payload());
213 }
214 
215 }  // namespace pw::rpc
216