1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "osp/public/presentation/presentation_receiver.h"
6
7 #include <algorithm>
8 #include <memory>
9
10 #include "osp/impl/presentation/presentation_common.h"
11 #include "osp/msgs/osp_messages.h"
12 #include "osp/public/message_demuxer.h"
13 #include "osp/public/network_service_manager.h"
14 #include "osp/public/protocol_connection_server.h"
15 #include "platform/api/time.h"
16 #include "util/osp_logging.h"
17 #include "util/trace_logging.h"
18
19 namespace openscreen {
20 namespace osp {
21 namespace {
22
GetEventCloseReason(Connection::CloseReason reason)23 msgs::PresentationConnectionCloseEvent_reason GetEventCloseReason(
24 Connection::CloseReason reason) {
25 switch (reason) {
26 case Connection::CloseReason::kDiscarded:
27 return msgs::PresentationConnectionCloseEvent_reason::
28 kConnectionObjectDiscarded;
29
30 case Connection::CloseReason::kError:
31 return msgs::PresentationConnectionCloseEvent_reason::
32 kUnrecoverableErrorWhileSendingOrReceivingMessage;
33
34 case Connection::CloseReason::kClosed: // fallthrough
35 default:
36 return msgs::PresentationConnectionCloseEvent_reason::kCloseMethodCalled;
37 }
38 }
39
GetEventTerminationReason(TerminationReason reason)40 msgs::PresentationTerminationEvent_reason GetEventTerminationReason(
41 TerminationReason reason) {
42 switch (reason) {
43 case TerminationReason::kReceiverUserTerminated:
44 return msgs::PresentationTerminationEvent_reason::
45 kUserTerminatedViaReceiver;
46 case TerminationReason::kReceiverShuttingDown:
47 return msgs::PresentationTerminationEvent_reason::kReceiverPoweringDown;
48 case TerminationReason::kReceiverPresentationUnloaded:
49 return msgs::PresentationTerminationEvent_reason::
50 kReceiverAttemptedToNavigate;
51 case TerminationReason::kReceiverPresentationReplaced:
52 return msgs::PresentationTerminationEvent_reason::
53 kReceiverReplacedPresentation;
54 case TerminationReason::kReceiverIdleTooLong:
55 return msgs::PresentationTerminationEvent_reason::kReceiverIdleTooLong;
56 case TerminationReason::kReceiverError:
57 return msgs::PresentationTerminationEvent_reason::kReceiverCrashed;
58 case TerminationReason::kReceiverTerminateCalled:
59 return msgs::PresentationTerminationEvent_reason::
60 kReceiverCalledTerminate;
61 default:
62 return msgs::PresentationTerminationEvent_reason::kUnknown;
63 }
64 }
65
WritePresentationInitiationResponse(const msgs::PresentationStartResponse & response,ProtocolConnection * connection)66 Error WritePresentationInitiationResponse(
67 const msgs::PresentationStartResponse& response,
68 ProtocolConnection* connection) {
69 return connection->WriteMessage(response,
70 msgs::EncodePresentationStartResponse);
71 }
72
WritePresentationConnectionOpenResponse(const msgs::PresentationConnectionOpenResponse & response,ProtocolConnection * connection)73 Error WritePresentationConnectionOpenResponse(
74 const msgs::PresentationConnectionOpenResponse& response,
75 ProtocolConnection* connection) {
76 return connection->WriteMessage(
77 response, msgs::EncodePresentationConnectionOpenResponse);
78 }
79
WritePresentationTerminationEvent(const msgs::PresentationTerminationEvent & event,ProtocolConnection * connection)80 Error WritePresentationTerminationEvent(
81 const msgs::PresentationTerminationEvent& event,
82 ProtocolConnection* connection) {
83 return connection->WriteMessage(event,
84 msgs::EncodePresentationTerminationEvent);
85 }
86
WritePresentationTerminationResponse(const msgs::PresentationTerminationResponse & response,ProtocolConnection * connection)87 Error WritePresentationTerminationResponse(
88 const msgs::PresentationTerminationResponse& response,
89 ProtocolConnection* connection) {
90 return connection->WriteMessage(response,
91 msgs::EncodePresentationTerminationResponse);
92 }
93
WritePresentationUrlAvailabilityResponse(const msgs::PresentationUrlAvailabilityResponse & response,ProtocolConnection * connection)94 Error WritePresentationUrlAvailabilityResponse(
95 const msgs::PresentationUrlAvailabilityResponse& response,
96 ProtocolConnection* connection) {
97 return connection->WriteMessage(
98 response, msgs::EncodePresentationUrlAvailabilityResponse);
99 }
100
101 } // namespace
102
OnStreamMessage(uint64_t endpoint_id,uint64_t connection_id,msgs::Type message_type,const uint8_t * buffer,size_t buffer_size,Clock::time_point now)103 ErrorOr<size_t> Receiver::OnStreamMessage(uint64_t endpoint_id,
104 uint64_t connection_id,
105 msgs::Type message_type,
106 const uint8_t* buffer,
107 size_t buffer_size,
108 Clock::time_point now) {
109 TRACE_SCOPED(TraceCategory::kPresentation, "Receiver::OnStreamMessage");
110 switch (message_type) {
111 case msgs::Type::kPresentationUrlAvailabilityRequest: {
112 TRACE_SCOPED(TraceCategory::kPresentation,
113 "kPresentationUrlAvailabilityRequest");
114 OSP_VLOG << "got presentation-url-availability-request";
115 msgs::PresentationUrlAvailabilityRequest request;
116 ssize_t decode_result = msgs::DecodePresentationUrlAvailabilityRequest(
117 buffer, buffer_size, &request);
118 if (decode_result < 0) {
119 OSP_LOG_WARN << "Presentation-url-availability-request parse error: "
120 << decode_result;
121 TRACE_SET_RESULT(Error::Code::kParseError);
122 return Error::Code::kParseError;
123 }
124
125 msgs::PresentationUrlAvailabilityResponse response;
126 response.request_id = request.request_id;
127
128 response.url_availabilities = delegate_->OnUrlAvailabilityRequest(
129 request.watch_id, request.watch_duration, std::move(request.urls));
130 msgs::CborEncodeBuffer buffer;
131
132 WritePresentationUrlAvailabilityResponse(
133 response, GetProtocolConnection(endpoint_id).get());
134 return decode_result;
135 }
136
137 case msgs::Type::kPresentationStartRequest: {
138 TRACE_SCOPED(TraceCategory::kPresentation, "kPresentationStartRequest");
139 OSP_VLOG << "got presentation-start-request";
140 msgs::PresentationStartRequest request;
141 const ssize_t result =
142 msgs::DecodePresentationStartRequest(buffer, buffer_size, &request);
143 if (result < 0) {
144 OSP_LOG_WARN << "Presentation-initiation-request parse error: "
145 << result;
146 TRACE_SET_RESULT(Error::Code::kParseError);
147 return Error::Code::kParseError;
148 }
149
150 OSP_LOG_INFO << "Got an initiation request for: " << request.url;
151
152 PresentationID presentation_id(std::move(request.presentation_id));
153 if (!presentation_id) {
154 msgs::PresentationStartResponse response;
155 response.request_id = request.request_id;
156 response.result =
157 msgs::PresentationStartResponse_result::kInvalidPresentationId;
158 Error write_error = WritePresentationInitiationResponse(
159 response, GetProtocolConnection(endpoint_id).get());
160
161 if (!write_error.ok()) {
162 TRACE_SET_RESULT(write_error);
163 return write_error;
164 }
165
166 return result;
167 }
168
169 auto& response_list = queued_responses_[presentation_id];
170 QueuedResponse queued_response{
171 /* .type = */ QueuedResponse::Type::kInitiation,
172 /* .request_id = */ request.request_id,
173 /* .connection_id = */ this->GetNextConnectionId(),
174 /* .endpoint_id = */ endpoint_id};
175 response_list.push_back(std::move(queued_response));
176
177 const bool starting = delegate_->StartPresentation(
178 Connection::PresentationInfo{presentation_id, request.url},
179 endpoint_id, request.headers);
180
181 if (starting)
182 return result;
183
184 queued_responses_.erase(presentation_id);
185 msgs::PresentationStartResponse response;
186 response.request_id = request.request_id;
187 response.result = msgs::PresentationStartResponse_result::kUnknownError;
188 Error write_error = WritePresentationInitiationResponse(
189 response, GetProtocolConnection(endpoint_id).get());
190 if (!write_error.ok()) {
191 TRACE_SET_RESULT(write_error);
192 return write_error;
193 }
194
195 return result;
196 }
197
198 case msgs::Type::kPresentationConnectionOpenRequest: {
199 TRACE_SCOPED(TraceCategory::kPresentation,
200 "kPresentationConnectionOpenRequest");
201 OSP_VLOG << "Got a presentation-connection-open-request";
202 msgs::PresentationConnectionOpenRequest request;
203 const ssize_t result = msgs::DecodePresentationConnectionOpenRequest(
204 buffer, buffer_size, &request);
205 if (result < 0) {
206 OSP_LOG_WARN << "Presentation-connection-open-request parse error: "
207 << result;
208 TRACE_SET_RESULT(Error::Code::kParseError);
209 return Error::Code::kParseError;
210 }
211
212 PresentationID presentation_id(std::move(request.presentation_id));
213
214 // TODO(jophba): add logic to queue presentation connection open
215 // (and terminate connection)
216 // requests to check against when a presentation starts, in case
217 // we get a request right before the beginning of the presentation.
218 if (!presentation_id || started_presentations_.find(presentation_id) ==
219 started_presentations_.end()) {
220 msgs::PresentationConnectionOpenResponse response;
221 response.request_id = request.request_id;
222 response.result = msgs::PresentationConnectionOpenResponse_result::
223 kInvalidPresentationId;
224 Error write_error = WritePresentationConnectionOpenResponse(
225 response, GetProtocolConnection(endpoint_id).get());
226 if (!write_error.ok()) {
227 TRACE_SET_RESULT(write_error);
228 return write_error;
229 }
230
231 return result;
232 }
233
234 // TODO(btolsch): We would also check that connection_id isn't already
235 // requested/in use but since the spec has already shifted to a
236 // receiver-chosen connection ID, we'll ignore that until we change our
237 // CDDL messages.
238 std::vector<QueuedResponse>& responses =
239 queued_responses_[presentation_id];
240 responses.emplace_back(
241 QueuedResponse{QueuedResponse::Type::kConnection, request.request_id,
242 this->GetNextConnectionId(), endpoint_id});
243 bool connecting = delegate_->ConnectToPresentation(
244 request.request_id, presentation_id, endpoint_id);
245 if (connecting)
246 return result;
247
248 responses.pop_back();
249 if (responses.empty())
250 queued_responses_.erase(presentation_id);
251
252 msgs::PresentationConnectionOpenResponse response;
253 response.request_id = request.request_id;
254 response.result =
255 msgs::PresentationConnectionOpenResponse_result::kUnknownError;
256 Error write_error = WritePresentationConnectionOpenResponse(
257 response, GetProtocolConnection(endpoint_id).get());
258 if (!write_error.ok()) {
259 TRACE_SET_RESULT(write_error);
260 return write_error;
261 }
262
263 return result;
264 }
265
266 case msgs::Type::kPresentationTerminationRequest: {
267 TRACE_SCOPED(TraceCategory::kPresentation,
268 "kPresentationTerminationRequest");
269 OSP_VLOG << "got presentation-termination-request";
270 msgs::PresentationTerminationRequest request;
271 const ssize_t result = msgs::DecodePresentationTerminationRequest(
272 buffer, buffer_size, &request);
273 if (result < 0) {
274 OSP_LOG_WARN << "Presentation-termination-request parse error: "
275 << result;
276 TRACE_SET_RESULT(Error::Code::kParseError);
277 return Error::Code::kParseError;
278 }
279
280 PresentationID presentation_id(std::move(request.presentation_id));
281 OSP_LOG_INFO << "Got termination request for: " << presentation_id;
282
283 auto presentation_entry = started_presentations_.find(presentation_id);
284 if (presentation_id &&
285 presentation_entry != started_presentations_.end()) {
286 TerminationReason reason =
287 (request.reason == msgs::PresentationTerminationRequest_reason::
288 kUserTerminatedViaController)
289 ? TerminationReason::kControllerTerminateCalled
290 : TerminationReason::kControllerUserTerminated;
291 presentation_entry->second.terminate_request_id = request.request_id;
292 delegate_->TerminatePresentation(presentation_id, reason);
293
294 msgs::PresentationTerminationResponse response;
295 response.request_id = request.request_id;
296 response.result = msgs::PresentationTerminationResponse_result::
297 kInvalidPresentationId;
298 Error write_error = WritePresentationTerminationResponse(
299 response, GetProtocolConnection(endpoint_id).get());
300 if (!write_error.ok()) {
301 TRACE_SET_RESULT(write_error);
302 return write_error;
303 }
304 return result;
305 }
306
307 TerminationReason reason =
308 (request.reason == msgs::PresentationTerminationRequest_reason::
309 kControllerCalledTerminate)
310 ? TerminationReason::kControllerTerminateCalled
311 : TerminationReason::kControllerUserTerminated;
312 presentation_entry->second.terminate_request_id = request.request_id;
313 delegate_->TerminatePresentation(presentation_id, reason);
314
315 return result;
316 }
317
318 default:
319 TRACE_SET_RESULT(Error::Code::kUnknownMessageType);
320 return Error::Code::kUnknownMessageType;
321 }
322 }
323
324 // TODO(crbug.com/openscreen/31): Remove singletons in the embedder API and
325 // protocol implementation layers and in presentation_connection, as well as
326 // unit tests. static
Get()327 Receiver* Receiver::Get() {
328 static Receiver& receiver = *new Receiver();
329 return &receiver;
330 }
331
Init()332 void Receiver::Init() {
333 if (!connection_manager_) {
334 connection_manager_ =
335 std::make_unique<ConnectionManager>(GetServerDemuxer());
336 }
337 }
338
Deinit()339 void Receiver::Deinit() {
340 connection_manager_.reset();
341 }
342
SetReceiverDelegate(ReceiverDelegate * delegate)343 void Receiver::SetReceiverDelegate(ReceiverDelegate* delegate) {
344 OSP_DCHECK(!delegate_ || !delegate);
345 delegate_ = delegate;
346
347 MessageDemuxer* demuxer = GetServerDemuxer();
348 if (delegate_) {
349 availability_watch_ = demuxer->SetDefaultMessageTypeWatch(
350 msgs::Type::kPresentationUrlAvailabilityRequest, this);
351 initiation_watch_ = demuxer->SetDefaultMessageTypeWatch(
352 msgs::Type::kPresentationStartRequest, this);
353 connection_watch_ = demuxer->SetDefaultMessageTypeWatch(
354 msgs::Type::kPresentationConnectionOpenRequest, this);
355 return;
356 }
357
358 StopWatching(&availability_watch_);
359 StopWatching(&initiation_watch_);
360 StopWatching(&connection_watch_);
361
362 std::vector<std::string> presentations_to_remove(
363 started_presentations_.size());
364 for (auto& it : started_presentations_) {
365 presentations_to_remove.push_back(it.first);
366 }
367
368 for (auto& presentation_id : presentations_to_remove) {
369 OnPresentationTerminated(presentation_id,
370 TerminationReason::kReceiverShuttingDown);
371 }
372 }
373
OnPresentationStarted(const std::string & presentation_id,Connection * connection,ResponseResult result)374 Error Receiver::OnPresentationStarted(const std::string& presentation_id,
375 Connection* connection,
376 ResponseResult result) {
377 auto queued_responses_entry = queued_responses_.find(presentation_id);
378 if (queued_responses_entry == queued_responses_.end())
379 return Error::Code::kNoStartedPresentation;
380
381 auto& responses = queued_responses_entry->second;
382 if ((responses.size() != 1) ||
383 (responses.front().type != QueuedResponse::Type::kInitiation)) {
384 return Error::Code::kPresentationAlreadyStarted;
385 }
386
387 QueuedResponse& initiation_response = responses.front();
388 msgs::PresentationStartResponse response;
389 response.request_id = initiation_response.request_id;
390 auto protocol_connection =
391 GetProtocolConnection(initiation_response.endpoint_id);
392 auto* raw_protocol_connection_ptr = protocol_connection.get();
393
394 OSP_VLOG << "presentation started with protocol_connection id: "
395 << protocol_connection->id();
396 if (result != ResponseResult::kSuccess) {
397 response.result = msgs::PresentationStartResponse_result::kUnknownError;
398
399 queued_responses_.erase(queued_responses_entry);
400 return WritePresentationInitiationResponse(response,
401 raw_protocol_connection_ptr);
402 }
403
404 response.result = msgs::PresentationStartResponse_result::kSuccess;
405 response.connection_id = connection->connection_id();
406
407 Presentation& presentation = started_presentations_[presentation_id];
408 presentation.endpoint_id = initiation_response.endpoint_id;
409 connection->OnConnected(initiation_response.connection_id,
410 initiation_response.endpoint_id,
411 std::move(protocol_connection));
412 presentation.connections.push_back(connection);
413 connection_manager_->AddConnection(connection);
414
415 presentation.terminate_watch = GetServerDemuxer()->WatchMessageType(
416 initiation_response.endpoint_id,
417 msgs::Type::kPresentationTerminationRequest, this);
418
419 queued_responses_.erase(queued_responses_entry);
420 return WritePresentationInitiationResponse(response,
421 raw_protocol_connection_ptr);
422 }
423
OnConnectionCreated(uint64_t request_id,Connection * connection,ResponseResult result)424 Error Receiver::OnConnectionCreated(uint64_t request_id,
425 Connection* connection,
426 ResponseResult result) {
427 const auto presentation_id = connection->presentation_info().id;
428
429 ErrorOr<QueuedResponseIterator> connection_response =
430 GetQueuedResponse(presentation_id, request_id);
431 if (connection_response.is_error()) {
432 return connection_response.error();
433 }
434 connection->OnConnected(
435 connection_response.value()->connection_id,
436 connection_response.value()->endpoint_id,
437 NetworkServiceManager::Get()
438 ->GetProtocolConnectionServer()
439 ->CreateProtocolConnection(connection_response.value()->endpoint_id));
440
441 started_presentations_[presentation_id].connections.push_back(connection);
442 connection_manager_->AddConnection(connection);
443
444 msgs::PresentationConnectionOpenResponse response;
445 response.request_id = request_id;
446 response.result = msgs::PresentationConnectionOpenResponse_result::kSuccess;
447 response.connection_id = connection->connection_id();
448
449 auto protocol_connection =
450 GetProtocolConnection(connection_response.value()->endpoint_id);
451
452 WritePresentationConnectionOpenResponse(response, protocol_connection.get());
453
454 DeleteQueuedResponse(presentation_id, connection_response.value());
455 return Error::None();
456 }
457
CloseConnection(Connection * connection,Connection::CloseReason reason)458 Error Receiver::CloseConnection(Connection* connection,
459 Connection::CloseReason reason) {
460 std::unique_ptr<ProtocolConnection> protocol_connection =
461 GetProtocolConnection(connection->endpoint_id());
462
463 if (!protocol_connection)
464 return Error::Code::kNoActiveConnection;
465
466 msgs::PresentationConnectionCloseEvent event;
467 event.connection_id = connection->connection_id();
468 event.reason = GetEventCloseReason(reason);
469 event.has_error_message = false;
470 msgs::CborEncodeBuffer buffer;
471 return protocol_connection->WriteMessage(
472 event, msgs::EncodePresentationConnectionCloseEvent);
473 }
474
OnPresentationTerminated(const std::string & presentation_id,TerminationReason reason)475 Error Receiver::OnPresentationTerminated(const std::string& presentation_id,
476 TerminationReason reason) {
477 auto presentation_entry = started_presentations_.find(presentation_id);
478 if (presentation_entry == started_presentations_.end())
479 return Error::Code::kNoStartedPresentation;
480
481 Presentation& presentation = presentation_entry->second;
482 presentation.terminate_watch = MessageDemuxer::MessageWatch();
483 std::unique_ptr<ProtocolConnection> protocol_connection =
484 GetProtocolConnection(presentation.endpoint_id);
485
486 if (!protocol_connection)
487 return Error::Code::kNoActiveConnection;
488
489 for (auto* connection : presentation.connections)
490 connection->OnTerminated();
491
492 if (presentation.terminate_request_id) {
493 // TODO(btolsch): Also timeout if this point isn't reached.
494 msgs::PresentationTerminationResponse response;
495 response.request_id = presentation.terminate_request_id;
496 response.result = msgs::PresentationTerminationResponse_result::kSuccess;
497 started_presentations_.erase(presentation_entry);
498 return WritePresentationTerminationResponse(response,
499 protocol_connection.get());
500 }
501
502 msgs::PresentationTerminationEvent event;
503 event.presentation_id = presentation_id;
504 event.reason = GetEventTerminationReason(reason);
505 started_presentations_.erase(presentation_entry);
506 return WritePresentationTerminationEvent(event, protocol_connection.get());
507 }
508
OnConnectionDestroyed(Connection * connection)509 void Receiver::OnConnectionDestroyed(Connection* connection) {
510 auto presentation_entry =
511 started_presentations_.find(connection->presentation_info().id);
512 if (presentation_entry == started_presentations_.end())
513 return;
514
515 std::vector<Connection*>& connections =
516 presentation_entry->second.connections;
517
518 auto past_the_end =
519 std::remove(connections.begin(), connections.end(), connection);
520 // An additional call to "erase" is necessary to actually adjust the size
521 // of the vector.
522 connections.erase(past_the_end, connections.end());
523
524 connection_manager_->RemoveConnection(connection);
525 }
526
527 Receiver::Receiver() = default;
528
529 Receiver::~Receiver() = default;
530
DeleteQueuedResponse(const std::string & presentation_id,Receiver::QueuedResponseIterator response)531 void Receiver::DeleteQueuedResponse(const std::string& presentation_id,
532 Receiver::QueuedResponseIterator response) {
533 auto entry = queued_responses_.find(presentation_id);
534 entry->second.erase(response);
535 if (entry->second.empty())
536 queued_responses_.erase(entry);
537 }
538
GetQueuedResponse(const std::string & presentation_id,uint64_t request_id) const539 ErrorOr<Receiver::QueuedResponseIterator> Receiver::GetQueuedResponse(
540 const std::string& presentation_id,
541 uint64_t request_id) const {
542 auto entry = queued_responses_.find(presentation_id);
543 if (entry == queued_responses_.end()) {
544 OSP_LOG_WARN << "connection created for unknown request";
545 return Error::Code::kUnknownRequestId;
546 }
547
548 const std::vector<QueuedResponse>& responses = entry->second;
549 Receiver::QueuedResponseIterator it =
550 std::find_if(responses.begin(), responses.end(),
551 [request_id](const QueuedResponse& response) {
552 return response.request_id == request_id;
553 });
554
555 if (it == responses.end()) {
556 OSP_LOG_WARN << "connection created for unknown request";
557 return Error::Code::kUnknownRequestId;
558 }
559
560 return it;
561 }
562
GetNextConnectionId()563 uint64_t Receiver::GetNextConnectionId() {
564 static uint64_t request_id = 0;
565 return request_id++;
566 }
567
568 } // namespace osp
569 } // namespace openscreen
570