1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 // clang-format off
16 #include "pw_rpc/internal/log_config.h" // PW_LOG_* macros must be first.
17
18 #include "pw_rpc/server.h"
19 // clang-format on
20
21 #include <algorithm>
22
23 #include "pw_log/log.h"
24 #include "pw_rpc/internal/endpoint.h"
25 #include "pw_rpc/internal/packet.h"
26 #include "pw_rpc/service_id.h"
27
28 namespace pw::rpc {
29 namespace {
30
31 using internal::Packet;
32 using internal::pwpb::PacketType;
33
34 } // namespace
35
ProcessPacket(ConstByteSpan packet_data)36 Status Server::ProcessPacket(ConstByteSpan packet_data) {
37 PW_TRY_ASSIGN(Packet packet,
38 Endpoint::ProcessPacket(packet_data, Packet::kServer));
39 return ProcessPacket(packet);
40 }
41
ProcessPacket(internal::Packet packet)42 Status Server::ProcessPacket(internal::Packet packet) {
43 internal::rpc_lock().lock();
44
45 static constexpr bool kLogAllIncomingPackets = false;
46 if constexpr (kLogAllIncomingPackets) {
47 PW_LOG_INFO("RPC server received packet type %u for %u:%08x/%08x",
48 static_cast<unsigned>(packet.type()),
49 static_cast<unsigned>(packet.channel_id()),
50 static_cast<unsigned>(packet.service_id()),
51 static_cast<unsigned>(packet.method_id()));
52 }
53
54 internal::ChannelBase* channel = GetInternalChannel(packet.channel_id());
55 if (channel == nullptr) {
56 internal::rpc_lock().unlock();
57 PW_LOG_WARN("RPC server received packet for unknown channel %u",
58 static_cast<unsigned>(packet.channel_id()));
59 return Status::Unavailable();
60 }
61
62 const auto [service, method] = FindMethodLocked(packet);
63
64 if (method == nullptr) {
65 // Don't send responses to errors to avoid infinite error cycles.
66 if (packet.type() != PacketType::CLIENT_ERROR) {
67 channel->Send(Packet::ServerError(packet, Status::NotFound()))
68 .IgnoreError();
69 }
70 internal::rpc_lock().unlock();
71 PW_LOG_DEBUG("Received packet on channel %u for unknown RPC %08x/%08x",
72 static_cast<unsigned>(packet.channel_id()),
73 static_cast<unsigned>(packet.service_id()),
74 static_cast<unsigned>(packet.method_id()));
75 return OkStatus(); // OK since the packet was handled.
76 }
77
78 // Handle request packets separately to avoid an unnecessary call lookup. The
79 // Call constructor looks up and cancels any duplicate calls.
80 if (packet.type() == PacketType::REQUEST) {
81 const internal::CallContext context(
82 *this, packet.channel_id(), *service, *method, packet.call_id());
83 method->Invoke(context, packet);
84 return OkStatus();
85 }
86
87 IntrusiveList<internal::Call>::iterator call = FindCall(packet);
88
89 switch (packet.type()) {
90 case PacketType::CLIENT_STREAM:
91 HandleClientStreamPacket(packet, *channel, call);
92 break;
93 case PacketType::CLIENT_ERROR:
94 if (call != calls_end()) {
95 PW_LOG_DEBUG("Server call %u for %u:%08x/%08x terminated with error %s",
96 static_cast<unsigned>(packet.call_id()),
97 static_cast<unsigned>(packet.channel_id()),
98 static_cast<unsigned>(packet.service_id()),
99 static_cast<unsigned>(packet.method_id()),
100 packet.status().str());
101 call->HandleError(packet.status());
102 } else {
103 internal::rpc_lock().unlock();
104 }
105 break;
106 case PacketType::CLIENT_REQUEST_COMPLETION:
107 HandleCompletionRequest(packet, *channel, call);
108 break;
109 case PacketType::REQUEST: // Handled above
110 case PacketType::RESPONSE:
111 case PacketType::SERVER_ERROR:
112 case PacketType::SERVER_STREAM:
113 default:
114 internal::rpc_lock().unlock();
115 PW_LOG_WARN("pw_rpc server unable to handle packet of type %u",
116 unsigned(packet.type()));
117 }
118
119 return OkStatus(); // OK since the packet was handled
120 }
121
FindMethod(uint32_t service_id,uint32_t method_id)122 std::tuple<Service*, const internal::Method*> Server::FindMethod(
123 uint32_t service_id, uint32_t method_id) {
124 internal::RpcLockGuard lock;
125 return FindMethodLocked(service_id, method_id);
126 }
127
FindMethodLocked(uint32_t service_id,uint32_t method_id)128 std::tuple<Service*, const internal::Method*> Server::FindMethodLocked(
129 uint32_t service_id, uint32_t method_id) {
130 auto service = std::find_if(services_.begin(), services_.end(), [&](auto& s) {
131 return internal::UnwrapServiceId(s.service_id()) == service_id;
132 });
133
134 if (service == services_.end()) {
135 return {};
136 }
137
138 return {&(*service), service->FindMethod(method_id)};
139 }
140
HandleCompletionRequest(const internal::Packet & packet,internal::ChannelBase & channel,IntrusiveList<internal::Call>::iterator call) const141 void Server::HandleCompletionRequest(
142 const internal::Packet& packet,
143 internal::ChannelBase& channel,
144 IntrusiveList<internal::Call>::iterator call) const {
145 if (call == calls_end()) {
146 channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
147 .IgnoreError(); // Errors are logged in Channel::Send.
148 internal::rpc_lock().unlock();
149 PW_LOG_DEBUG(
150 "Received a request completion packet for %u:%08x/%08x, which is not a"
151 "pending call",
152 static_cast<unsigned>(packet.channel_id()),
153 static_cast<unsigned>(packet.service_id()),
154 static_cast<unsigned>(packet.method_id()));
155 return;
156 }
157
158 if (call->client_requested_completion()) {
159 internal::rpc_lock().unlock();
160 PW_LOG_DEBUG("Received multiple completion requests for %u:%08x/%08x",
161 static_cast<unsigned>(packet.channel_id()),
162 static_cast<unsigned>(packet.service_id()),
163 static_cast<unsigned>(packet.method_id()));
164 return;
165 }
166
167 static_cast<internal::ServerCall&>(*call).HandleClientRequestedCompletion();
168 }
169
HandleClientStreamPacket(const internal::Packet & packet,internal::ChannelBase & channel,IntrusiveList<internal::Call>::iterator call) const170 void Server::HandleClientStreamPacket(
171 const internal::Packet& packet,
172 internal::ChannelBase& channel,
173 IntrusiveList<internal::Call>::iterator call) const {
174 if (call == calls_end()) {
175 channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
176 .IgnoreError(); // Errors are logged in Channel::Send.
177 internal::rpc_lock().unlock();
178 PW_LOG_DEBUG(
179 "Received client stream packet for %u:%08x/%08x, which is not pending",
180 static_cast<unsigned>(packet.channel_id()),
181 static_cast<unsigned>(packet.service_id()),
182 static_cast<unsigned>(packet.method_id()));
183 return;
184 }
185
186 if (!call->has_client_stream()) {
187 channel.Send(Packet::ServerError(packet, Status::InvalidArgument()))
188 .IgnoreError(); // Errors are logged in Channel::Send.
189 internal::rpc_lock().unlock();
190 PW_LOG_DEBUG(
191 "Received client stream packet for %u:%08x/%08x, which doesn't have a "
192 "client stream",
193 static_cast<unsigned>(packet.channel_id()),
194 static_cast<unsigned>(packet.service_id()),
195 static_cast<unsigned>(packet.method_id()));
196 return;
197 }
198
199 if (call->client_requested_completion()) {
200 channel.Send(Packet::ServerError(packet, Status::FailedPrecondition()))
201 .IgnoreError(); // Errors are logged in Channel::Send.
202 internal::rpc_lock().unlock();
203 PW_LOG_DEBUG(
204 "Received client stream packet for %u:%08x/%08x, but its client stream "
205 "is closed",
206 static_cast<unsigned>(packet.channel_id()),
207 static_cast<unsigned>(packet.service_id()),
208 static_cast<unsigned>(packet.method_id()));
209 return;
210 }
211
212 call->HandlePayload(packet.payload());
213 }
214
215 } // namespace pw::rpc
216