xref: /aosp_15_r20/external/pigweed/pw_transfer/transfer_thread.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #define PW_LOG_MODULE_NAME "TRN"
16 #define PW_LOG_LEVEL PW_TRANSFER_CONFIG_LOG_LEVEL
17 
18 #include "pw_transfer/transfer_thread.h"
19 
20 #include "pw_assert/check.h"
21 #include "pw_log/log.h"
22 #include "pw_transfer/internal/chunk.h"
23 #include "pw_transfer/internal/client_context.h"
24 #include "pw_transfer/internal/config.h"
25 #include "pw_transfer/internal/event.h"
26 
27 PW_MODIFY_DIAGNOSTICS_PUSH();
28 PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers");
29 
30 namespace pw::transfer::internal {
31 
Terminate()32 void TransferThread::Terminate() {
33   next_event_ownership_.acquire();
34   next_event_.type = EventType::kTerminate;
35   event_notification_.release();
36 }
37 
SimulateTimeout(EventType type,uint32_t session_id)38 void TransferThread::SimulateTimeout(EventType type, uint32_t session_id) {
39   next_event_ownership_.acquire();
40 
41   next_event_.type = type;
42   next_event_.chunk = {};
43   next_event_.chunk.context_identifier = session_id;
44 
45   event_notification_.release();
46 
47   WaitUntilEventIsProcessed();
48 }
49 
Run()50 void TransferThread::Run() {
51   // Next event starts freed.
52   next_event_ownership_.release();
53 
54   while (true) {
55     if (event_notification_.try_acquire_until(GetNextTransferTimeout())) {
56       HandleEvent(next_event_);
57 
58       // Sample event type before we release ownership of next_event_.
59       bool is_terminating = next_event_.type == EventType::kTerminate;
60 
61       // Finished processing the event. Allow the next_event struct to be
62       // overwritten.
63       next_event_ownership_.release();
64 
65       if (is_terminating) {
66         return;
67       }
68     }
69 
70     // Regardless of whether an event was received or not, check for any
71     // transfers which have timed out and process them if so.
72     for (Context& context : client_transfers_) {
73       if (context.timed_out()) {
74         context.HandleEvent({.type = EventType::kClientTimeout});
75       }
76     }
77     for (Context& context : server_transfers_) {
78       if (context.timed_out()) {
79         context.HandleEvent({.type = EventType::kServerTimeout});
80       }
81     }
82   }
83 }
84 
GetNextTransferTimeout() const85 chrono::SystemClock::time_point TransferThread::GetNextTransferTimeout() const {
86   chrono::SystemClock::time_point timeout =
87       chrono::SystemClock::TimePointAfterAtLeast(kMaxTimeout);
88 
89   for (Context& context : client_transfers_) {
90     auto ctx_timeout = context.timeout();
91     if (ctx_timeout.has_value() && ctx_timeout.value() < timeout) {
92       timeout = ctx_timeout.value();
93     }
94   }
95   for (Context& context : server_transfers_) {
96     auto ctx_timeout = context.timeout();
97     if (ctx_timeout.has_value() && ctx_timeout.value() < timeout) {
98       timeout = ctx_timeout.value();
99     }
100   }
101 
102   return timeout;
103 }
104 
StartTransfer(TransferType type,ProtocolVersion version,uint32_t session_id,uint32_t resource_id,uint32_t handle_id,ConstByteSpan raw_chunk,stream::Stream * stream,const TransferParameters & max_parameters,Function<void (Status)> && on_completion,chrono::SystemClock::duration timeout,chrono::SystemClock::duration initial_timeout,uint8_t max_retries,uint32_t max_lifetime_retries,uint32_t initial_offset)105 void TransferThread::StartTransfer(
106     TransferType type,
107     ProtocolVersion version,
108     uint32_t session_id,
109     uint32_t resource_id,
110     uint32_t handle_id,
111     ConstByteSpan raw_chunk,
112     stream::Stream* stream,
113     const TransferParameters& max_parameters,
114     Function<void(Status)>&& on_completion,
115     chrono::SystemClock::duration timeout,
116     chrono::SystemClock::duration initial_timeout,
117     uint8_t max_retries,
118     uint32_t max_lifetime_retries,
119     uint32_t initial_offset) {
120   if (!TryWaitForEventToProcess()) {
121     return;
122   }
123 
124   bool is_client_transfer = stream != nullptr;
125 
126   if (is_client_transfer) {
127     if (version == ProtocolVersion::kLegacy) {
128       session_id = resource_id;
129     } else if (session_id == Context::kUnassignedSessionId) {
130       session_id = AssignSessionId();
131     }
132   }
133 
134   next_event_.type = is_client_transfer ? EventType::kNewClientTransfer
135                                         : EventType::kNewServerTransfer;
136 
137   if (!raw_chunk.empty()) {
138     std::memcpy(chunk_buffer_.data(), raw_chunk.data(), raw_chunk.size());
139   }
140 
141   next_event_.new_transfer = {
142       .type = type,
143       .protocol_version = version,
144       .session_id = session_id,
145       .resource_id = resource_id,
146       .handle_id = handle_id,
147       .max_parameters = &max_parameters,
148       .timeout = timeout,
149       .initial_timeout = initial_timeout,
150       .max_retries = max_retries,
151       .max_lifetime_retries = max_lifetime_retries,
152       .transfer_thread = this,
153       .raw_chunk_data = chunk_buffer_.data(),
154       .raw_chunk_size = raw_chunk.size(),
155       .initial_offset = initial_offset,
156   };
157 
158   staged_on_completion_ = std::move(on_completion);
159 
160   // The transfer is initialized with either a stream (client-side) or a handler
161   // (server-side). If no stream is provided, try to find a registered handler
162   // with the specified ID.
163   if (is_client_transfer) {
164     next_event_.new_transfer.stream = stream;
165     next_event_.new_transfer.rpc_writer =
166         &(type == TransferType::kTransmit ? client_write_stream_
167                                           : client_read_stream_)
168              .as_writer();
169   } else {
170     auto handler = std::find_if(handlers_.begin(),
171                                 handlers_.end(),
172                                 [&](auto& h) { return h.id() == resource_id; });
173     if (handler != handlers_.end()) {
174       next_event_.new_transfer.handler = &*handler;
175       next_event_.new_transfer.rpc_writer =
176           &(type == TransferType::kTransmit ? server_read_stream_
177                                             : server_write_stream_)
178                .as_writer();
179     } else {
180       // No handler exists for the transfer: return a NOT_FOUND.
181       next_event_.type = EventType::kSendStatusChunk;
182       next_event_.send_status_chunk = {
183           .session_id = session_id,
184           .protocol_version = version,
185           .status = Status::NotFound().code(),
186           .stream = type == TransferType::kTransmit
187                         ? TransferStream::kServerRead
188                         : TransferStream::kServerWrite,
189       };
190     }
191   }
192 
193   event_notification_.release();
194 }
195 
ProcessChunk(EventType type,ConstByteSpan chunk)196 void TransferThread::ProcessChunk(EventType type, ConstByteSpan chunk) {
197   // If this assert is hit, there is a bug in the transfer implementation.
198   // Contexts' max_chunk_size_bytes fields should be set based on the size of
199   // chunk_buffer_.
200   PW_CHECK(chunk.size() <= chunk_buffer_.size(),
201            "Transfer received a larger chunk than it can handle.");
202 
203   Result<Chunk::Identifier> identifier = Chunk::ExtractIdentifier(chunk);
204   if (!identifier.ok()) {
205     PW_LOG_ERROR("Received a malformed chunk without a context identifier");
206     return;
207   }
208 
209   if (!TryWaitForEventToProcess()) {
210     return;
211   }
212 
213   std::memcpy(chunk_buffer_.data(), chunk.data(), chunk.size());
214 
215   next_event_.type = type;
216   next_event_.chunk = {
217       .context_identifier = identifier->value(),
218       .match_resource_id = identifier->is_legacy(),
219       .data = chunk_buffer_.data(),
220       .size = chunk.size(),
221   };
222 
223   event_notification_.release();
224 }
225 
SendStatus(TransferStream stream,uint32_t session_id,ProtocolVersion version,Status status)226 void TransferThread::SendStatus(TransferStream stream,
227                                 uint32_t session_id,
228                                 ProtocolVersion version,
229                                 Status status) {
230   if (!TryWaitForEventToProcess()) {
231     return;
232   }
233 
234   next_event_.type = EventType::kSendStatusChunk;
235   next_event_.send_status_chunk = {
236       .session_id = session_id,
237       .protocol_version = version,
238       .status = status.code(),
239       .stream = stream,
240   };
241 
242   event_notification_.release();
243 }
244 
EndTransfer(EventType type,IdentifierType id_type,uint32_t id,Status status,bool send_status_chunk)245 void TransferThread::EndTransfer(EventType type,
246                                  IdentifierType id_type,
247                                  uint32_t id,
248                                  Status status,
249                                  bool send_status_chunk) {
250   if (!TryWaitForEventToProcess()) {
251     return;
252   }
253 
254   next_event_.type = type;
255   next_event_.end_transfer = {
256       .id_type = id_type,
257       .id = id,
258       .status = status.code(),
259       .send_status_chunk = send_status_chunk,
260   };
261 
262   event_notification_.release();
263 }
264 
SetStream(TransferStream stream)265 void TransferThread::SetStream(TransferStream stream) {
266   if (!TryWaitForEventToProcess()) {
267     return;
268   }
269 
270   next_event_.type = EventType::kSetStream;
271   next_event_.set_stream = {
272       .stream = stream,
273   };
274 
275   event_notification_.release();
276 }
277 
UpdateClientTransfer(uint32_t handle_id,size_t transfer_size_bytes)278 void TransferThread::UpdateClientTransfer(uint32_t handle_id,
279                                           size_t transfer_size_bytes) {
280   if (!TryWaitForEventToProcess()) {
281     return;
282   }
283 
284   next_event_.type = EventType::kUpdateClientTransfer;
285   next_event_.update_transfer.handle_id = handle_id;
286   next_event_.update_transfer.transfer_size_bytes = transfer_size_bytes;
287 
288   event_notification_.release();
289 }
290 
TransferHandlerEvent(EventType type,Handler & handler)291 bool TransferThread::TransferHandlerEvent(EventType type, Handler& handler) {
292   if (!TryWaitForEventToProcess()) {
293     return false;
294   }
295 
296   next_event_.type = type;
297   if (type == EventType::kAddTransferHandler) {
298     next_event_.add_transfer_handler = &handler;
299   } else {
300     next_event_.remove_transfer_handler = &handler;
301   }
302 
303   event_notification_.release();
304   return true;
305 }
306 
HandleEvent(const internal::Event & event)307 void TransferThread::HandleEvent(const internal::Event& event) {
308   switch (event.type) {
309     case EventType::kTerminate:
310       // Terminate server contexts.
311       for (ServerContext& server_context : server_transfers_) {
312         server_context.HandleEvent(Event{
313             .type = EventType::kServerEndTransfer,
314             .end_transfer =
315                 EndTransferEvent{
316                     .id_type = IdentifierType::Session,
317                     .id = server_context.session_id(),
318                     .status = Status::Aborted().code(),
319                     .send_status_chunk = false,
320                 },
321         });
322       }
323 
324       // Terminate client contexts.
325       for (ClientContext& client_context : client_transfers_) {
326         client_context.HandleEvent(Event{
327             .type = EventType::kClientEndTransfer,
328             .end_transfer =
329                 EndTransferEvent{
330                     .id_type = IdentifierType::Session,
331                     .id = client_context.session_id(),
332                     .status = Status::Aborted().code(),
333                     .send_status_chunk = false,
334                 },
335         });
336       }
337 
338       // Cancel/Finish streams.
339       client_read_stream_.Cancel().IgnoreError();
340       client_write_stream_.Cancel().IgnoreError();
341       server_read_stream_.Finish(Status::Aborted()).IgnoreError();
342       server_write_stream_.Finish(Status::Aborted()).IgnoreError();
343       return;
344 
345     case EventType::kSendStatusChunk:
346       SendStatusChunk(event.send_status_chunk);
347       break;
348 
349     case EventType::kAddTransferHandler:
350       handlers_.push_front(*event.add_transfer_handler);
351       return;
352 
353     case EventType::kRemoveTransferHandler:
354       for (ServerContext& server_context : server_transfers_) {
355         if (server_context.handler() == event.remove_transfer_handler) {
356           server_context.HandleEvent(Event{
357               .type = EventType::kServerEndTransfer,
358               .end_transfer =
359                   EndTransferEvent{
360                       .id_type = IdentifierType::Session,
361                       .id = server_context.session_id(),
362                       .status = Status::Aborted().code(),
363                       .send_status_chunk = false,
364                   },
365           });
366         }
367       }
368       handlers_.remove(*event.remove_transfer_handler);
369       return;
370 
371     case EventType::kSetStream:
372       HandleSetStreamEvent(event.set_stream.stream);
373       return;
374 
375     case EventType::kGetResourceStatus:
376       GetResourceState(event.resource_status.resource_id);
377       return;
378 
379     case EventType::kNewClientTransfer:
380     case EventType::kNewServerTransfer:
381     case EventType::kClientChunk:
382     case EventType::kServerChunk:
383     case EventType::kClientTimeout:
384     case EventType::kServerTimeout:
385     case EventType::kClientEndTransfer:
386     case EventType::kServerEndTransfer:
387     case EventType::kUpdateClientTransfer:
388     default:
389       // Other events are handled by individual transfer contexts.
390       break;
391   }
392 
393   Context* ctx = FindContextForEvent(event);
394   if (ctx == nullptr) {
395     // No context was found. For new transfer events, report a
396     // RESOURCE_EXHAUSTED error with starting the transfer.
397     if (event.type == EventType::kNewClientTransfer) {
398       // On the client, invoke the completion callback directly.
399       staged_on_completion_(Status::ResourceExhausted());
400     } else if (event.type == EventType::kNewServerTransfer) {
401       // On the server, send a status chunk back to the client.
402       SendStatusChunk(
403           {.session_id = event.new_transfer.session_id,
404            .protocol_version = event.new_transfer.protocol_version,
405            .status = Status::ResourceExhausted().code(),
406            .stream = event.new_transfer.type == TransferType::kTransmit
407                          ? TransferStream::kServerRead
408                          : TransferStream::kServerWrite});
409     }
410     return;
411   }
412 
413   if (event.type == EventType::kNewClientTransfer) {
414     // TODO(frolv): This is terrible.
415     ClientContext* cctx = static_cast<ClientContext*>(ctx);
416     cctx->set_on_completion(std::move(staged_on_completion_));
417     cctx->set_handle_id(event.new_transfer.handle_id);
418   }
419 
420   if (event.type == EventType::kUpdateClientTransfer) {
421     static_cast<ClientContext&>(*ctx).set_transfer_size_bytes(
422         event.update_transfer.transfer_size_bytes);
423     return;
424   }
425 
426   ctx->HandleEvent(event);
427 }
428 
FindContextForEvent(const internal::Event & event) const429 Context* TransferThread::FindContextForEvent(
430     const internal::Event& event) const {
431   switch (event.type) {
432     case EventType::kNewClientTransfer:
433       return FindNewTransfer(client_transfers_, event.new_transfer.session_id);
434     case EventType::kNewServerTransfer:
435       return FindNewTransfer(server_transfers_, event.new_transfer.session_id);
436 
437     case EventType::kClientChunk:
438       if (event.chunk.match_resource_id) {
439         return FindActiveTransferByResourceId(client_transfers_,
440                                               event.chunk.context_identifier);
441       }
442       return FindActiveTransferByLegacyId(client_transfers_,
443                                           event.chunk.context_identifier);
444 
445     case EventType::kServerChunk:
446       if (event.chunk.match_resource_id) {
447         return FindActiveTransferByResourceId(server_transfers_,
448                                               event.chunk.context_identifier);
449       }
450       return FindActiveTransferByLegacyId(server_transfers_,
451                                           event.chunk.context_identifier);
452 
453     case EventType::kClientTimeout:  // Manually triggered client timeout
454       return FindActiveTransferByLegacyId(client_transfers_,
455                                           event.chunk.context_identifier);
456     case EventType::kServerTimeout:  // Manually triggered server timeout
457       return FindActiveTransferByLegacyId(server_transfers_,
458                                           event.chunk.context_identifier);
459 
460     case EventType::kClientEndTransfer:
461       if (event.end_transfer.id_type == IdentifierType::Handle) {
462         return FindClientTransferByHandleId(event.end_transfer.id);
463       }
464       return FindActiveTransferByLegacyId(client_transfers_,
465                                           event.end_transfer.id);
466     case EventType::kServerEndTransfer:
467       PW_DCHECK(event.end_transfer.id_type != IdentifierType::Handle);
468       return FindActiveTransferByLegacyId(server_transfers_,
469                                           event.end_transfer.id);
470 
471     case EventType::kUpdateClientTransfer:
472       return FindClientTransferByHandleId(event.update_transfer.handle_id);
473 
474     case EventType::kSendStatusChunk:
475     case EventType::kAddTransferHandler:
476     case EventType::kRemoveTransferHandler:
477     case EventType::kSetStream:
478     case EventType::kTerminate:
479     case EventType::kGetResourceStatus:
480     default:
481       return nullptr;
482   }
483 }
484 
SendStatusChunk(const internal::SendStatusChunkEvent & event)485 void TransferThread::SendStatusChunk(
486     const internal::SendStatusChunkEvent& event) {
487   rpc::Writer& destination = stream_for(event.stream);
488 
489   Chunk chunk =
490       Chunk::Final(event.protocol_version, event.session_id, event.status);
491 
492   Result<ConstByteSpan> result = chunk.Encode(chunk_buffer_);
493   if (!result.ok()) {
494     PW_LOG_ERROR("Failed to encode final chunk for transfer %u",
495                  static_cast<unsigned>(event.session_id));
496     return;
497   }
498 
499   if (!destination.Write(result.value()).ok()) {
500     PW_LOG_ERROR("Failed to send final chunk for transfer %u",
501                  static_cast<unsigned>(event.session_id));
502     return;
503   }
504 }
505 
506 // Should only be called with the `next_event_ownership_` lock held.
AssignSessionId()507 uint32_t TransferThread::AssignSessionId() {
508   uint32_t session_id = next_session_id_++;
509   if (session_id == 0) {
510     session_id = next_session_id_++;
511   }
512   return session_id;
513 }
514 
515 template <typename T>
TerminateTransfers(span<T> contexts,TransferType type,EventType event_type,Status status)516 void TerminateTransfers(span<T> contexts,
517                         TransferType type,
518                         EventType event_type,
519                         Status status) {
520   for (Context& context : contexts) {
521     if (context.active() && context.type() == type) {
522       context.HandleEvent(Event{
523           .type = event_type,
524           .end_transfer =
525               EndTransferEvent{
526                   .id_type = IdentifierType::Session,
527                   .id = context.session_id(),
528                   .status = status.code(),
529                   .send_status_chunk = false,
530               },
531       });
532     }
533   }
534 }
535 
HandleSetStreamEvent(TransferStream stream)536 void TransferThread::HandleSetStreamEvent(TransferStream stream) {
537   switch (stream) {
538     case TransferStream::kClientRead:
539       TerminateTransfers(client_transfers_,
540                          TransferType::kReceive,
541                          EventType::kClientEndTransfer,
542                          Status::Aborted());
543       client_read_stream_ = std::move(staged_client_stream_);
544       client_read_stream_.set_on_next(std::move(staged_client_on_next_));
545       client_read_stream_.set_on_error([](Status status) {
546         PW_LOG_WARN("Client read stream closed unexpectedly: %s", status.str());
547       });
548       break;
549     case TransferStream::kClientWrite:
550       TerminateTransfers(client_transfers_,
551                          TransferType::kTransmit,
552                          EventType::kClientEndTransfer,
553                          Status::Aborted());
554       client_write_stream_ = std::move(staged_client_stream_);
555       client_write_stream_.set_on_next(std::move(staged_client_on_next_));
556       client_write_stream_.set_on_error([](Status status) {
557         PW_LOG_WARN("Client write stream closed unexpectedly: %s",
558                     status.str());
559       });
560       break;
561     case TransferStream::kServerRead:
562       TerminateTransfers(server_transfers_,
563                          TransferType::kTransmit,
564                          EventType::kServerEndTransfer,
565                          Status::Aborted());
566       server_read_stream_ = std::move(staged_server_stream_);
567       server_read_stream_.set_on_next(std::move(staged_server_on_next_));
568       server_read_stream_.set_on_error([](Status status) {
569         PW_LOG_WARN("Server read stream closed unexpectedly: %s", status.str());
570       });
571       break;
572     case TransferStream::kServerWrite:
573       TerminateTransfers(server_transfers_,
574                          TransferType::kReceive,
575                          EventType::kServerEndTransfer,
576                          Status::Aborted());
577       server_write_stream_ = std::move(staged_server_stream_);
578       server_write_stream_.set_on_next(std::move(staged_server_on_next_));
579       server_write_stream_.set_on_error([](Status status) {
580         PW_LOG_WARN("Server write stream closed unexpectedly: %s",
581                     status.str());
582       });
583       break;
584   }
585 }
586 
587 // Adds GetResourceStatusEvent to the queue. Will fail if there is already a
588 // GetResourceStatusEvent in process.
EnqueueResourceEvent(uint32_t resource_id,ResourceStatusCallback && callback)589 void TransferThread::EnqueueResourceEvent(uint32_t resource_id,
590                                           ResourceStatusCallback&& callback) {
591   if (!TryWaitForEventToProcess()) {
592     return;
593   }
594 
595   next_event_.type = EventType::kGetResourceStatus;
596 
597   resource_status_callback_ = std::move(callback);
598 
599   next_event_.resource_status.resource_id = resource_id;
600 
601   event_notification_.release();
602 }
603 
604 // Should only be called when we got a valid callback and RPC responder from
605 // GetResourceStatus transfer RPC.
GetResourceState(uint32_t resource_id)606 void TransferThread::GetResourceState(uint32_t resource_id) {
607   PW_ASSERT(resource_status_callback_ != nullptr);
608 
609   auto handler = std::find_if(handlers_.begin(), handlers_.end(), [&](auto& h) {
610     return h.id() == resource_id;
611   });
612   internal::ResourceStatus stats;
613   stats.resource_id = resource_id;
614 
615   if (handler != handlers_.end()) {
616     Status status = handler->GetStatus(stats.readable_offset,
617                                        stats.writeable_offset,
618                                        stats.read_checksum,
619                                        stats.write_checksum);
620 
621     resource_status_callback_(status, stats);
622   } else {
623     resource_status_callback_(Status::NotFound(), stats);
624   }
625 }
626 
627 }  // namespace pw::transfer::internal
628 
629 PW_MODIFY_DIAGNOSTICS_POP();
630