1 // Copyright 2019 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ 6 #define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ 7 8 #include <cstddef> 9 #include <cstdint> 10 #include <type_traits> 11 #include <utility> 12 #include <vector> 13 14 #include "absl/types/optional.h" 15 #include "osp/public/message_demuxer.h" 16 #include "osp/public/network_service_manager.h" 17 #include "osp/public/protocol_connection.h" 18 #include "platform/base/error.h" 19 #include "platform/base/macros.h" 20 #include "util/osp_logging.h" 21 22 namespace openscreen { 23 namespace osp { 24 25 template <typename T> 26 using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*); 27 28 // Provides a uniform way of accessing import properties of a request/response 29 // message pair from a template: request encode function, response decode 30 // function, request serializable data member. 31 template <typename T> 32 struct DefaultRequestCoderTraits { 33 public: 34 using RequestMsgType = typename T::RequestMsgType; 35 static constexpr MessageEncodingFunction<RequestMsgType> kEncoder = 36 T::kEncoder; 37 static constexpr MessageDecodingFunction<typename T::ResponseMsgType> 38 kDecoder = T::kDecoder; 39 serial_requestDefaultRequestCoderTraits40 static const RequestMsgType* serial_request(const T& data) { 41 return &data.request; 42 } serial_requestDefaultRequestCoderTraits43 static RequestMsgType* serial_request(T& data) { return &data.request; } 44 }; 45 46 // Provides a wrapper for the common pattern of sending a request message and 47 // waiting for a response message with a matching |request_id| field. It also 48 // handles the business of queueing messages to be sent until a protocol 49 // connection is available. 50 // 51 // Messages are written using WriteMessage. This will queue messages if there 52 // is no protocol connection or write them immediately if there is. When a 53 // matching response is received via the MessageDemuxer (taken from the global 54 // ProtocolConnectionClient), OnMatchedResponse is called on the provided 55 // Delegate object along with the original request that it matches. 56 template <typename RequestT, 57 typename RequestCoderTraits = DefaultRequestCoderTraits<RequestT>> 58 class RequestResponseHandler : public MessageDemuxer::MessageCallback { 59 public: 60 class Delegate { 61 public: 62 63 virtual void OnMatchedResponse(RequestT* request, 64 typename RequestT::ResponseMsgType* response, 65 uint64_t endpoint_id) = 0; 66 virtual void OnError(RequestT* request, Error error) = 0; 67 68 protected: 69 virtual ~Delegate() = default; 70 }; 71 RequestResponseHandler(Delegate * delegate)72 explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {} ~RequestResponseHandler()73 ~RequestResponseHandler() { Reset(); } 74 Reset()75 void Reset() { 76 connection_ = nullptr; 77 for (auto& message : to_send_) { 78 delegate_->OnError(&message.request, Error::Code::kRequestCancelled); 79 } 80 to_send_.clear(); 81 for (auto& message : sent_) { 82 delegate_->OnError(&message.request, Error::Code::kRequestCancelled); 83 } 84 sent_.clear(); 85 response_watch_ = MessageDemuxer::MessageWatch(); 86 } 87 88 // Write a message to the underlying protocol connection, or queue it until 89 // one is provided via SetConnection. If |id| is provided, it can be used to 90 // cancel the message via CancelMessage. 91 template <typename RequestTRval> 92 typename std::enable_if< 93 !std::is_lvalue_reference<RequestTRval>::value && 94 std::is_same<typename std::decay<RequestTRval>::type, 95 RequestT>::value, 96 Error>::type WriteMessage(absl::optional<uint64_t> id,RequestTRval && message)97 WriteMessage(absl::optional<uint64_t> id, RequestTRval&& message) { 98 auto* request_msg = RequestCoderTraits::serial_request(message); 99 if (connection_) { 100 request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); 101 Error result = 102 connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); 103 if (!result.ok()) { 104 return result; 105 } 106 sent_.emplace_back(RequestWithId{id, std::move(message)}); 107 EnsureResponseWatch(); 108 } else { 109 to_send_.emplace_back(RequestWithId{id, std::move(message)}); 110 } 111 return Error::None(); 112 } 113 114 template <typename RequestTRval> 115 typename std::enable_if< 116 !std::is_lvalue_reference<RequestTRval>::value && 117 std::is_same<typename std::decay<RequestTRval>::type, 118 RequestT>::value, 119 Error>::type WriteMessage(RequestTRval && message)120 WriteMessage(RequestTRval&& message) { 121 return WriteMessage(absl::nullopt, std::move(message)); 122 } 123 124 // Remove the message that was originally written with |id| from the send and 125 // sent queues so that we are no longer looking for a response. CancelMessage(uint64_t id)126 void CancelMessage(uint64_t id) { 127 to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(), 128 [&id](const RequestWithId& msg) { 129 return id == msg.id; 130 }), 131 to_send_.end()); 132 sent_.erase(std::remove_if( 133 sent_.begin(), sent_.end(), 134 [&id](const RequestWithId& msg) { return id == msg.id; }), 135 sent_.end()); 136 if (sent_.empty()) { 137 response_watch_ = MessageDemuxer::MessageWatch(); 138 } 139 } 140 141 // Assign a ProtocolConnection to this handler for writing messages. SetConnection(ProtocolConnection * connection)142 void SetConnection(ProtocolConnection* connection) { 143 connection_ = connection; 144 for (auto& message : to_send_) { 145 auto* request_msg = RequestCoderTraits::serial_request(message.request); 146 request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); 147 Error result = 148 connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); 149 if (result.ok()) { 150 sent_.emplace_back(std::move(message)); 151 } else { 152 delegate_->OnError(&message.request, result); 153 } 154 } 155 if (!to_send_.empty()) { 156 EnsureResponseWatch(); 157 } 158 to_send_.clear(); 159 } 160 161 // MessageDemuxer::MessageCallback overrides. OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)162 ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id, 163 uint64_t connection_id, 164 msgs::Type message_type, 165 const uint8_t* buffer, 166 size_t buffer_size, 167 Clock::time_point now) override { 168 if (message_type != RequestT::kResponseType) { 169 return 0; 170 } 171 typename RequestT::ResponseMsgType response; 172 ssize_t result = 173 RequestCoderTraits::kDecoder(buffer, buffer_size, &response); 174 if (result < 0) { 175 return 0; 176 } 177 auto it = std::find_if( 178 sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) { 179 return RequestCoderTraits::serial_request(msg.request)->request_id == 180 response.request_id; 181 }); 182 if (it != sent_.end()) { 183 delegate_->OnMatchedResponse(&it->request, &response, 184 connection_->endpoint_id()); 185 sent_.erase(it); 186 if (sent_.empty()) { 187 response_watch_ = MessageDemuxer::MessageWatch(); 188 } 189 } else { 190 OSP_LOG_WARN << "got response for unknown request id: " 191 << response.request_id; 192 } 193 return result; 194 } 195 196 private: 197 struct RequestWithId { 198 absl::optional<uint64_t> id; 199 RequestT request; 200 }; 201 EnsureResponseWatch()202 void EnsureResponseWatch() { 203 if (!response_watch_) { 204 response_watch_ = NetworkServiceManager::Get() 205 ->GetProtocolConnectionClient() 206 ->message_demuxer() 207 ->WatchMessageType(connection_->endpoint_id(), 208 RequestT::kResponseType, this); 209 } 210 } 211 GetNextRequestId(uint64_t endpoint_id)212 uint64_t GetNextRequestId(uint64_t endpoint_id) { 213 return NetworkServiceManager::Get() 214 ->GetProtocolConnectionClient() 215 ->endpoint_request_ids() 216 ->GetNextRequestId(endpoint_id); 217 } 218 219 ProtocolConnection* connection_ = nullptr; 220 Delegate* const delegate_; 221 std::vector<RequestWithId> to_send_; 222 std::vector<RequestWithId> sent_; 223 MessageDemuxer::MessageWatch response_watch_; 224 225 OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler); 226 }; 227 228 } // namespace osp 229 } // namespace openscreen 230 231 #endif // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ 232