xref: /aosp_15_r20/external/openscreen/osp/public/request_response_handler.h (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
6 #define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
7 
8 #include <cstddef>
9 #include <cstdint>
10 #include <type_traits>
11 #include <utility>
12 #include <vector>
13 
14 #include "absl/types/optional.h"
15 #include "osp/public/message_demuxer.h"
16 #include "osp/public/network_service_manager.h"
17 #include "osp/public/protocol_connection.h"
18 #include "platform/base/error.h"
19 #include "platform/base/macros.h"
20 #include "util/osp_logging.h"
21 
22 namespace openscreen {
23 namespace osp {
24 
25 template <typename T>
26 using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*);
27 
28 // Provides a uniform way of accessing import properties of a request/response
29 // message pair from a template: request encode function, response decode
30 // function, request serializable data member.
31 template <typename T>
32 struct DefaultRequestCoderTraits {
33  public:
34   using RequestMsgType = typename T::RequestMsgType;
35   static constexpr MessageEncodingFunction<RequestMsgType> kEncoder =
36       T::kEncoder;
37   static constexpr MessageDecodingFunction<typename T::ResponseMsgType>
38       kDecoder = T::kDecoder;
39 
serial_requestDefaultRequestCoderTraits40   static const RequestMsgType* serial_request(const T& data) {
41     return &data.request;
42   }
serial_requestDefaultRequestCoderTraits43   static RequestMsgType* serial_request(T& data) { return &data.request; }
44 };
45 
46 // Provides a wrapper for the common pattern of sending a request message and
47 // waiting for a response message with a matching |request_id| field.  It also
48 // handles the business of queueing messages to be sent until a protocol
49 // connection is available.
50 //
51 // Messages are written using WriteMessage.  This will queue messages if there
52 // is no protocol connection or write them immediately if there is.  When a
53 // matching response is received via the MessageDemuxer (taken from the global
54 // ProtocolConnectionClient), OnMatchedResponse is called on the provided
55 // Delegate object along with the original request that it matches.
56 template <typename RequestT,
57           typename RequestCoderTraits = DefaultRequestCoderTraits<RequestT>>
58 class RequestResponseHandler : public MessageDemuxer::MessageCallback {
59  public:
60   class Delegate {
61    public:
62 
63     virtual void OnMatchedResponse(RequestT* request,
64                                    typename RequestT::ResponseMsgType* response,
65                                    uint64_t endpoint_id) = 0;
66     virtual void OnError(RequestT* request, Error error) = 0;
67 
68    protected:
69     virtual ~Delegate() = default;
70   };
71 
RequestResponseHandler(Delegate * delegate)72   explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {}
~RequestResponseHandler()73   ~RequestResponseHandler() { Reset(); }
74 
Reset()75   void Reset() {
76     connection_ = nullptr;
77     for (auto& message : to_send_) {
78       delegate_->OnError(&message.request, Error::Code::kRequestCancelled);
79     }
80     to_send_.clear();
81     for (auto& message : sent_) {
82       delegate_->OnError(&message.request, Error::Code::kRequestCancelled);
83     }
84     sent_.clear();
85     response_watch_ = MessageDemuxer::MessageWatch();
86   }
87 
88   // Write a message to the underlying protocol connection, or queue it until
89   // one is provided via SetConnection.  If |id| is provided, it can be used to
90   // cancel the message via CancelMessage.
91   template <typename RequestTRval>
92   typename std::enable_if<
93       !std::is_lvalue_reference<RequestTRval>::value &&
94           std::is_same<typename std::decay<RequestTRval>::type,
95                        RequestT>::value,
96       Error>::type
WriteMessage(absl::optional<uint64_t> id,RequestTRval && message)97   WriteMessage(absl::optional<uint64_t> id, RequestTRval&& message) {
98     auto* request_msg = RequestCoderTraits::serial_request(message);
99     if (connection_) {
100       request_msg->request_id = GetNextRequestId(connection_->endpoint_id());
101       Error result =
102           connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder);
103       if (!result.ok()) {
104         return result;
105       }
106       sent_.emplace_back(RequestWithId{id, std::move(message)});
107       EnsureResponseWatch();
108     } else {
109       to_send_.emplace_back(RequestWithId{id, std::move(message)});
110     }
111     return Error::None();
112   }
113 
114   template <typename RequestTRval>
115   typename std::enable_if<
116       !std::is_lvalue_reference<RequestTRval>::value &&
117           std::is_same<typename std::decay<RequestTRval>::type,
118                        RequestT>::value,
119       Error>::type
WriteMessage(RequestTRval && message)120   WriteMessage(RequestTRval&& message) {
121     return WriteMessage(absl::nullopt, std::move(message));
122   }
123 
124   // Remove the message that was originally written with |id| from the send and
125   // sent queues so that we are no longer looking for a response.
CancelMessage(uint64_t id)126   void CancelMessage(uint64_t id) {
127     to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(),
128                                   [&id](const RequestWithId& msg) {
129                                     return id == msg.id;
130                                   }),
131                    to_send_.end());
132     sent_.erase(std::remove_if(
133                     sent_.begin(), sent_.end(),
134                     [&id](const RequestWithId& msg) { return id == msg.id; }),
135                 sent_.end());
136     if (sent_.empty()) {
137       response_watch_ = MessageDemuxer::MessageWatch();
138     }
139   }
140 
141   // Assign a ProtocolConnection to this handler for writing messages.
SetConnection(ProtocolConnection * connection)142   void SetConnection(ProtocolConnection* connection) {
143     connection_ = connection;
144     for (auto& message : to_send_) {
145       auto* request_msg = RequestCoderTraits::serial_request(message.request);
146       request_msg->request_id = GetNextRequestId(connection_->endpoint_id());
147       Error result =
148           connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder);
149       if (result.ok()) {
150         sent_.emplace_back(std::move(message));
151       } else {
152         delegate_->OnError(&message.request, result);
153       }
154     }
155     if (!to_send_.empty()) {
156       EnsureResponseWatch();
157     }
158     to_send_.clear();
159   }
160 
161   // MessageDemuxer::MessageCallback overrides.
OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)162   ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id,
163                                   uint64_t connection_id,
164                                   msgs::Type message_type,
165                                   const uint8_t* buffer,
166                                   size_t buffer_size,
167                                   Clock::time_point now) override {
168     if (message_type != RequestT::kResponseType) {
169       return 0;
170     }
171     typename RequestT::ResponseMsgType response;
172     ssize_t result =
173         RequestCoderTraits::kDecoder(buffer, buffer_size, &response);
174     if (result < 0) {
175       return 0;
176     }
177     auto it = std::find_if(
178         sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) {
179           return RequestCoderTraits::serial_request(msg.request)->request_id ==
180                  response.request_id;
181         });
182     if (it != sent_.end()) {
183       delegate_->OnMatchedResponse(&it->request, &response,
184                                    connection_->endpoint_id());
185       sent_.erase(it);
186       if (sent_.empty()) {
187         response_watch_ = MessageDemuxer::MessageWatch();
188       }
189     } else {
190       OSP_LOG_WARN << "got response for unknown request id: "
191                    << response.request_id;
192     }
193     return result;
194   }
195 
196  private:
197   struct RequestWithId {
198     absl::optional<uint64_t> id;
199     RequestT request;
200   };
201 
EnsureResponseWatch()202   void EnsureResponseWatch() {
203     if (!response_watch_) {
204       response_watch_ = NetworkServiceManager::Get()
205                             ->GetProtocolConnectionClient()
206                             ->message_demuxer()
207                             ->WatchMessageType(connection_->endpoint_id(),
208                                                RequestT::kResponseType, this);
209     }
210   }
211 
GetNextRequestId(uint64_t endpoint_id)212   uint64_t GetNextRequestId(uint64_t endpoint_id) {
213     return NetworkServiceManager::Get()
214         ->GetProtocolConnectionClient()
215         ->endpoint_request_ids()
216         ->GetNextRequestId(endpoint_id);
217   }
218 
219   ProtocolConnection* connection_ = nullptr;
220   Delegate* const delegate_;
221   std::vector<RequestWithId> to_send_;
222   std::vector<RequestWithId> sent_;
223   MessageDemuxer::MessageWatch response_watch_;
224 
225   OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler);
226 };
227 
228 }  // namespace osp
229 }  // namespace openscreen
230 
231 #endif  // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_
232