xref: /aosp_15_r20/external/pigweed/pw_grpc/pw_rpc_handler.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_grpc/pw_rpc_handler.h"
16 
17 #include <cinttypes>
18 
19 namespace pw::grpc {
20 
21 using pw::rpc::internal::pwpb::PacketType;
22 
OnClose(StreamId id)23 void PwRpcHandler::OnClose(StreamId id) { ResetStream(id); }
24 
OnNewConnection()25 void PwRpcHandler::OnNewConnection() { ResetAllStreams(); }
26 
OnNew(StreamId id,InlineString<kMaxMethodNameSize> full_method_name)27 Status PwRpcHandler::OnNew(StreamId id,
28                            InlineString<kMaxMethodNameSize> full_method_name) {
29   // Parse out service and method from `/grpc.examples.echo.Echo/UnaryEcho`
30   // formatted name.
31   std::string_view view = std::string_view(full_method_name);
32   auto split_pos = view.find_last_of('/');
33   if (view.empty() || view[0] != '/' || split_pos == std::string_view::npos) {
34     PW_LOG_WARN("Can't determine service/method name id=%" PRIu32 " name=%s",
35                 id,
36                 full_method_name.c_str());
37     return Status::NotFound();
38   }
39 
40   // Look up method in the server.
41   std::string_view service_name = view.substr(1, split_pos - 1);
42   std::string_view method_name = view.substr(split_pos + 1);
43   uint32_t service_id = rpc::internal::Hash(service_name);
44   uint32_t method_id = rpc::internal::Hash(method_name);
45   const auto [service, method] = server_.FindMethod(service_id, method_id);
46   if (service == nullptr || method == nullptr) {
47     PW_LOG_WARN("Unknown method '%s'", full_method_name.c_str());
48     return Status::NotFound();
49   }
50 
51   return CreateStream(id, service_id, method_id, method->type());
52 }
53 
OnMessage(StreamId id,ByteSpan message)54 Status PwRpcHandler::OnMessage(StreamId id, ByteSpan message) {
55   auto stream = LookupStream(id);
56   if (!stream.ok()) {
57     PW_LOG_INFO("Handler.OnMessage id=%" PRIu32 " size=%zu: unknown stream",
58                 id,
59                 message.size());
60     return Status::NotFound();
61   }
62 
63   switch (stream->method_type) {
64     case pw::rpc::MethodType::kUnary:
65     case pw::rpc::MethodType::kServerStreaming: {
66       auto packet = pw::rpc::internal::Packet(PacketType::kRequest,
67                                               channel_id_,
68                                               stream->service_id,
69                                               stream->method_id,
70                                               id,
71                                               message,
72                                               pw::OkStatus());
73       PW_TRY(server_.ProcessPacket(packet));
74       break;
75     }
76     case pw::rpc::MethodType::kClientStreaming:
77     case pw::rpc::MethodType::kBidirectionalStreaming: {
78       if (!stream->sent_request) {
79         auto packet = pw::rpc::internal::Packet(PacketType::kRequest,
80                                                 channel_id_,
81                                                 stream->service_id,
82                                                 stream->method_id,
83                                                 id,
84                                                 {},
85                                                 pw::OkStatus());
86         PW_TRY(server_.ProcessPacket(packet));
87         MarkSentRequest(id);
88       }
89 
90       auto packet = pw::rpc::internal::Packet(PacketType::kClientStream,
91                                               channel_id_,
92                                               stream->service_id,
93                                               stream->method_id,
94                                               id,
95                                               message,
96                                               pw::OkStatus());
97       PW_TRY(server_.ProcessPacket(packet));
98       break;
99     }
100     default:
101       PW_LOG_WARN("Unexpected method type");
102       return Status::Internal();
103   }
104 
105   return OkStatus();
106 }
107 
OnHalfClose(StreamId id)108 void PwRpcHandler::OnHalfClose(StreamId id) {
109   auto stream = LookupStream(id);
110   if (!stream.ok()) {
111     PW_LOG_INFO("OnHalfClose unknown stream");
112     return;
113   }
114 
115   if (stream->method_type == pw::rpc::MethodType::kClientStreaming ||
116       stream->method_type == pw::rpc::MethodType::kBidirectionalStreaming) {
117     auto packet =
118         pw::rpc::internal::Packet(PacketType::kClientRequestCompletion,
119                                   channel_id_,
120                                   stream->service_id,
121                                   stream->method_id,
122                                   id,
123                                   {},
124                                   pw::OkStatus());
125     ResetStream(id);
126 
127     server_.ProcessPacket(packet).IgnoreError();
128   }
129 }
130 
OnCancel(StreamId id)131 void PwRpcHandler::OnCancel(StreamId id) {
132   auto stream = LookupStream(id);
133   if (!stream.ok()) {
134     PW_LOG_INFO("OnCancel unknown stream");
135     return;
136   }
137 
138   auto packet = pw::rpc::internal::Packet(PacketType::kClientError,
139                                           channel_id_,
140                                           stream->service_id,
141                                           stream->method_id,
142                                           id,
143                                           {},
144                                           pw::Status::Cancelled());
145   ResetStream(id);
146 
147   server_.ProcessPacket(packet).IgnoreError();
148 }
149 
LookupStream(StreamId id)150 Result<PwRpcHandler::Stream> PwRpcHandler::LookupStream(StreamId id) {
151   auto streams_locked = streams_.acquire();
152   for (size_t i = 0; i < streams_locked->size(); ++i) {
153     auto& stream = (*streams_locked)[i];
154     if (stream.id == id) {
155       return stream;
156     }
157   }
158   return Status::NotFound();
159 }
160 
ResetAllStreams()161 void PwRpcHandler::ResetAllStreams() {
162   auto streams_locked = streams_.acquire();
163   for (size_t i = 0; i < streams_locked->size(); ++i) {
164     auto& stream = (*streams_locked)[i];
165     stream.id = 0;
166   }
167 }
168 
ResetStream(StreamId id)169 void PwRpcHandler::ResetStream(StreamId id) {
170   auto streams_locked = streams_.acquire();
171   for (size_t i = 0; i < streams_locked->size(); ++i) {
172     auto& stream = (*streams_locked)[i];
173     if (stream.id == id) {
174       stream.id = 0;
175       break;
176     }
177   }
178 }
179 
MarkSentRequest(StreamId id)180 void PwRpcHandler::MarkSentRequest(StreamId id) {
181   auto streams_locked = streams_.acquire();
182   for (size_t i = 0; i < streams_locked->size(); ++i) {
183     auto& stream = (*streams_locked)[i];
184     if (stream.id == id) {
185       stream.sent_request = true;
186       break;
187     }
188   }
189 }
190 
CreateStream(StreamId id,uint32_t service_id,uint32_t method_id,pw::rpc::MethodType method_type)191 Status PwRpcHandler::CreateStream(StreamId id,
192                                   uint32_t service_id,
193                                   uint32_t method_id,
194                                   pw::rpc::MethodType method_type) {
195   auto streams_locked = streams_.acquire();
196 
197   for (size_t i = 0; i < streams_locked->size(); ++i) {
198     auto& stream = (*streams_locked)[i];
199     if (!stream.id) {
200       stream.id = id;
201       stream.service_id = service_id;
202       stream.method_id = method_id;
203       stream.method_type = method_type;
204       stream.sent_request = false;
205       return OkStatus();
206     }
207   }
208   return Status::ResourceExhausted();
209 }
210 
211 }  // namespace pw::grpc
212