1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #define PW_LOG_MODULE_NAME "TRN"
16 #define PW_LOG_LEVEL PW_TRANSFER_CONFIG_LOG_LEVEL
17
18 #include "pw_transfer/transfer_thread.h"
19
20 #include "pw_assert/check.h"
21 #include "pw_log/log.h"
22 #include "pw_transfer/internal/chunk.h"
23 #include "pw_transfer/internal/client_context.h"
24 #include "pw_transfer/internal/config.h"
25 #include "pw_transfer/internal/event.h"
26
27 PW_MODIFY_DIAGNOSTICS_PUSH();
28 PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers");
29
30 namespace pw::transfer::internal {
31
Terminate()32 void TransferThread::Terminate() {
33 next_event_ownership_.acquire();
34 next_event_.type = EventType::kTerminate;
35 event_notification_.release();
36 }
37
SimulateTimeout(EventType type,uint32_t session_id)38 void TransferThread::SimulateTimeout(EventType type, uint32_t session_id) {
39 next_event_ownership_.acquire();
40
41 next_event_.type = type;
42 next_event_.chunk = {};
43 next_event_.chunk.context_identifier = session_id;
44
45 event_notification_.release();
46
47 WaitUntilEventIsProcessed();
48 }
49
Run()50 void TransferThread::Run() {
51 // Next event starts freed.
52 next_event_ownership_.release();
53
54 while (true) {
55 if (event_notification_.try_acquire_until(GetNextTransferTimeout())) {
56 HandleEvent(next_event_);
57
58 // Sample event type before we release ownership of next_event_.
59 bool is_terminating = next_event_.type == EventType::kTerminate;
60
61 // Finished processing the event. Allow the next_event struct to be
62 // overwritten.
63 next_event_ownership_.release();
64
65 if (is_terminating) {
66 return;
67 }
68 }
69
70 // Regardless of whether an event was received or not, check for any
71 // transfers which have timed out and process them if so.
72 for (Context& context : client_transfers_) {
73 if (context.timed_out()) {
74 context.HandleEvent({.type = EventType::kClientTimeout});
75 }
76 }
77 for (Context& context : server_transfers_) {
78 if (context.timed_out()) {
79 context.HandleEvent({.type = EventType::kServerTimeout});
80 }
81 }
82 }
83 }
84
GetNextTransferTimeout() const85 chrono::SystemClock::time_point TransferThread::GetNextTransferTimeout() const {
86 chrono::SystemClock::time_point timeout =
87 chrono::SystemClock::TimePointAfterAtLeast(kMaxTimeout);
88
89 for (Context& context : client_transfers_) {
90 auto ctx_timeout = context.timeout();
91 if (ctx_timeout.has_value() && ctx_timeout.value() < timeout) {
92 timeout = ctx_timeout.value();
93 }
94 }
95 for (Context& context : server_transfers_) {
96 auto ctx_timeout = context.timeout();
97 if (ctx_timeout.has_value() && ctx_timeout.value() < timeout) {
98 timeout = ctx_timeout.value();
99 }
100 }
101
102 return timeout;
103 }
104
StartTransfer(TransferType type,ProtocolVersion version,uint32_t session_id,uint32_t resource_id,uint32_t handle_id,ConstByteSpan raw_chunk,stream::Stream * stream,const TransferParameters & max_parameters,Function<void (Status)> && on_completion,chrono::SystemClock::duration timeout,chrono::SystemClock::duration initial_timeout,uint8_t max_retries,uint32_t max_lifetime_retries,uint32_t initial_offset)105 void TransferThread::StartTransfer(
106 TransferType type,
107 ProtocolVersion version,
108 uint32_t session_id,
109 uint32_t resource_id,
110 uint32_t handle_id,
111 ConstByteSpan raw_chunk,
112 stream::Stream* stream,
113 const TransferParameters& max_parameters,
114 Function<void(Status)>&& on_completion,
115 chrono::SystemClock::duration timeout,
116 chrono::SystemClock::duration initial_timeout,
117 uint8_t max_retries,
118 uint32_t max_lifetime_retries,
119 uint32_t initial_offset) {
120 if (!TryWaitForEventToProcess()) {
121 return;
122 }
123
124 bool is_client_transfer = stream != nullptr;
125
126 if (is_client_transfer) {
127 if (version == ProtocolVersion::kLegacy) {
128 session_id = resource_id;
129 } else if (session_id == Context::kUnassignedSessionId) {
130 session_id = AssignSessionId();
131 }
132 }
133
134 next_event_.type = is_client_transfer ? EventType::kNewClientTransfer
135 : EventType::kNewServerTransfer;
136
137 if (!raw_chunk.empty()) {
138 std::memcpy(chunk_buffer_.data(), raw_chunk.data(), raw_chunk.size());
139 }
140
141 next_event_.new_transfer = {
142 .type = type,
143 .protocol_version = version,
144 .session_id = session_id,
145 .resource_id = resource_id,
146 .handle_id = handle_id,
147 .max_parameters = &max_parameters,
148 .timeout = timeout,
149 .initial_timeout = initial_timeout,
150 .max_retries = max_retries,
151 .max_lifetime_retries = max_lifetime_retries,
152 .transfer_thread = this,
153 .raw_chunk_data = chunk_buffer_.data(),
154 .raw_chunk_size = raw_chunk.size(),
155 .initial_offset = initial_offset,
156 };
157
158 staged_on_completion_ = std::move(on_completion);
159
160 // The transfer is initialized with either a stream (client-side) or a handler
161 // (server-side). If no stream is provided, try to find a registered handler
162 // with the specified ID.
163 if (is_client_transfer) {
164 next_event_.new_transfer.stream = stream;
165 next_event_.new_transfer.rpc_writer =
166 &(type == TransferType::kTransmit ? client_write_stream_
167 : client_read_stream_)
168 .as_writer();
169 } else {
170 auto handler = std::find_if(handlers_.begin(),
171 handlers_.end(),
172 [&](auto& h) { return h.id() == resource_id; });
173 if (handler != handlers_.end()) {
174 next_event_.new_transfer.handler = &*handler;
175 next_event_.new_transfer.rpc_writer =
176 &(type == TransferType::kTransmit ? server_read_stream_
177 : server_write_stream_)
178 .as_writer();
179 } else {
180 // No handler exists for the transfer: return a NOT_FOUND.
181 next_event_.type = EventType::kSendStatusChunk;
182 next_event_.send_status_chunk = {
183 .session_id = session_id,
184 .protocol_version = version,
185 .status = Status::NotFound().code(),
186 .stream = type == TransferType::kTransmit
187 ? TransferStream::kServerRead
188 : TransferStream::kServerWrite,
189 };
190 }
191 }
192
193 event_notification_.release();
194 }
195
ProcessChunk(EventType type,ConstByteSpan chunk)196 void TransferThread::ProcessChunk(EventType type, ConstByteSpan chunk) {
197 // If this assert is hit, there is a bug in the transfer implementation.
198 // Contexts' max_chunk_size_bytes fields should be set based on the size of
199 // chunk_buffer_.
200 PW_CHECK(chunk.size() <= chunk_buffer_.size(),
201 "Transfer received a larger chunk than it can handle.");
202
203 Result<Chunk::Identifier> identifier = Chunk::ExtractIdentifier(chunk);
204 if (!identifier.ok()) {
205 PW_LOG_ERROR("Received a malformed chunk without a context identifier");
206 return;
207 }
208
209 if (!TryWaitForEventToProcess()) {
210 return;
211 }
212
213 std::memcpy(chunk_buffer_.data(), chunk.data(), chunk.size());
214
215 next_event_.type = type;
216 next_event_.chunk = {
217 .context_identifier = identifier->value(),
218 .match_resource_id = identifier->is_legacy(),
219 .data = chunk_buffer_.data(),
220 .size = chunk.size(),
221 };
222
223 event_notification_.release();
224 }
225
SendStatus(TransferStream stream,uint32_t session_id,ProtocolVersion version,Status status)226 void TransferThread::SendStatus(TransferStream stream,
227 uint32_t session_id,
228 ProtocolVersion version,
229 Status status) {
230 if (!TryWaitForEventToProcess()) {
231 return;
232 }
233
234 next_event_.type = EventType::kSendStatusChunk;
235 next_event_.send_status_chunk = {
236 .session_id = session_id,
237 .protocol_version = version,
238 .status = status.code(),
239 .stream = stream,
240 };
241
242 event_notification_.release();
243 }
244
EndTransfer(EventType type,IdentifierType id_type,uint32_t id,Status status,bool send_status_chunk)245 void TransferThread::EndTransfer(EventType type,
246 IdentifierType id_type,
247 uint32_t id,
248 Status status,
249 bool send_status_chunk) {
250 if (!TryWaitForEventToProcess()) {
251 return;
252 }
253
254 next_event_.type = type;
255 next_event_.end_transfer = {
256 .id_type = id_type,
257 .id = id,
258 .status = status.code(),
259 .send_status_chunk = send_status_chunk,
260 };
261
262 event_notification_.release();
263 }
264
SetStream(TransferStream stream)265 void TransferThread::SetStream(TransferStream stream) {
266 if (!TryWaitForEventToProcess()) {
267 return;
268 }
269
270 next_event_.type = EventType::kSetStream;
271 next_event_.set_stream = {
272 .stream = stream,
273 };
274
275 event_notification_.release();
276 }
277
UpdateClientTransfer(uint32_t handle_id,size_t transfer_size_bytes)278 void TransferThread::UpdateClientTransfer(uint32_t handle_id,
279 size_t transfer_size_bytes) {
280 if (!TryWaitForEventToProcess()) {
281 return;
282 }
283
284 next_event_.type = EventType::kUpdateClientTransfer;
285 next_event_.update_transfer.handle_id = handle_id;
286 next_event_.update_transfer.transfer_size_bytes = transfer_size_bytes;
287
288 event_notification_.release();
289 }
290
TransferHandlerEvent(EventType type,Handler & handler)291 bool TransferThread::TransferHandlerEvent(EventType type, Handler& handler) {
292 if (!TryWaitForEventToProcess()) {
293 return false;
294 }
295
296 next_event_.type = type;
297 if (type == EventType::kAddTransferHandler) {
298 next_event_.add_transfer_handler = &handler;
299 } else {
300 next_event_.remove_transfer_handler = &handler;
301 }
302
303 event_notification_.release();
304 return true;
305 }
306
HandleEvent(const internal::Event & event)307 void TransferThread::HandleEvent(const internal::Event& event) {
308 switch (event.type) {
309 case EventType::kTerminate:
310 // Terminate server contexts.
311 for (ServerContext& server_context : server_transfers_) {
312 server_context.HandleEvent(Event{
313 .type = EventType::kServerEndTransfer,
314 .end_transfer =
315 EndTransferEvent{
316 .id_type = IdentifierType::Session,
317 .id = server_context.session_id(),
318 .status = Status::Aborted().code(),
319 .send_status_chunk = false,
320 },
321 });
322 }
323
324 // Terminate client contexts.
325 for (ClientContext& client_context : client_transfers_) {
326 client_context.HandleEvent(Event{
327 .type = EventType::kClientEndTransfer,
328 .end_transfer =
329 EndTransferEvent{
330 .id_type = IdentifierType::Session,
331 .id = client_context.session_id(),
332 .status = Status::Aborted().code(),
333 .send_status_chunk = false,
334 },
335 });
336 }
337
338 // Cancel/Finish streams.
339 client_read_stream_.Cancel().IgnoreError();
340 client_write_stream_.Cancel().IgnoreError();
341 server_read_stream_.Finish(Status::Aborted()).IgnoreError();
342 server_write_stream_.Finish(Status::Aborted()).IgnoreError();
343 return;
344
345 case EventType::kSendStatusChunk:
346 SendStatusChunk(event.send_status_chunk);
347 break;
348
349 case EventType::kAddTransferHandler:
350 handlers_.push_front(*event.add_transfer_handler);
351 return;
352
353 case EventType::kRemoveTransferHandler:
354 for (ServerContext& server_context : server_transfers_) {
355 if (server_context.handler() == event.remove_transfer_handler) {
356 server_context.HandleEvent(Event{
357 .type = EventType::kServerEndTransfer,
358 .end_transfer =
359 EndTransferEvent{
360 .id_type = IdentifierType::Session,
361 .id = server_context.session_id(),
362 .status = Status::Aborted().code(),
363 .send_status_chunk = false,
364 },
365 });
366 }
367 }
368 handlers_.remove(*event.remove_transfer_handler);
369 return;
370
371 case EventType::kSetStream:
372 HandleSetStreamEvent(event.set_stream.stream);
373 return;
374
375 case EventType::kGetResourceStatus:
376 GetResourceState(event.resource_status.resource_id);
377 return;
378
379 case EventType::kNewClientTransfer:
380 case EventType::kNewServerTransfer:
381 case EventType::kClientChunk:
382 case EventType::kServerChunk:
383 case EventType::kClientTimeout:
384 case EventType::kServerTimeout:
385 case EventType::kClientEndTransfer:
386 case EventType::kServerEndTransfer:
387 case EventType::kUpdateClientTransfer:
388 default:
389 // Other events are handled by individual transfer contexts.
390 break;
391 }
392
393 Context* ctx = FindContextForEvent(event);
394 if (ctx == nullptr) {
395 // No context was found. For new transfer events, report a
396 // RESOURCE_EXHAUSTED error with starting the transfer.
397 if (event.type == EventType::kNewClientTransfer) {
398 // On the client, invoke the completion callback directly.
399 staged_on_completion_(Status::ResourceExhausted());
400 } else if (event.type == EventType::kNewServerTransfer) {
401 // On the server, send a status chunk back to the client.
402 SendStatusChunk(
403 {.session_id = event.new_transfer.session_id,
404 .protocol_version = event.new_transfer.protocol_version,
405 .status = Status::ResourceExhausted().code(),
406 .stream = event.new_transfer.type == TransferType::kTransmit
407 ? TransferStream::kServerRead
408 : TransferStream::kServerWrite});
409 }
410 return;
411 }
412
413 if (event.type == EventType::kNewClientTransfer) {
414 // TODO(frolv): This is terrible.
415 ClientContext* cctx = static_cast<ClientContext*>(ctx);
416 cctx->set_on_completion(std::move(staged_on_completion_));
417 cctx->set_handle_id(event.new_transfer.handle_id);
418 }
419
420 if (event.type == EventType::kUpdateClientTransfer) {
421 static_cast<ClientContext&>(*ctx).set_transfer_size_bytes(
422 event.update_transfer.transfer_size_bytes);
423 return;
424 }
425
426 ctx->HandleEvent(event);
427 }
428
FindContextForEvent(const internal::Event & event) const429 Context* TransferThread::FindContextForEvent(
430 const internal::Event& event) const {
431 switch (event.type) {
432 case EventType::kNewClientTransfer:
433 return FindNewTransfer(client_transfers_, event.new_transfer.session_id);
434 case EventType::kNewServerTransfer:
435 return FindNewTransfer(server_transfers_, event.new_transfer.session_id);
436
437 case EventType::kClientChunk:
438 if (event.chunk.match_resource_id) {
439 return FindActiveTransferByResourceId(client_transfers_,
440 event.chunk.context_identifier);
441 }
442 return FindActiveTransferByLegacyId(client_transfers_,
443 event.chunk.context_identifier);
444
445 case EventType::kServerChunk:
446 if (event.chunk.match_resource_id) {
447 return FindActiveTransferByResourceId(server_transfers_,
448 event.chunk.context_identifier);
449 }
450 return FindActiveTransferByLegacyId(server_transfers_,
451 event.chunk.context_identifier);
452
453 case EventType::kClientTimeout: // Manually triggered client timeout
454 return FindActiveTransferByLegacyId(client_transfers_,
455 event.chunk.context_identifier);
456 case EventType::kServerTimeout: // Manually triggered server timeout
457 return FindActiveTransferByLegacyId(server_transfers_,
458 event.chunk.context_identifier);
459
460 case EventType::kClientEndTransfer:
461 if (event.end_transfer.id_type == IdentifierType::Handle) {
462 return FindClientTransferByHandleId(event.end_transfer.id);
463 }
464 return FindActiveTransferByLegacyId(client_transfers_,
465 event.end_transfer.id);
466 case EventType::kServerEndTransfer:
467 PW_DCHECK(event.end_transfer.id_type != IdentifierType::Handle);
468 return FindActiveTransferByLegacyId(server_transfers_,
469 event.end_transfer.id);
470
471 case EventType::kUpdateClientTransfer:
472 return FindClientTransferByHandleId(event.update_transfer.handle_id);
473
474 case EventType::kSendStatusChunk:
475 case EventType::kAddTransferHandler:
476 case EventType::kRemoveTransferHandler:
477 case EventType::kSetStream:
478 case EventType::kTerminate:
479 case EventType::kGetResourceStatus:
480 default:
481 return nullptr;
482 }
483 }
484
SendStatusChunk(const internal::SendStatusChunkEvent & event)485 void TransferThread::SendStatusChunk(
486 const internal::SendStatusChunkEvent& event) {
487 rpc::Writer& destination = stream_for(event.stream);
488
489 Chunk chunk =
490 Chunk::Final(event.protocol_version, event.session_id, event.status);
491
492 Result<ConstByteSpan> result = chunk.Encode(chunk_buffer_);
493 if (!result.ok()) {
494 PW_LOG_ERROR("Failed to encode final chunk for transfer %u",
495 static_cast<unsigned>(event.session_id));
496 return;
497 }
498
499 if (!destination.Write(result.value()).ok()) {
500 PW_LOG_ERROR("Failed to send final chunk for transfer %u",
501 static_cast<unsigned>(event.session_id));
502 return;
503 }
504 }
505
506 // Should only be called with the `next_event_ownership_` lock held.
AssignSessionId()507 uint32_t TransferThread::AssignSessionId() {
508 uint32_t session_id = next_session_id_++;
509 if (session_id == 0) {
510 session_id = next_session_id_++;
511 }
512 return session_id;
513 }
514
515 template <typename T>
TerminateTransfers(span<T> contexts,TransferType type,EventType event_type,Status status)516 void TerminateTransfers(span<T> contexts,
517 TransferType type,
518 EventType event_type,
519 Status status) {
520 for (Context& context : contexts) {
521 if (context.active() && context.type() == type) {
522 context.HandleEvent(Event{
523 .type = event_type,
524 .end_transfer =
525 EndTransferEvent{
526 .id_type = IdentifierType::Session,
527 .id = context.session_id(),
528 .status = status.code(),
529 .send_status_chunk = false,
530 },
531 });
532 }
533 }
534 }
535
HandleSetStreamEvent(TransferStream stream)536 void TransferThread::HandleSetStreamEvent(TransferStream stream) {
537 switch (stream) {
538 case TransferStream::kClientRead:
539 TerminateTransfers(client_transfers_,
540 TransferType::kReceive,
541 EventType::kClientEndTransfer,
542 Status::Aborted());
543 client_read_stream_ = std::move(staged_client_stream_);
544 client_read_stream_.set_on_next(std::move(staged_client_on_next_));
545 client_read_stream_.set_on_error([](Status status) {
546 PW_LOG_WARN("Client read stream closed unexpectedly: %s", status.str());
547 });
548 break;
549 case TransferStream::kClientWrite:
550 TerminateTransfers(client_transfers_,
551 TransferType::kTransmit,
552 EventType::kClientEndTransfer,
553 Status::Aborted());
554 client_write_stream_ = std::move(staged_client_stream_);
555 client_write_stream_.set_on_next(std::move(staged_client_on_next_));
556 client_write_stream_.set_on_error([](Status status) {
557 PW_LOG_WARN("Client write stream closed unexpectedly: %s",
558 status.str());
559 });
560 break;
561 case TransferStream::kServerRead:
562 TerminateTransfers(server_transfers_,
563 TransferType::kTransmit,
564 EventType::kServerEndTransfer,
565 Status::Aborted());
566 server_read_stream_ = std::move(staged_server_stream_);
567 server_read_stream_.set_on_next(std::move(staged_server_on_next_));
568 server_read_stream_.set_on_error([](Status status) {
569 PW_LOG_WARN("Server read stream closed unexpectedly: %s", status.str());
570 });
571 break;
572 case TransferStream::kServerWrite:
573 TerminateTransfers(server_transfers_,
574 TransferType::kReceive,
575 EventType::kServerEndTransfer,
576 Status::Aborted());
577 server_write_stream_ = std::move(staged_server_stream_);
578 server_write_stream_.set_on_next(std::move(staged_server_on_next_));
579 server_write_stream_.set_on_error([](Status status) {
580 PW_LOG_WARN("Server write stream closed unexpectedly: %s",
581 status.str());
582 });
583 break;
584 }
585 }
586
587 // Adds GetResourceStatusEvent to the queue. Will fail if there is already a
588 // GetResourceStatusEvent in process.
EnqueueResourceEvent(uint32_t resource_id,ResourceStatusCallback && callback)589 void TransferThread::EnqueueResourceEvent(uint32_t resource_id,
590 ResourceStatusCallback&& callback) {
591 if (!TryWaitForEventToProcess()) {
592 return;
593 }
594
595 next_event_.type = EventType::kGetResourceStatus;
596
597 resource_status_callback_ = std::move(callback);
598
599 next_event_.resource_status.resource_id = resource_id;
600
601 event_notification_.release();
602 }
603
604 // Should only be called when we got a valid callback and RPC responder from
605 // GetResourceStatus transfer RPC.
GetResourceState(uint32_t resource_id)606 void TransferThread::GetResourceState(uint32_t resource_id) {
607 PW_ASSERT(resource_status_callback_ != nullptr);
608
609 auto handler = std::find_if(handlers_.begin(), handlers_.end(), [&](auto& h) {
610 return h.id() == resource_id;
611 });
612 internal::ResourceStatus stats;
613 stats.resource_id = resource_id;
614
615 if (handler != handlers_.end()) {
616 Status status = handler->GetStatus(stats.readable_offset,
617 stats.writeable_offset,
618 stats.read_checksum,
619 stats.write_checksum);
620
621 resource_status_callback_(status, stats);
622 } else {
623 resource_status_callback_(Status::NotFound(), stats);
624 }
625 }
626
627 } // namespace pw::transfer::internal
628
629 PW_MODIFY_DIAGNOSTICS_POP();
630