xref: /aosp_15_r20/external/pigweed/pw_transfer/public/pw_transfer/transfer_thread.h (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <cstdint>
17 
18 #include "pw_assert/assert.h"
19 #include "pw_bytes/span.h"
20 #include "pw_chrono/system_clock.h"
21 #include "pw_function/function.h"
22 #include "pw_rpc/raw/client_reader_writer.h"
23 #include "pw_rpc/raw/server_reader_writer.h"
24 #include "pw_span/span.h"
25 #include "pw_sync/binary_semaphore.h"
26 #include "pw_sync/timed_thread_notification.h"
27 #include "pw_thread/thread_core.h"
28 #include "pw_transfer/handler.h"
29 #include "pw_transfer/internal/client_context.h"
30 #include "pw_transfer/internal/config.h"
31 #include "pw_transfer/internal/context.h"
32 #include "pw_transfer/internal/event.h"
33 #include "pw_transfer/internal/server_context.h"
34 
35 namespace pw::transfer {
36 
37 class Client;
38 
39 namespace internal {
40 
41 class TransferThread : public thread::ThreadCore {
42  public:
TransferThread(span<ClientContext> client_transfers,span<ServerContext> server_transfers,ByteSpan chunk_buffer,ByteSpan encode_buffer)43   TransferThread(span<ClientContext> client_transfers,
44                  span<ServerContext> server_transfers,
45                  ByteSpan chunk_buffer,
46                  ByteSpan encode_buffer)
47       : client_transfers_(client_transfers),
48         server_transfers_(server_transfers),
49         next_session_id_(1),
50         chunk_buffer_(chunk_buffer),
51         encode_buffer_(encode_buffer) {}
52 
53   void StartClientTransfer(TransferType type,
54                            ProtocolVersion version,
55                            uint32_t resource_id,
56                            uint32_t handle_id,
57                            stream::Stream* stream,
58                            const TransferParameters& max_parameters,
59                            Function<void(Status)>&& on_completion,
60                            chrono::SystemClock::duration timeout,
61                            chrono::SystemClock::duration initial_timeout,
62                            uint8_t max_retries,
63                            uint32_t max_lifetime_retries,
64                            uint32_t initial_offset = 0) {
65     StartTransfer(type,
66                   version,
67                   Context::kUnassignedSessionId,  // Assigned later.
68                   resource_id,
69                   handle_id,
70                   /*raw_chunk=*/{},
71                   stream,
72                   max_parameters,
73                   std::move(on_completion),
74                   timeout,
75                   initial_timeout,
76                   max_retries,
77                   max_lifetime_retries,
78                   initial_offset);
79   }
80 
81   void StartServerTransfer(TransferType type,
82                            ProtocolVersion version,
83                            uint32_t session_id,
84                            uint32_t resource_id,
85                            ConstByteSpan raw_chunk,
86                            const TransferParameters& max_parameters,
87                            chrono::SystemClock::duration timeout,
88                            uint8_t max_retries,
89                            uint32_t max_lifetime_retries,
90                            uint32_t initial_offset = 0) {
91     StartTransfer(type,
92                   version,
93                   session_id,
94                   resource_id,
95                   /*handle_id=*/0,
96                   raw_chunk,
97                   /*stream=*/nullptr,
98                   max_parameters,
99                   /*on_completion=*/nullptr,
100                   timeout,
101                   timeout,
102                   max_retries,
103                   max_lifetime_retries,
104                   initial_offset);
105   }
106 
ProcessClientChunk(ConstByteSpan chunk)107   void ProcessClientChunk(ConstByteSpan chunk) {
108     ProcessChunk(EventType::kClientChunk, chunk);
109   }
110 
ProcessServerChunk(ConstByteSpan chunk)111   void ProcessServerChunk(ConstByteSpan chunk) {
112     ProcessChunk(EventType::kServerChunk, chunk);
113   }
114 
SendServerStatus(TransferType type,uint32_t session_id,ProtocolVersion version,Status status)115   void SendServerStatus(TransferType type,
116                         uint32_t session_id,
117                         ProtocolVersion version,
118                         Status status) {
119     SendStatus(type == TransferType::kTransmit ? TransferStream::kServerRead
120                                                : TransferStream::kServerWrite,
121                session_id,
122                version,
123                status);
124   }
125 
CancelClientTransfer(uint32_t handle_id)126   void CancelClientTransfer(uint32_t handle_id) {
127     EndTransfer(EventType::kClientEndTransfer,
128                 IdentifierType::Handle,
129                 handle_id,
130                 Status::Cancelled(),
131                 /*send_status_chunk=*/true);
132   }
133 
134   void EndClientTransfer(uint32_t session_id,
135                          Status status,
136                          bool send_status_chunk = false) {
137     EndTransfer(EventType::kClientEndTransfer,
138                 IdentifierType::Session,
139                 session_id,
140                 status,
141                 send_status_chunk);
142   }
143 
144   void EndServerTransfer(uint32_t session_id,
145                          Status status,
146                          bool send_status_chunk = false) {
147     EndTransfer(EventType::kServerEndTransfer,
148                 IdentifierType::Session,
149                 session_id,
150                 status,
151                 send_status_chunk);
152   }
153 
154   /// Updates the transfer thread's client read stream.
155   ///
156   /// The provided stream should not have an on_next function set. Instead,
157   /// on_next is passed separately to ensure that it is only set when the new
158   /// stream becomes the transfer thread's primary stream.
159   ///
160   /// If the thread has an existing active client read stream, closes it and
161   /// terminates any transfers running on it.
SetClientReadStream(rpc::RawClientReaderWriter & read_stream,Function<void (ConstByteSpan)> && on_next)162   void SetClientReadStream(rpc::RawClientReaderWriter& read_stream,
163                            Function<void(ConstByteSpan)>&& on_next) {
164     // Clear the existing callback to prevent incoming chunks from blocking on
165     // the transfer thread and preventing the call's cleanup.
166     client_read_stream_.set_on_next(nullptr);
167     staged_client_stream_ = std::move(read_stream);
168     staged_client_on_next_ = std::move(on_next);
169     SetStream(TransferStream::kClientRead);
170   }
171 
172   /// Updates the transfer thread's client write stream.
173   ///
174   /// The provided stream should not have an on_next function set. Instead,
175   /// on_next is passed separately to ensure that it is only set when the new
176   /// stream becomes the transfer thread's primary stream.
177   ///
178   /// If the thread has an existing active client write stream, closes it and
179   /// terminates any transfers running on it.
SetClientWriteStream(rpc::RawClientReaderWriter & write_stream,Function<void (ConstByteSpan)> && on_next)180   void SetClientWriteStream(rpc::RawClientReaderWriter& write_stream,
181                             Function<void(ConstByteSpan)>&& on_next) {
182     // Clear the existing callback to prevent incoming chunks from blocking on
183     // the transfer thread and preventing the call's cleanup.
184     client_write_stream_.set_on_next(nullptr);
185     staged_client_stream_ = std::move(write_stream);
186     staged_client_on_next_ = std::move(on_next);
187     SetStream(TransferStream::kClientWrite);
188   }
189 
190   /// Updates the transfer thread's server read stream.
191   ///
192   /// The provided stream should not have an on_next function set. Instead,
193   /// on_next is passed separately to ensure that it is only set when the new
194   /// stream becomes the transfer thread's primary stream.
195   ///
196   /// If the thread has an existing active server read stream, closes it and
197   /// terminates any transfers running on it.
SetServerReadStream(rpc::RawServerReaderWriter & read_stream,Function<void (ConstByteSpan)> && on_next)198   void SetServerReadStream(rpc::RawServerReaderWriter& read_stream,
199                            Function<void(ConstByteSpan)>&& on_next) {
200     // Clear the existing callback to prevent incoming chunks from blocking on
201     // the transfer thread and preventing the call's cleanup.
202     server_read_stream_.set_on_next(nullptr);
203     staged_server_stream_ = std::move(read_stream);
204     staged_server_on_next_ = std::move(on_next);
205     SetStream(TransferStream::kServerRead);
206   }
207 
208   /// Updates the transfer thread's server write stream.
209   ///
210   /// The provided stream should not have an on_next function set. Instead,
211   /// on_next is passed separately to ensure that it is only set when the new
212   /// stream becomes the transfer thread's primary stream.
213   ///
214   /// If the thread has an existing active server write stream, closes it and
215   /// terminates any transfers running on it.
SetServerWriteStream(rpc::RawServerReaderWriter & write_stream,Function<void (ConstByteSpan)> && on_next)216   void SetServerWriteStream(rpc::RawServerReaderWriter& write_stream,
217                             Function<void(ConstByteSpan)>&& on_next) {
218     // Clear the existing callback to prevent incoming chunks from blocking on
219     // the transfer thread and preventing the call's cleanup.
220     server_write_stream_.set_on_next(nullptr);
221     staged_server_stream_ = std::move(write_stream);
222     staged_server_on_next_ = std::move(on_next);
223     SetStream(TransferStream::kServerWrite);
224   }
225 
AddTransferHandler(Handler & handler)226   bool AddTransferHandler(Handler& handler) {
227     return TransferHandlerEvent(EventType::kAddTransferHandler, handler);
228   }
229 
RemoveTransferHandler(Handler & handler)230   bool RemoveTransferHandler(Handler& handler) {
231     if (!TransferHandlerEvent(EventType::kRemoveTransferHandler, handler)) {
232       return false;
233     }
234     // Ensure this function blocks until the transfer handler is fully cleaned
235     // up.
236     WaitUntilEventIsProcessed();
237     return true;
238   }
239 
max_chunk_size()240   size_t max_chunk_size() const { return chunk_buffer_.size(); }
241 
242   // For testing only: terminates the transfer thread with a kTerminate event.
243   void Terminate();
244 
245   // For testing only: blocks until the next event can be acquired, which means
246   // a previously enqueued event has been processed.
WaitUntilEventIsProcessed()247   void WaitUntilEventIsProcessed() {
248     next_event_ownership_.acquire();
249     next_event_ownership_.release();
250   }
251 
252   // For testing only: simulates a timeout event for a client transfer.
SimulateClientTimeout(uint32_t session_id)253   void SimulateClientTimeout(uint32_t session_id) {
254     SimulateTimeout(EventType::kClientTimeout, session_id);
255   }
256 
257   // For testing only: simulates a timeout event for a server transfer.
SimulateServerTimeout(uint32_t session_id)258   void SimulateServerTimeout(uint32_t session_id) {
259     SimulateTimeout(EventType::kServerTimeout, session_id);
260   }
261 
262   void EnqueueResourceEvent(uint32_t resource_id,
263                             ResourceStatusCallback&& callback);
264 
265  private:
266   friend class transfer::Client;
267   friend class Context;
268 
269   // Maximum amount of time between transfer thread runs.
270   static constexpr chrono::SystemClock::duration kMaxTimeout =
271       std::chrono::seconds(2);
272 
273   void UpdateClientTransfer(uint32_t handle_id, size_t transfer_size_bytes);
274 
275   // Finds an active server or client transfer, matching against its legacy ID.
276   template <typename T>
FindActiveTransferByLegacyId(const span<T> & transfers,uint32_t session_id)277   static Context* FindActiveTransferByLegacyId(const span<T>& transfers,
278                                                uint32_t session_id) {
279     auto transfer =
280         std::find_if(transfers.begin(), transfers.end(), [session_id](auto& c) {
281           return c.initialized() && c.session_id() == session_id;
282         });
283     return transfer != transfers.end() ? &*transfer : nullptr;
284   }
285 
286   // Finds an active server or client transfer, matching against resource ID.
287   template <typename T>
FindActiveTransferByResourceId(const span<T> & transfers,uint32_t resource_id)288   static Context* FindActiveTransferByResourceId(const span<T>& transfers,
289                                                  uint32_t resource_id) {
290     auto transfer = std::find_if(
291         transfers.begin(), transfers.end(), [resource_id](auto& c) {
292           return c.initialized() && c.resource_id() == resource_id;
293         });
294     return transfer != transfers.end() ? &*transfer : nullptr;
295   }
296 
FindClientTransferByHandleId(uint32_t handle_id)297   Context* FindClientTransferByHandleId(uint32_t handle_id) const {
298     auto transfer =
299         std::find_if(client_transfers_.begin(),
300                      client_transfers_.end(),
301                      [handle_id](auto& c) {
302                        return c.initialized() && c.handle_id() == handle_id;
303                      });
304     return transfer != client_transfers_.end() ? &*transfer : nullptr;
305   }
306 
307   void SimulateTimeout(EventType type, uint32_t session_id);
308 
309   // Finds an new server or client transfer.
310   template <typename T>
FindNewTransfer(const span<T> & transfers,uint32_t session_id)311   static Context* FindNewTransfer(const span<T>& transfers,
312                                   uint32_t session_id) {
313     Context* new_transfer = nullptr;
314 
315     for (Context& context : transfers) {
316       if (context.active()) {
317         if (context.session_id() == session_id) {
318           // Restart an already active transfer.
319           return &context;
320         }
321       } else {
322         // Store the inactive context as an option, but keep checking for the
323         // restart case.
324         new_transfer = &context;
325       }
326     }
327 
328     return new_transfer;
329   }
330 
encode_buffer()331   const ByteSpan& encode_buffer() const { return encode_buffer_; }
332 
333   void Run() final;
334 
335   void HandleTimeouts();
336 
stream_for(TransferStream stream)337   rpc::Writer& stream_for(TransferStream stream) {
338     switch (stream) {
339       case TransferStream::kClientRead:
340         return client_read_stream_.as_writer();
341       case TransferStream::kClientWrite:
342         return client_write_stream_.as_writer();
343       case TransferStream::kServerRead:
344         return server_read_stream_.as_writer();
345       case TransferStream::kServerWrite:
346         return server_write_stream_.as_writer();
347     }
348     // An unknown TransferStream value was passed, which means this function
349     // was passed an invalid enum value.
350     PW_ASSERT(false);
351   }
352 
TryWaitForEventToProcess()353   bool TryWaitForEventToProcess() {
354     if constexpr (cfg::kWaitForEventProcessingIndefinitely) {
355       next_event_ownership_.acquire();
356       return true;
357     }
358     return next_event_ownership_.try_acquire_for(cfg::kEventProcessingTimeout);
359   }
360 
361   // Returns the earliest timeout among all active transfers, up to kMaxTimeout.
362   chrono::SystemClock::time_point GetNextTransferTimeout() const;
363 
364   uint32_t AssignSessionId();
365 
366   void StartTransfer(TransferType type,
367                      ProtocolVersion version,
368                      uint32_t session_id,
369                      uint32_t resource_id,
370                      uint32_t handle_id,
371                      ConstByteSpan raw_chunk,
372                      stream::Stream* stream,
373                      const TransferParameters& max_parameters,
374                      Function<void(Status)>&& on_completion,
375                      chrono::SystemClock::duration timeout,
376                      chrono::SystemClock::duration initial_timeout,
377                      uint8_t max_retries,
378                      uint32_t max_lifetime_retries,
379                      uint32_t initial_offset);
380 
381   void ProcessChunk(EventType type, ConstByteSpan chunk);
382 
383   void SendStatus(TransferStream stream,
384                   uint32_t session_id,
385                   ProtocolVersion version,
386                   Status status);
387 
388   void EndTransfer(EventType type,
389                    IdentifierType id_type,
390                    uint32_t session_id,
391                    Status status,
392                    bool send_status_chunk);
393 
394   void SetStream(TransferStream stream);
395   void HandleSetStreamEvent(TransferStream stream);
396 
397   bool TransferHandlerEvent(EventType type, Handler& handler);
398 
399   void HandleEvent(const Event& event);
400   Context* FindContextForEvent(const Event& event) const;
401 
402   void SendStatusChunk(const SendStatusChunkEvent& event);
403 
404   void GetResourceState(uint32_t resource_id);
405 
406   sync::TimedThreadNotification event_notification_;
407   sync::BinarySemaphore next_event_ownership_;
408 
409   Event next_event_;
410   Function<void(Status)> staged_on_completion_;
411 
412   rpc::RawClientReaderWriter client_read_stream_;
413   rpc::RawClientReaderWriter client_write_stream_;
414   rpc::RawClientReaderWriter staged_client_stream_;
415   Function<void(ConstByteSpan)> staged_client_on_next_;
416 
417   rpc::RawServerReaderWriter server_read_stream_;
418   rpc::RawServerReaderWriter server_write_stream_;
419   rpc::RawServerReaderWriter staged_server_stream_;
420   Function<void(ConstByteSpan)> staged_server_on_next_;
421 
422   span<ClientContext> client_transfers_;
423   span<ServerContext> server_transfers_;
424 
425   // Identifier to use for the next started transfer, unique over the RPC
426   // channel between the transfer client and server.
427   //
428   // TODO(frolv): If we ever support changing the RPC channel, this should be
429   // reset to 1.
430   uint32_t next_session_id_;
431 
432   // All registered transfer handlers.
433   IntrusiveList<Handler> handlers_;
434 
435   // Buffer in which chunk data is staged for CHUNK events.
436   ByteSpan chunk_buffer_;
437 
438   // Buffer into which responses are encoded. Only ever used from within the
439   // transfer thread, so no locking is required.
440   ByteSpan encode_buffer_;
441 
442   ResourceStatusCallback resource_status_callback_ = nullptr;
443 };
444 
445 }  // namespace internal
446 
447 using TransferThread = internal::TransferThread;
448 
449 template <size_t kMaxConcurrentClientTransfers,
450           size_t kMaxConcurrentServerTransfers>
451 class Thread final : public internal::TransferThread {
452  public:
Thread(ByteSpan chunk_buffer,ByteSpan encode_buffer)453   Thread(ByteSpan chunk_buffer, ByteSpan encode_buffer)
454       : internal::TransferThread(
455             client_contexts_, server_contexts_, chunk_buffer, encode_buffer) {}
456 
457  private:
458   std::array<internal::ClientContext, kMaxConcurrentClientTransfers>
459       client_contexts_;
460   std::array<internal::ServerContext, kMaxConcurrentServerTransfers>
461       server_contexts_;
462 };
463 
464 }  // namespace pw::transfer
465