xref: /aosp_15_r20/external/pigweed/pw_transfer/chunk.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #define PW_LOG_MODULE_NAME "TRN"
16 #define PW_LOG_LEVEL PW_TRANSFER_CONFIG_LOG_LEVEL
17 
18 #include "pw_transfer/internal/chunk.h"
19 
20 #include "pw_assert/check.h"
21 #include "pw_log/log.h"
22 #include "pw_log/rate_limited.h"
23 #include "pw_protobuf/decoder.h"
24 #include "pw_protobuf/serialized_size.h"
25 #include "pw_status/try.h"
26 #include "pw_transfer/internal/config.h"
27 
28 namespace pw::transfer::internal {
29 
30 namespace ProtoChunk = transfer::pwpb::Chunk;
31 
ExtractIdentifier(ConstByteSpan message)32 Result<Chunk::Identifier> Chunk::ExtractIdentifier(ConstByteSpan message) {
33   protobuf::Decoder decoder(message);
34 
35   uint32_t session_id = 0;
36   uint32_t desired_session_id = 0;
37   bool legacy = true;
38 
39   Status status;
40 
41   while ((status = decoder.Next()).ok()) {
42     ProtoChunk::Fields field =
43         static_cast<ProtoChunk::Fields>(decoder.FieldNumber());
44 
45     if (field == ProtoChunk::Fields::kTransferId) {
46       // Interpret a legacy transfer_id field as a session ID if an explicit
47       // session_id field has not already been seen.
48       if (session_id == 0) {
49         PW_TRY(decoder.ReadUint32(&session_id));
50       }
51     } else if (field == ProtoChunk::Fields::kSessionId) {
52       // A session_id field always takes precedence over transfer_id.
53       PW_TRY(decoder.ReadUint32(&session_id));
54       legacy = false;
55     } else if (field == ProtoChunk::Fields::kDesiredSessionId) {
56       PW_TRY(decoder.ReadUint32(&desired_session_id));
57     }
58   }
59 
60   if (!status.IsOutOfRange()) {
61     return Status::DataLoss();
62   }
63 
64   if (desired_session_id != 0) {
65     // Can't have both a desired and regular session_id.
66     if (!legacy && session_id != 0) {
67       return Status::DataLoss();
68     }
69     return Identifier::Desired(desired_session_id);
70   }
71 
72   if (session_id != 0) {
73     return legacy ? Identifier::Legacy(session_id)
74                   : Identifier::Session(session_id);
75   }
76 
77   return Status::DataLoss();
78 }
79 
Parse(ConstByteSpan message)80 Result<Chunk> Chunk::Parse(ConstByteSpan message) {
81   protobuf::Decoder decoder(message);
82   Status status;
83   uint32_t value;
84 
85   Chunk chunk;
86 
87   // Determine the protocol version of the chunk depending on field presence in
88   // the serialized message.
89   chunk.protocol_version_ = ProtocolVersion::kUnknown;
90 
91   // Some older versions of the protocol set the deprecated pending_bytes field
92   // in their chunks. The newer transfer handling code does not process this
93   // field, instead working only in terms of window_end_offset. If pending_bytes
94   // is encountered in the serialized message, save its value, then calculate
95   // window_end_offset from it once parsing is complete.
96   uint32_t pending_bytes = 0;
97 
98   bool has_session_id = false;
99 
100   while ((status = decoder.Next()).ok()) {
101     ProtoChunk::Fields field =
102         static_cast<ProtoChunk::Fields>(decoder.FieldNumber());
103 
104     switch (field) {
105       case ProtoChunk::Fields::kTransferId:
106         // transfer_id is a legacy field. session_id will always take precedence
107         // over it, so it should only be read if session_id has not yet been
108         // encountered.
109         if (chunk.session_id_ == 0) {
110           PW_TRY(decoder.ReadUint32(&chunk.session_id_));
111         }
112         break;
113 
114       case ProtoChunk::Fields::kSessionId:
115         // The existence of a session_id field indicates that a newer protocol
116         // is running. Update the deduced protocol unless it was explicitly
117         // specified.
118         if (chunk.protocol_version_ == ProtocolVersion::kUnknown) {
119           chunk.protocol_version_ = ProtocolVersion::kVersionTwo;
120         }
121         has_session_id = true;
122         PW_TRY(decoder.ReadUint32(&chunk.session_id_));
123         break;
124 
125       case ProtoChunk::Fields::kPendingBytes:
126         PW_TRY(decoder.ReadUint32(&pending_bytes));
127         break;
128 
129       case ProtoChunk::Fields::kMaxChunkSizeBytes:
130         PW_TRY(decoder.ReadUint32(&value));
131         chunk.set_max_chunk_size_bytes(value);
132         break;
133 
134       case ProtoChunk::Fields::kMinDelayMicroseconds:
135         PW_TRY(decoder.ReadUint32(&value));
136         chunk.set_min_delay_microseconds(value);
137         break;
138 
139       case ProtoChunk::Fields::kOffset:
140         PW_TRY(decoder.ReadUint32(&chunk.offset_));
141         break;
142 
143       case ProtoChunk::Fields::kData:
144         PW_TRY(decoder.ReadBytes(&chunk.payload_));
145         break;
146 
147       case ProtoChunk::Fields::kRemainingBytes: {
148         uint64_t remaining_bytes;
149         PW_TRY(decoder.ReadUint64(&remaining_bytes));
150         chunk.set_remaining_bytes(remaining_bytes);
151         break;
152       }
153 
154       case ProtoChunk::Fields::kStatus:
155         PW_TRY(decoder.ReadUint32(&value));
156         chunk.set_status(static_cast<Status::Code>(value));
157         break;
158 
159       case ProtoChunk::Fields::kWindowEndOffset:
160         PW_TRY(decoder.ReadUint32(&chunk.window_end_offset_));
161         break;
162 
163       case ProtoChunk::Fields::kType: {
164         uint32_t type;
165         PW_TRY(decoder.ReadUint32(&type));
166         chunk.type_ = static_cast<Chunk::Type>(type);
167         break;
168       }
169 
170       case ProtoChunk::Fields::kResourceId:
171         PW_TRY(decoder.ReadUint32(&value));
172         chunk.set_resource_id(value);
173         break;
174 
175       case ProtoChunk::Fields::kProtocolVersion:
176         // The protocol_version field is added as part of the initial handshake
177         // starting from version 2. If provided, it should override any deduced
178         // protocol version.
179         PW_TRY(decoder.ReadUint32(&value));
180         if (!ValidProtocolVersion(value)) {
181           return Status::DataLoss();
182         }
183         chunk.protocol_version_ = static_cast<ProtocolVersion>(value);
184         break;
185 
186       case ProtoChunk::Fields::kDesiredSessionId:
187         PW_TRY(decoder.ReadUint32(&value));
188         chunk.desired_session_id_ = value;
189         break;
190 
191       case ProtoChunk::Fields::kInitialOffset:
192         PW_TRY(decoder.ReadUint32(&value));
193         chunk.set_initial_offset(value);
194         break;
195 
196         // Silently ignore any unrecognized fields.
197     }
198   }
199 
200   if (chunk.desired_session_id_.has_value() && has_session_id) {
201     // Setting both session_id and desired_session_id is not permitted.
202     return Status::DataLoss();
203   }
204 
205   if (chunk.protocol_version_ == ProtocolVersion::kUnknown) {
206     // If no fields in the chunk specified its protocol version, assume it is a
207     // legacy chunk.
208     chunk.protocol_version_ = ProtocolVersion::kLegacy;
209   }
210 
211   if (pending_bytes != 0) {
212     // Compute window_end_offset if it isn't explicitly provided (in older
213     // protocol versions).
214     chunk.set_window_end_offset(chunk.offset() + pending_bytes);
215   }
216 
217   if (status.ok() || status.IsOutOfRange()) {
218     return chunk;
219   }
220 
221   return status;
222 }
223 
Encode(ByteSpan buffer) const224 Result<ConstByteSpan> Chunk::Encode(ByteSpan buffer) const {
225   PW_CHECK(protocol_version_ != ProtocolVersion::kUnknown,
226            "Cannot encode a transfer chunk with an unknown protocol version");
227 
228   ProtoChunk::MemoryEncoder encoder(buffer);
229 
230   // Write the payload first to avoid clobbering it if it shares the same buffer
231   // as the encode buffer.
232   if (has_payload()) {
233     encoder.WriteData(payload_).IgnoreError();
234   }
235 
236   if (protocol_version_ >= ProtocolVersion::kVersionTwo) {
237     if (session_id_ != 0) {
238       PW_CHECK(!desired_session_id_.has_value(),
239                "A chunk cannot set both a desired and regular session ID");
240       encoder.WriteSessionId(session_id_).IgnoreError();
241     }
242 
243     if (desired_session_id_.has_value()) {
244       encoder.WriteDesiredSessionId(desired_session_id_.value()).IgnoreError();
245     }
246 
247     if (resource_id_.has_value()) {
248       encoder.WriteResourceId(resource_id_.value()).IgnoreError();
249     }
250   }
251 
252   // During the initial handshake, the chunk's configured protocol version is
253   // explicitly serialized to the wire.
254   if (IsInitialHandshakeChunk()) {
255     encoder.WriteProtocolVersion(static_cast<uint32_t>(protocol_version_))
256         .IgnoreError();
257   }
258 
259   if (type_.has_value()) {
260     encoder.WriteType(static_cast<ProtoChunk::Type>(type_.value()))
261         .IgnoreError();
262   }
263 
264   if (window_end_offset_ != 0) {
265     encoder.WriteWindowEndOffset(window_end_offset_).IgnoreError();
266   }
267 
268   // Encode additional fields from the legacy protocol.
269   if (ShouldEncodeLegacyFields()) {
270     // The legacy protocol uses the transfer_id field instead of session_id or
271     // resource_id.
272     if (resource_id_.has_value()) {
273       encoder.WriteTransferId(resource_id_.value()).IgnoreError();
274     } else {
275       encoder.WriteTransferId(session_id_).IgnoreError();
276     }
277 
278     // In the legacy protocol, the pending_bytes field must be set alongside
279     // window_end_offset, as some transfer implementations require it.
280     if (window_end_offset_ != 0) {
281       encoder.WritePendingBytes(window_end_offset_ - offset_).IgnoreError();
282     }
283   }
284 
285   if (max_chunk_size_bytes_.has_value()) {
286     encoder.WriteMaxChunkSizeBytes(max_chunk_size_bytes_.value()).IgnoreError();
287   }
288   if (min_delay_microseconds_.has_value()) {
289     encoder.WriteMinDelayMicroseconds(min_delay_microseconds_.value())
290         .IgnoreError();
291   }
292 
293   if (offset_ != 0) {
294     encoder.WriteOffset(offset_).IgnoreError();
295   }
296 
297   if (initial_offset_ != 0) {
298     encoder.WriteInitialOffset(initial_offset_).IgnoreError();
299   }
300 
301   if (remaining_bytes_.has_value()) {
302     encoder.WriteRemainingBytes(remaining_bytes_.value()).IgnoreError();
303   }
304 
305   if (status_.has_value()) {
306     encoder.WriteStatus(status_.value().code()).IgnoreError();
307   }
308 
309   PW_TRY(encoder.status());
310   return ConstByteSpan(encoder);
311 }
312 
EncodedSize() const313 size_t Chunk::EncodedSize() const {
314   size_t size = 0;
315 
316   if (session_id_ != 0) {
317     if (protocol_version_ >= ProtocolVersion::kVersionTwo) {
318       size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kSessionId,
319                                           session_id_);
320     }
321 
322     if (ShouldEncodeLegacyFields()) {
323       if (resource_id_.has_value()) {
324         size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kTransferId,
325                                             resource_id_.value());
326       } else {
327         size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kTransferId,
328                                             session_id_);
329       }
330     }
331   }
332 
333   if (IsInitialHandshakeChunk()) {
334     size +=
335         protobuf::SizeOfVarintField(ProtoChunk::Fields::kProtocolVersion,
336                                     static_cast<uint32_t>(protocol_version_));
337   }
338 
339   if (protocol_version_ >= ProtocolVersion::kVersionTwo) {
340     if (resource_id_.has_value()) {
341       size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kResourceId,
342                                           resource_id_.value());
343     }
344     if (desired_session_id_.has_value()) {
345       size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kDesiredSessionId,
346                                           desired_session_id_.value());
347     }
348   }
349 
350   if (offset_ != 0) {
351     size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kOffset, offset_);
352   }
353 
354   if (window_end_offset_ != 0) {
355     size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kWindowEndOffset,
356                                         window_end_offset_);
357 
358     if (ShouldEncodeLegacyFields()) {
359       size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kPendingBytes,
360                                           window_end_offset_ - offset_);
361     }
362   }
363 
364   if (type_.has_value()) {
365     size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kType,
366                                         static_cast<uint32_t>(type_.value()));
367   }
368 
369   if (has_payload()) {
370     size += protobuf::SizeOfDelimitedField(ProtoChunk::Fields::kData,
371                                            payload_.size());
372   }
373 
374   if (max_chunk_size_bytes_.has_value()) {
375     size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kMaxChunkSizeBytes,
376                                         max_chunk_size_bytes_.value());
377   }
378 
379   if (min_delay_microseconds_.has_value()) {
380     size +=
381         protobuf::SizeOfVarintField(ProtoChunk::Fields::kMinDelayMicroseconds,
382                                     min_delay_microseconds_.value());
383   }
384 
385   if (remaining_bytes_.has_value()) {
386     size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kRemainingBytes,
387                                         remaining_bytes_.value());
388   }
389 
390   if (status_.has_value()) {
391     size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kStatus,
392                                         status_.value().code());
393   }
394 
395   return size;
396 }
397 
LogChunk(bool received,pw::chrono::SystemClock::duration rate_limit) const398 void Chunk::LogChunk(bool received,
399                      pw::chrono::SystemClock::duration rate_limit) const {
400   // Log in two different spots so the rate limiting applies separately to sent
401   // and received
402   if (received) {
403     PW_LOG_EVERY_N_DURATION(
404         PW_LOG_LEVEL_DEBUG,
405         rate_limit,
406         "Chunk received, type: %u, session id: %u, protocol version: %u,\n"
407         "resource id: %d, desired session id: %d, offset: %u, size: %u,\n"
408         "window end offset: %u, remaining bytes: %d, status: %d",
409         type_.has_value() ? static_cast<unsigned>(type_.value()) : 0,
410         static_cast<unsigned>(session_id_),
411         static_cast<unsigned>(protocol_version_),
412         resource_id_.has_value() ? static_cast<unsigned>(resource_id_.value())
413                                  : -1,
414         desired_session_id_.has_value()
415             ? static_cast<int>(desired_session_id_.value())
416             : -1,
417         static_cast<unsigned>(offset_),
418         has_payload() ? static_cast<unsigned>(payload_.size()) : 0,
419         static_cast<unsigned>(window_end_offset_),
420         remaining_bytes_.has_value()
421             ? static_cast<unsigned>(remaining_bytes_.value())
422             : -1,
423         status_.has_value() ? static_cast<unsigned>(status_.value().code())
424                             : -1);
425   } else {
426     PW_LOG_EVERY_N_DURATION(
427         PW_LOG_LEVEL_DEBUG,
428         rate_limit,
429         "Chunk sent, type: %u, session id: %u, protocol version: %u,\n"
430         "resource id: %d, desired session id: %d, offset: %u, size: %u,\n"
431         "window end offset: %u, remaining bytes: %d, status: %d",
432         type_.has_value() ? static_cast<unsigned>(type_.value()) : 0,
433         static_cast<unsigned>(session_id_),
434         static_cast<unsigned>(protocol_version_),
435         resource_id_.has_value() ? static_cast<unsigned>(resource_id_.value())
436                                  : -1,
437         desired_session_id_.has_value()
438             ? static_cast<int>(desired_session_id_.value())
439             : -1,
440         static_cast<unsigned>(offset_),
441         has_payload() ? static_cast<unsigned>(payload_.size()) : 0,
442         static_cast<unsigned>(window_end_offset_),
443         remaining_bytes_.has_value()
444             ? static_cast<unsigned>(remaining_bytes_.value())
445             : -1,
446         status_.has_value() ? static_cast<unsigned>(status_.value().code())
447                             : -1);
448   }
449 }
450 
451 }  // namespace pw::transfer::internal
452