xref: /aosp_15_r20/external/openscreen/osp/impl/presentation/presentation_receiver.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "osp/public/presentation/presentation_receiver.h"
6 
7 #include <algorithm>
8 #include <memory>
9 
10 #include "osp/impl/presentation/presentation_common.h"
11 #include "osp/msgs/osp_messages.h"
12 #include "osp/public/message_demuxer.h"
13 #include "osp/public/network_service_manager.h"
14 #include "osp/public/protocol_connection_server.h"
15 #include "platform/api/time.h"
16 #include "util/osp_logging.h"
17 #include "util/trace_logging.h"
18 
19 namespace openscreen {
20 namespace osp {
21 namespace {
22 
GetEventCloseReason(Connection::CloseReason reason)23 msgs::PresentationConnectionCloseEvent_reason GetEventCloseReason(
24     Connection::CloseReason reason) {
25   switch (reason) {
26     case Connection::CloseReason::kDiscarded:
27       return msgs::PresentationConnectionCloseEvent_reason::
28           kConnectionObjectDiscarded;
29 
30     case Connection::CloseReason::kError:
31       return msgs::PresentationConnectionCloseEvent_reason::
32           kUnrecoverableErrorWhileSendingOrReceivingMessage;
33 
34     case Connection::CloseReason::kClosed:  // fallthrough
35     default:
36       return msgs::PresentationConnectionCloseEvent_reason::kCloseMethodCalled;
37   }
38 }
39 
GetEventTerminationReason(TerminationReason reason)40 msgs::PresentationTerminationEvent_reason GetEventTerminationReason(
41     TerminationReason reason) {
42   switch (reason) {
43     case TerminationReason::kReceiverUserTerminated:
44       return msgs::PresentationTerminationEvent_reason::
45           kUserTerminatedViaReceiver;
46     case TerminationReason::kReceiverShuttingDown:
47       return msgs::PresentationTerminationEvent_reason::kReceiverPoweringDown;
48     case TerminationReason::kReceiverPresentationUnloaded:
49       return msgs::PresentationTerminationEvent_reason::
50           kReceiverAttemptedToNavigate;
51     case TerminationReason::kReceiverPresentationReplaced:
52       return msgs::PresentationTerminationEvent_reason::
53           kReceiverReplacedPresentation;
54     case TerminationReason::kReceiverIdleTooLong:
55       return msgs::PresentationTerminationEvent_reason::kReceiverIdleTooLong;
56     case TerminationReason::kReceiverError:
57       return msgs::PresentationTerminationEvent_reason::kReceiverCrashed;
58     case TerminationReason::kReceiverTerminateCalled:
59       return msgs::PresentationTerminationEvent_reason::
60           kReceiverCalledTerminate;
61     default:
62       return msgs::PresentationTerminationEvent_reason::kUnknown;
63   }
64 }
65 
WritePresentationInitiationResponse(const msgs::PresentationStartResponse & response,ProtocolConnection * connection)66 Error WritePresentationInitiationResponse(
67     const msgs::PresentationStartResponse& response,
68     ProtocolConnection* connection) {
69   return connection->WriteMessage(response,
70                                   msgs::EncodePresentationStartResponse);
71 }
72 
WritePresentationConnectionOpenResponse(const msgs::PresentationConnectionOpenResponse & response,ProtocolConnection * connection)73 Error WritePresentationConnectionOpenResponse(
74     const msgs::PresentationConnectionOpenResponse& response,
75     ProtocolConnection* connection) {
76   return connection->WriteMessage(
77       response, msgs::EncodePresentationConnectionOpenResponse);
78 }
79 
WritePresentationTerminationEvent(const msgs::PresentationTerminationEvent & event,ProtocolConnection * connection)80 Error WritePresentationTerminationEvent(
81     const msgs::PresentationTerminationEvent& event,
82     ProtocolConnection* connection) {
83   return connection->WriteMessage(event,
84                                   msgs::EncodePresentationTerminationEvent);
85 }
86 
WritePresentationTerminationResponse(const msgs::PresentationTerminationResponse & response,ProtocolConnection * connection)87 Error WritePresentationTerminationResponse(
88     const msgs::PresentationTerminationResponse& response,
89     ProtocolConnection* connection) {
90   return connection->WriteMessage(response,
91                                   msgs::EncodePresentationTerminationResponse);
92 }
93 
WritePresentationUrlAvailabilityResponse(const msgs::PresentationUrlAvailabilityResponse & response,ProtocolConnection * connection)94 Error WritePresentationUrlAvailabilityResponse(
95     const msgs::PresentationUrlAvailabilityResponse& response,
96     ProtocolConnection* connection) {
97   return connection->WriteMessage(
98       response, msgs::EncodePresentationUrlAvailabilityResponse);
99 }
100 
101 }  // namespace
102 
OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)103 ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id,
104                                           uint64_t connection_id,
105                                           msgs::Type message_type,
106                                           const uint8_t* buffer,
107                                           size_t buffer_size,
108                                           Clock::time_point now) {
109   TRACE_SCOPED(TraceCategory::kPresentation, "Receiver::OnStreamMessage");
110   switch (message_type) {
111     case msgs::Type::kPresentationUrlAvailabilityRequest: {
112       TRACE_SCOPED(TraceCategory::kPresentation,
113                    "kPresentationUrlAvailabilityRequest");
114       OSP_VLOG << "got presentation-url-availability-request";
115       msgs::PresentationUrlAvailabilityRequest request;
116       ssize_t decode_result = msgs::DecodePresentationUrlAvailabilityRequest(
117           buffer, buffer_size, &request);
118       if (decode_result < 0) {
119         OSP_LOG_WARN << "Presentation-url-availability-request parse error: "
120                      << decode_result;
121         TRACE_SET_RESULT(Error::Code::kParseError);
122         return Error::Code::kParseError;
123       }
124 
125       msgs::PresentationUrlAvailabilityResponse response;
126       response.request_id = request.request_id;
127 
128       response.url_availabilities = delegate_->OnUrlAvailabilityRequest(
129           request.watch_id, request.watch_duration, std::move(request.urls));
130       msgs::CborEncodeBuffer buffer;
131 
132       WritePresentationUrlAvailabilityResponse(
133           response, GetProtocolConnection(endpoint_id).get());
134       return decode_result;
135     }
136 
137     case msgs::Type::kPresentationStartRequest: {
138       TRACE_SCOPED(TraceCategory::kPresentation, "kPresentationStartRequest");
139       OSP_VLOG << "got presentation-start-request";
140       msgs::PresentationStartRequest request;
141       const ssize_t result =
142           msgs::DecodePresentationStartRequest(buffer, buffer_size, &request);
143       if (result < 0) {
144         OSP_LOG_WARN << "Presentation-initiation-request parse error: "
145                      << result;
146         TRACE_SET_RESULT(Error::Code::kParseError);
147         return Error::Code::kParseError;
148       }
149 
150       OSP_LOG_INFO << "Got an initiation request for: " << request.url;
151 
152       PresentationID presentation_id(std::move(request.presentation_id));
153       if (!presentation_id) {
154         msgs::PresentationStartResponse response;
155         response.request_id = request.request_id;
156         response.result =
157             msgs::PresentationStartResponse_result::kInvalidPresentationId;
158         Error write_error = WritePresentationInitiationResponse(
159             response, GetProtocolConnection(endpoint_id).get());
160 
161         if (!write_error.ok()) {
162           TRACE_SET_RESULT(write_error);
163           return write_error;
164         }
165 
166         return result;
167       }
168 
169       auto& response_list = queued_responses_[presentation_id];
170       QueuedResponse queued_response{
171           /* .type = */ QueuedResponse::Type::kInitiation,
172           /* .request_id = */ request.request_id,
173           /* .connection_id = */ this->GetNextConnectionId(),
174           /* .endpoint_id = */ endpoint_id};
175       response_list.push_back(std::move(queued_response));
176 
177       const bool starting = delegate_->StartPresentation(
178           Connection::PresentationInfo{presentation_id, request.url},
179           endpoint_id, request.headers);
180 
181       if (starting)
182         return result;
183 
184       queued_responses_.erase(presentation_id);
185       msgs::PresentationStartResponse response;
186       response.request_id = request.request_id;
187       response.result = msgs::PresentationStartResponse_result::kUnknownError;
188       Error write_error = WritePresentationInitiationResponse(
189           response, GetProtocolConnection(endpoint_id).get());
190       if (!write_error.ok()) {
191         TRACE_SET_RESULT(write_error);
192         return write_error;
193       }
194 
195       return result;
196     }
197 
198     case msgs::Type::kPresentationConnectionOpenRequest: {
199       TRACE_SCOPED(TraceCategory::kPresentation,
200                    "kPresentationConnectionOpenRequest");
201       OSP_VLOG << "Got a presentation-connection-open-request";
202       msgs::PresentationConnectionOpenRequest request;
203       const ssize_t result = msgs::DecodePresentationConnectionOpenRequest(
204           buffer, buffer_size, &request);
205       if (result < 0) {
206         OSP_LOG_WARN << "Presentation-connection-open-request parse error: "
207                      << result;
208         TRACE_SET_RESULT(Error::Code::kParseError);
209         return Error::Code::kParseError;
210       }
211 
212       PresentationID presentation_id(std::move(request.presentation_id));
213 
214       // TODO(jophba): add logic to queue presentation connection open
215       // (and terminate connection)
216       // requests to check against when a presentation starts, in case
217       // we get a request right before the beginning of the presentation.
218       if (!presentation_id || started_presentations_.find(presentation_id) ==
219                                   started_presentations_.end()) {
220         msgs::PresentationConnectionOpenResponse response;
221         response.request_id = request.request_id;
222         response.result = msgs::PresentationConnectionOpenResponse_result::
223             kInvalidPresentationId;
224         Error write_error = WritePresentationConnectionOpenResponse(
225             response, GetProtocolConnection(endpoint_id).get());
226         if (!write_error.ok()) {
227           TRACE_SET_RESULT(write_error);
228           return write_error;
229         }
230 
231         return result;
232       }
233 
234       // TODO(btolsch): We would also check that connection_id isn't already
235       // requested/in use but since the spec has already shifted to a
236       // receiver-chosen connection ID, we'll ignore that until we change our
237       // CDDL messages.
238       std::vector<QueuedResponse>& responses =
239           queued_responses_[presentation_id];
240       responses.emplace_back(
241           QueuedResponse{QueuedResponse::Type::kConnection, request.request_id,
242                          this->GetNextConnectionId(), endpoint_id});
243       bool connecting = delegate_->ConnectToPresentation(
244           request.request_id, presentation_id, endpoint_id);
245       if (connecting)
246         return result;
247 
248       responses.pop_back();
249       if (responses.empty())
250         queued_responses_.erase(presentation_id);
251 
252       msgs::PresentationConnectionOpenResponse response;
253       response.request_id = request.request_id;
254       response.result =
255           msgs::PresentationConnectionOpenResponse_result::kUnknownError;
256       Error write_error = WritePresentationConnectionOpenResponse(
257           response, GetProtocolConnection(endpoint_id).get());
258       if (!write_error.ok()) {
259         TRACE_SET_RESULT(write_error);
260         return write_error;
261       }
262 
263       return result;
264     }
265 
266     case msgs::Type::kPresentationTerminationRequest: {
267       TRACE_SCOPED(TraceCategory::kPresentation,
268                    "kPresentationTerminationRequest");
269       OSP_VLOG << "got presentation-termination-request";
270       msgs::PresentationTerminationRequest request;
271       const ssize_t result = msgs::DecodePresentationTerminationRequest(
272           buffer, buffer_size, &request);
273       if (result < 0) {
274         OSP_LOG_WARN << "Presentation-termination-request parse error: "
275                      << result;
276         TRACE_SET_RESULT(Error::Code::kParseError);
277         return Error::Code::kParseError;
278       }
279 
280       PresentationID presentation_id(std::move(request.presentation_id));
281       OSP_LOG_INFO << "Got termination request for: " << presentation_id;
282 
283       auto presentation_entry = started_presentations_.find(presentation_id);
284       if (presentation_id &&
285           presentation_entry != started_presentations_.end()) {
286         TerminationReason reason =
287             (request.reason == msgs::PresentationTerminationRequest_reason::
288                                    kUserTerminatedViaController)
289                 ? TerminationReason::kControllerTerminateCalled
290                 : TerminationReason::kControllerUserTerminated;
291         presentation_entry->second.terminate_request_id = request.request_id;
292         delegate_->TerminatePresentation(presentation_id, reason);
293 
294         msgs::PresentationTerminationResponse response;
295         response.request_id = request.request_id;
296         response.result = msgs::PresentationTerminationResponse_result::
297             kInvalidPresentationId;
298         Error write_error = WritePresentationTerminationResponse(
299             response, GetProtocolConnection(endpoint_id).get());
300         if (!write_error.ok()) {
301           TRACE_SET_RESULT(write_error);
302           return write_error;
303         }
304         return result;
305       }
306 
307       TerminationReason reason =
308           (request.reason == msgs::PresentationTerminationRequest_reason::
309                                  kControllerCalledTerminate)
310               ? TerminationReason::kControllerTerminateCalled
311               : TerminationReason::kControllerUserTerminated;
312       presentation_entry->second.terminate_request_id = request.request_id;
313       delegate_->TerminatePresentation(presentation_id, reason);
314 
315       return result;
316     }
317 
318     default:
319       TRACE_SET_RESULT(Error::Code::kUnknownMessageType);
320       return Error::Code::kUnknownMessageType;
321   }
322 }
323 
324 // TODO(crbug.com/openscreen/31): Remove singletons in the embedder API and
325 // protocol implementation layers and in presentation_connection, as well as
326 // unit tests. static
Get()327 Receiver* Receiver::Get() {
328   static Receiver& receiver = *new Receiver();
329   return &receiver;
330 }
331 
Init()332 void Receiver::Init() {
333   if (!connection_manager_) {
334     connection_manager_ =
335         std::make_unique<ConnectionManager>(GetServerDemuxer());
336   }
337 }
338 
Deinit()339 void Receiver::Deinit() {
340   connection_manager_.reset();
341 }
342 
SetReceiverDelegate(ReceiverDelegate * delegate)343 void Receiver::SetReceiverDelegate(ReceiverDelegate* delegate) {
344   OSP_DCHECK(!delegate_ || !delegate);
345   delegate_ = delegate;
346 
347   MessageDemuxer* demuxer = GetServerDemuxer();
348   if (delegate_) {
349     availability_watch_ = demuxer->SetDefaultMessageTypeWatch(
350         msgs::Type::kPresentationUrlAvailabilityRequest, this);
351     initiation_watch_ = demuxer->SetDefaultMessageTypeWatch(
352         msgs::Type::kPresentationStartRequest, this);
353     connection_watch_ = demuxer->SetDefaultMessageTypeWatch(
354         msgs::Type::kPresentationConnectionOpenRequest, this);
355     return;
356   }
357 
358   StopWatching(&availability_watch_);
359   StopWatching(&initiation_watch_);
360   StopWatching(&connection_watch_);
361 
362   std::vector<std::string> presentations_to_remove(
363       started_presentations_.size());
364   for (auto& it : started_presentations_) {
365     presentations_to_remove.push_back(it.first);
366   }
367 
368   for (auto& presentation_id : presentations_to_remove) {
369     OnPresentationTerminated(presentation_id,
370                              TerminationReason::kReceiverShuttingDown);
371   }
372 }
373 
OnPresentationStarted(const std::string & presentation_id,Connection * connection,ResponseResult result)374 Error Receiver::OnPresentationStarted(const std::string& presentation_id,
375                                       Connection* connection,
376                                       ResponseResult result) {
377   auto queued_responses_entry = queued_responses_.find(presentation_id);
378   if (queued_responses_entry == queued_responses_.end())
379     return Error::Code::kNoStartedPresentation;
380 
381   auto& responses = queued_responses_entry->second;
382   if ((responses.size() != 1) ||
383       (responses.front().type != QueuedResponse::Type::kInitiation)) {
384     return Error::Code::kPresentationAlreadyStarted;
385   }
386 
387   QueuedResponse& initiation_response = responses.front();
388   msgs::PresentationStartResponse response;
389   response.request_id = initiation_response.request_id;
390   auto protocol_connection =
391       GetProtocolConnection(initiation_response.endpoint_id);
392   auto* raw_protocol_connection_ptr = protocol_connection.get();
393 
394   OSP_VLOG << "presentation started with protocol_connection id: "
395            << protocol_connection->id();
396   if (result != ResponseResult::kSuccess) {
397     response.result = msgs::PresentationStartResponse_result::kUnknownError;
398 
399     queued_responses_.erase(queued_responses_entry);
400     return WritePresentationInitiationResponse(response,
401                                                raw_protocol_connection_ptr);
402   }
403 
404   response.result = msgs::PresentationStartResponse_result::kSuccess;
405   response.connection_id = connection->connection_id();
406 
407   Presentation& presentation = started_presentations_[presentation_id];
408   presentation.endpoint_id = initiation_response.endpoint_id;
409   connection->OnConnected(initiation_response.connection_id,
410                           initiation_response.endpoint_id,
411                           std::move(protocol_connection));
412   presentation.connections.push_back(connection);
413   connection_manager_->AddConnection(connection);
414 
415   presentation.terminate_watch = GetServerDemuxer()->WatchMessageType(
416       initiation_response.endpoint_id,
417       msgs::Type::kPresentationTerminationRequest, this);
418 
419   queued_responses_.erase(queued_responses_entry);
420   return WritePresentationInitiationResponse(response,
421                                              raw_protocol_connection_ptr);
422 }
423 
OnConnectionCreated(uint64_t request_id,Connection * connection,ResponseResult result)424 Error Receiver::OnConnectionCreated(uint64_t request_id,
425                                     Connection* connection,
426                                     ResponseResult result) {
427   const auto presentation_id = connection->presentation_info().id;
428 
429   ErrorOr<QueuedResponseIterator> connection_response =
430       GetQueuedResponse(presentation_id, request_id);
431   if (connection_response.is_error()) {
432     return connection_response.error();
433   }
434   connection->OnConnected(
435       connection_response.value()->connection_id,
436       connection_response.value()->endpoint_id,
437       NetworkServiceManager::Get()
438           ->GetProtocolConnectionServer()
439           ->CreateProtocolConnection(connection_response.value()->endpoint_id));
440 
441   started_presentations_[presentation_id].connections.push_back(connection);
442   connection_manager_->AddConnection(connection);
443 
444   msgs::PresentationConnectionOpenResponse response;
445   response.request_id = request_id;
446   response.result = msgs::PresentationConnectionOpenResponse_result::kSuccess;
447   response.connection_id = connection->connection_id();
448 
449   auto protocol_connection =
450       GetProtocolConnection(connection_response.value()->endpoint_id);
451 
452   WritePresentationConnectionOpenResponse(response, protocol_connection.get());
453 
454   DeleteQueuedResponse(presentation_id, connection_response.value());
455   return Error::None();
456 }
457 
CloseConnection(Connection * connection,Connection::CloseReason reason)458 Error Receiver::CloseConnection(Connection* connection,
459                                 Connection::CloseReason reason) {
460   std::unique_ptr<ProtocolConnection> protocol_connection =
461       GetProtocolConnection(connection->endpoint_id());
462 
463   if (!protocol_connection)
464     return Error::Code::kNoActiveConnection;
465 
466   msgs::PresentationConnectionCloseEvent event;
467   event.connection_id = connection->connection_id();
468   event.reason = GetEventCloseReason(reason);
469   event.has_error_message = false;
470   msgs::CborEncodeBuffer buffer;
471   return protocol_connection->WriteMessage(
472       event, msgs::EncodePresentationConnectionCloseEvent);
473 }
474 
OnPresentationTerminated(const std::string & presentation_id,TerminationReason reason)475 Error Receiver::OnPresentationTerminated(const std::string& presentation_id,
476                                          TerminationReason reason) {
477   auto presentation_entry = started_presentations_.find(presentation_id);
478   if (presentation_entry == started_presentations_.end())
479     return Error::Code::kNoStartedPresentation;
480 
481   Presentation& presentation = presentation_entry->second;
482   presentation.terminate_watch = MessageDemuxer::MessageWatch();
483   std::unique_ptr<ProtocolConnection> protocol_connection =
484       GetProtocolConnection(presentation.endpoint_id);
485 
486   if (!protocol_connection)
487     return Error::Code::kNoActiveConnection;
488 
489   for (auto* connection : presentation.connections)
490     connection->OnTerminated();
491 
492   if (presentation.terminate_request_id) {
493     // TODO(btolsch): Also timeout if this point isn't reached.
494     msgs::PresentationTerminationResponse response;
495     response.request_id = presentation.terminate_request_id;
496     response.result = msgs::PresentationTerminationResponse_result::kSuccess;
497     started_presentations_.erase(presentation_entry);
498     return WritePresentationTerminationResponse(response,
499                                                 protocol_connection.get());
500   }
501 
502   msgs::PresentationTerminationEvent event;
503   event.presentation_id = presentation_id;
504   event.reason = GetEventTerminationReason(reason);
505   started_presentations_.erase(presentation_entry);
506   return WritePresentationTerminationEvent(event, protocol_connection.get());
507 }
508 
OnConnectionDestroyed(Connection * connection)509 void Receiver::OnConnectionDestroyed(Connection* connection) {
510   auto presentation_entry =
511       started_presentations_.find(connection->presentation_info().id);
512   if (presentation_entry == started_presentations_.end())
513     return;
514 
515   std::vector<Connection*>& connections =
516       presentation_entry->second.connections;
517 
518   auto past_the_end =
519       std::remove(connections.begin(), connections.end(), connection);
520   // An additional call to "erase" is necessary to actually adjust the size
521   // of the vector.
522   connections.erase(past_the_end, connections.end());
523 
524   connection_manager_->RemoveConnection(connection);
525 }
526 
527 Receiver::Receiver() = default;
528 
529 Receiver::~Receiver() = default;
530 
DeleteQueuedResponse(const std::string & presentation_id,Receiver::QueuedResponseIterator response)531 void Receiver::DeleteQueuedResponse(const std::string& presentation_id,
532                                     Receiver::QueuedResponseIterator response) {
533   auto entry = queued_responses_.find(presentation_id);
534   entry->second.erase(response);
535   if (entry->second.empty())
536     queued_responses_.erase(entry);
537 }
538 
GetQueuedResponse(const std::string & presentation_id,uint64_t request_id) const539 ErrorOr<Receiver::QueuedResponseIterator> Receiver::GetQueuedResponse(
540     const std::string& presentation_id,
541     uint64_t request_id) const {
542   auto entry = queued_responses_.find(presentation_id);
543   if (entry == queued_responses_.end()) {
544     OSP_LOG_WARN << "connection created for unknown request";
545     return Error::Code::kUnknownRequestId;
546   }
547 
548   const std::vector<QueuedResponse>& responses = entry->second;
549   Receiver::QueuedResponseIterator it =
550       std::find_if(responses.begin(), responses.end(),
551                    [request_id](const QueuedResponse& response) {
552                      return response.request_id == request_id;
553                    });
554 
555   if (it == responses.end()) {
556     OSP_LOG_WARN << "connection created for unknown request";
557     return Error::Code::kUnknownRequestId;
558   }
559 
560   return it;
561 }
562 
GetNextConnectionId()563 uint64_t Receiver::GetNextConnectionId() {
564   static uint64_t request_id = 0;
565   return request_id++;
566 }
567 
568 }  // namespace osp
569 }  // namespace openscreen
570