1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_grpc/pw_rpc_handler.h"
16
17 #include <cinttypes>
18
19 namespace pw::grpc {
20
21 using pw::rpc::internal::pwpb::PacketType;
22
OnClose(StreamId id)23 void PwRpcHandler::OnClose(StreamId id) { ResetStream(id); }
24
OnNewConnection()25 void PwRpcHandler::OnNewConnection() { ResetAllStreams(); }
26
OnNew(StreamId id,InlineString<kMaxMethodNameSize> full_method_name)27 Status PwRpcHandler::OnNew(StreamId id,
28 InlineString<kMaxMethodNameSize> full_method_name) {
29 // Parse out service and method from `/grpc.examples.echo.Echo/UnaryEcho`
30 // formatted name.
31 std::string_view view = std::string_view(full_method_name);
32 auto split_pos = view.find_last_of('/');
33 if (view.empty() || view[0] != '/' || split_pos == std::string_view::npos) {
34 PW_LOG_WARN("Can't determine service/method name id=%" PRIu32 " name=%s",
35 id,
36 full_method_name.c_str());
37 return Status::NotFound();
38 }
39
40 // Look up method in the server.
41 std::string_view service_name = view.substr(1, split_pos - 1);
42 std::string_view method_name = view.substr(split_pos + 1);
43 uint32_t service_id = rpc::internal::Hash(service_name);
44 uint32_t method_id = rpc::internal::Hash(method_name);
45 const auto [service, method] = server_.FindMethod(service_id, method_id);
46 if (service == nullptr || method == nullptr) {
47 PW_LOG_WARN("Unknown method '%s'", full_method_name.c_str());
48 return Status::NotFound();
49 }
50
51 return CreateStream(id, service_id, method_id, method->type());
52 }
53
OnMessage(StreamId id,ByteSpan message)54 Status PwRpcHandler::OnMessage(StreamId id, ByteSpan message) {
55 auto stream = LookupStream(id);
56 if (!stream.ok()) {
57 PW_LOG_INFO("Handler.OnMessage id=%" PRIu32 " size=%zu: unknown stream",
58 id,
59 message.size());
60 return Status::NotFound();
61 }
62
63 switch (stream->method_type) {
64 case pw::rpc::MethodType::kUnary:
65 case pw::rpc::MethodType::kServerStreaming: {
66 auto packet = pw::rpc::internal::Packet(PacketType::kRequest,
67 channel_id_,
68 stream->service_id,
69 stream->method_id,
70 id,
71 message,
72 pw::OkStatus());
73 PW_TRY(server_.ProcessPacket(packet));
74 break;
75 }
76 case pw::rpc::MethodType::kClientStreaming:
77 case pw::rpc::MethodType::kBidirectionalStreaming: {
78 if (!stream->sent_request) {
79 auto packet = pw::rpc::internal::Packet(PacketType::kRequest,
80 channel_id_,
81 stream->service_id,
82 stream->method_id,
83 id,
84 {},
85 pw::OkStatus());
86 PW_TRY(server_.ProcessPacket(packet));
87 MarkSentRequest(id);
88 }
89
90 auto packet = pw::rpc::internal::Packet(PacketType::kClientStream,
91 channel_id_,
92 stream->service_id,
93 stream->method_id,
94 id,
95 message,
96 pw::OkStatus());
97 PW_TRY(server_.ProcessPacket(packet));
98 break;
99 }
100 default:
101 PW_LOG_WARN("Unexpected method type");
102 return Status::Internal();
103 }
104
105 return OkStatus();
106 }
107
OnHalfClose(StreamId id)108 void PwRpcHandler::OnHalfClose(StreamId id) {
109 auto stream = LookupStream(id);
110 if (!stream.ok()) {
111 PW_LOG_INFO("OnHalfClose unknown stream");
112 return;
113 }
114
115 if (stream->method_type == pw::rpc::MethodType::kClientStreaming ||
116 stream->method_type == pw::rpc::MethodType::kBidirectionalStreaming) {
117 auto packet =
118 pw::rpc::internal::Packet(PacketType::kClientRequestCompletion,
119 channel_id_,
120 stream->service_id,
121 stream->method_id,
122 id,
123 {},
124 pw::OkStatus());
125 ResetStream(id);
126
127 server_.ProcessPacket(packet).IgnoreError();
128 }
129 }
130
OnCancel(StreamId id)131 void PwRpcHandler::OnCancel(StreamId id) {
132 auto stream = LookupStream(id);
133 if (!stream.ok()) {
134 PW_LOG_INFO("OnCancel unknown stream");
135 return;
136 }
137
138 auto packet = pw::rpc::internal::Packet(PacketType::kClientError,
139 channel_id_,
140 stream->service_id,
141 stream->method_id,
142 id,
143 {},
144 pw::Status::Cancelled());
145 ResetStream(id);
146
147 server_.ProcessPacket(packet).IgnoreError();
148 }
149
LookupStream(StreamId id)150 Result<PwRpcHandler::Stream> PwRpcHandler::LookupStream(StreamId id) {
151 auto streams_locked = streams_.acquire();
152 for (size_t i = 0; i < streams_locked->size(); ++i) {
153 auto& stream = (*streams_locked)[i];
154 if (stream.id == id) {
155 return stream;
156 }
157 }
158 return Status::NotFound();
159 }
160
ResetAllStreams()161 void PwRpcHandler::ResetAllStreams() {
162 auto streams_locked = streams_.acquire();
163 for (size_t i = 0; i < streams_locked->size(); ++i) {
164 auto& stream = (*streams_locked)[i];
165 stream.id = 0;
166 }
167 }
168
ResetStream(StreamId id)169 void PwRpcHandler::ResetStream(StreamId id) {
170 auto streams_locked = streams_.acquire();
171 for (size_t i = 0; i < streams_locked->size(); ++i) {
172 auto& stream = (*streams_locked)[i];
173 if (stream.id == id) {
174 stream.id = 0;
175 break;
176 }
177 }
178 }
179
MarkSentRequest(StreamId id)180 void PwRpcHandler::MarkSentRequest(StreamId id) {
181 auto streams_locked = streams_.acquire();
182 for (size_t i = 0; i < streams_locked->size(); ++i) {
183 auto& stream = (*streams_locked)[i];
184 if (stream.id == id) {
185 stream.sent_request = true;
186 break;
187 }
188 }
189 }
190
CreateStream(StreamId id,uint32_t service_id,uint32_t method_id,pw::rpc::MethodType method_type)191 Status PwRpcHandler::CreateStream(StreamId id,
192 uint32_t service_id,
193 uint32_t method_id,
194 pw::rpc::MethodType method_type) {
195 auto streams_locked = streams_.acquire();
196
197 for (size_t i = 0; i < streams_locked->size(); ++i) {
198 auto& stream = (*streams_locked)[i];
199 if (!stream.id) {
200 stream.id = id;
201 stream.service_id = service_id;
202 stream.method_id = method_id;
203 stream.method_type = method_type;
204 stream.sent_request = false;
205 return OkStatus();
206 }
207 }
208 return Status::ResourceExhausted();
209 }
210
211 } // namespace pw::grpc
212