1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #define PW_LOG_MODULE_NAME "TRN"
16 #define PW_LOG_LEVEL PW_TRANSFER_CONFIG_LOG_LEVEL
17
18 #include "pw_transfer/internal/chunk.h"
19
20 #include "pw_assert/check.h"
21 #include "pw_log/log.h"
22 #include "pw_log/rate_limited.h"
23 #include "pw_protobuf/decoder.h"
24 #include "pw_protobuf/serialized_size.h"
25 #include "pw_status/try.h"
26 #include "pw_transfer/internal/config.h"
27
28 namespace pw::transfer::internal {
29
30 namespace ProtoChunk = transfer::pwpb::Chunk;
31
ExtractIdentifier(ConstByteSpan message)32 Result<Chunk::Identifier> Chunk::ExtractIdentifier(ConstByteSpan message) {
33 protobuf::Decoder decoder(message);
34
35 uint32_t session_id = 0;
36 uint32_t desired_session_id = 0;
37 bool legacy = true;
38
39 Status status;
40
41 while ((status = decoder.Next()).ok()) {
42 ProtoChunk::Fields field =
43 static_cast<ProtoChunk::Fields>(decoder.FieldNumber());
44
45 if (field == ProtoChunk::Fields::kTransferId) {
46 // Interpret a legacy transfer_id field as a session ID if an explicit
47 // session_id field has not already been seen.
48 if (session_id == 0) {
49 PW_TRY(decoder.ReadUint32(&session_id));
50 }
51 } else if (field == ProtoChunk::Fields::kSessionId) {
52 // A session_id field always takes precedence over transfer_id.
53 PW_TRY(decoder.ReadUint32(&session_id));
54 legacy = false;
55 } else if (field == ProtoChunk::Fields::kDesiredSessionId) {
56 PW_TRY(decoder.ReadUint32(&desired_session_id));
57 }
58 }
59
60 if (!status.IsOutOfRange()) {
61 return Status::DataLoss();
62 }
63
64 if (desired_session_id != 0) {
65 // Can't have both a desired and regular session_id.
66 if (!legacy && session_id != 0) {
67 return Status::DataLoss();
68 }
69 return Identifier::Desired(desired_session_id);
70 }
71
72 if (session_id != 0) {
73 return legacy ? Identifier::Legacy(session_id)
74 : Identifier::Session(session_id);
75 }
76
77 return Status::DataLoss();
78 }
79
Parse(ConstByteSpan message)80 Result<Chunk> Chunk::Parse(ConstByteSpan message) {
81 protobuf::Decoder decoder(message);
82 Status status;
83 uint32_t value;
84
85 Chunk chunk;
86
87 // Determine the protocol version of the chunk depending on field presence in
88 // the serialized message.
89 chunk.protocol_version_ = ProtocolVersion::kUnknown;
90
91 // Some older versions of the protocol set the deprecated pending_bytes field
92 // in their chunks. The newer transfer handling code does not process this
93 // field, instead working only in terms of window_end_offset. If pending_bytes
94 // is encountered in the serialized message, save its value, then calculate
95 // window_end_offset from it once parsing is complete.
96 uint32_t pending_bytes = 0;
97
98 bool has_session_id = false;
99
100 while ((status = decoder.Next()).ok()) {
101 ProtoChunk::Fields field =
102 static_cast<ProtoChunk::Fields>(decoder.FieldNumber());
103
104 switch (field) {
105 case ProtoChunk::Fields::kTransferId:
106 // transfer_id is a legacy field. session_id will always take precedence
107 // over it, so it should only be read if session_id has not yet been
108 // encountered.
109 if (chunk.session_id_ == 0) {
110 PW_TRY(decoder.ReadUint32(&chunk.session_id_));
111 }
112 break;
113
114 case ProtoChunk::Fields::kSessionId:
115 // The existence of a session_id field indicates that a newer protocol
116 // is running. Update the deduced protocol unless it was explicitly
117 // specified.
118 if (chunk.protocol_version_ == ProtocolVersion::kUnknown) {
119 chunk.protocol_version_ = ProtocolVersion::kVersionTwo;
120 }
121 has_session_id = true;
122 PW_TRY(decoder.ReadUint32(&chunk.session_id_));
123 break;
124
125 case ProtoChunk::Fields::kPendingBytes:
126 PW_TRY(decoder.ReadUint32(&pending_bytes));
127 break;
128
129 case ProtoChunk::Fields::kMaxChunkSizeBytes:
130 PW_TRY(decoder.ReadUint32(&value));
131 chunk.set_max_chunk_size_bytes(value);
132 break;
133
134 case ProtoChunk::Fields::kMinDelayMicroseconds:
135 PW_TRY(decoder.ReadUint32(&value));
136 chunk.set_min_delay_microseconds(value);
137 break;
138
139 case ProtoChunk::Fields::kOffset:
140 PW_TRY(decoder.ReadUint32(&chunk.offset_));
141 break;
142
143 case ProtoChunk::Fields::kData:
144 PW_TRY(decoder.ReadBytes(&chunk.payload_));
145 break;
146
147 case ProtoChunk::Fields::kRemainingBytes: {
148 uint64_t remaining_bytes;
149 PW_TRY(decoder.ReadUint64(&remaining_bytes));
150 chunk.set_remaining_bytes(remaining_bytes);
151 break;
152 }
153
154 case ProtoChunk::Fields::kStatus:
155 PW_TRY(decoder.ReadUint32(&value));
156 chunk.set_status(static_cast<Status::Code>(value));
157 break;
158
159 case ProtoChunk::Fields::kWindowEndOffset:
160 PW_TRY(decoder.ReadUint32(&chunk.window_end_offset_));
161 break;
162
163 case ProtoChunk::Fields::kType: {
164 uint32_t type;
165 PW_TRY(decoder.ReadUint32(&type));
166 chunk.type_ = static_cast<Chunk::Type>(type);
167 break;
168 }
169
170 case ProtoChunk::Fields::kResourceId:
171 PW_TRY(decoder.ReadUint32(&value));
172 chunk.set_resource_id(value);
173 break;
174
175 case ProtoChunk::Fields::kProtocolVersion:
176 // The protocol_version field is added as part of the initial handshake
177 // starting from version 2. If provided, it should override any deduced
178 // protocol version.
179 PW_TRY(decoder.ReadUint32(&value));
180 if (!ValidProtocolVersion(value)) {
181 return Status::DataLoss();
182 }
183 chunk.protocol_version_ = static_cast<ProtocolVersion>(value);
184 break;
185
186 case ProtoChunk::Fields::kDesiredSessionId:
187 PW_TRY(decoder.ReadUint32(&value));
188 chunk.desired_session_id_ = value;
189 break;
190
191 case ProtoChunk::Fields::kInitialOffset:
192 PW_TRY(decoder.ReadUint32(&value));
193 chunk.set_initial_offset(value);
194 break;
195
196 // Silently ignore any unrecognized fields.
197 }
198 }
199
200 if (chunk.desired_session_id_.has_value() && has_session_id) {
201 // Setting both session_id and desired_session_id is not permitted.
202 return Status::DataLoss();
203 }
204
205 if (chunk.protocol_version_ == ProtocolVersion::kUnknown) {
206 // If no fields in the chunk specified its protocol version, assume it is a
207 // legacy chunk.
208 chunk.protocol_version_ = ProtocolVersion::kLegacy;
209 }
210
211 if (pending_bytes != 0) {
212 // Compute window_end_offset if it isn't explicitly provided (in older
213 // protocol versions).
214 chunk.set_window_end_offset(chunk.offset() + pending_bytes);
215 }
216
217 if (status.ok() || status.IsOutOfRange()) {
218 return chunk;
219 }
220
221 return status;
222 }
223
Encode(ByteSpan buffer) const224 Result<ConstByteSpan> Chunk::Encode(ByteSpan buffer) const {
225 PW_CHECK(protocol_version_ != ProtocolVersion::kUnknown,
226 "Cannot encode a transfer chunk with an unknown protocol version");
227
228 ProtoChunk::MemoryEncoder encoder(buffer);
229
230 // Write the payload first to avoid clobbering it if it shares the same buffer
231 // as the encode buffer.
232 if (has_payload()) {
233 encoder.WriteData(payload_).IgnoreError();
234 }
235
236 if (protocol_version_ >= ProtocolVersion::kVersionTwo) {
237 if (session_id_ != 0) {
238 PW_CHECK(!desired_session_id_.has_value(),
239 "A chunk cannot set both a desired and regular session ID");
240 encoder.WriteSessionId(session_id_).IgnoreError();
241 }
242
243 if (desired_session_id_.has_value()) {
244 encoder.WriteDesiredSessionId(desired_session_id_.value()).IgnoreError();
245 }
246
247 if (resource_id_.has_value()) {
248 encoder.WriteResourceId(resource_id_.value()).IgnoreError();
249 }
250 }
251
252 // During the initial handshake, the chunk's configured protocol version is
253 // explicitly serialized to the wire.
254 if (IsInitialHandshakeChunk()) {
255 encoder.WriteProtocolVersion(static_cast<uint32_t>(protocol_version_))
256 .IgnoreError();
257 }
258
259 if (type_.has_value()) {
260 encoder.WriteType(static_cast<ProtoChunk::Type>(type_.value()))
261 .IgnoreError();
262 }
263
264 if (window_end_offset_ != 0) {
265 encoder.WriteWindowEndOffset(window_end_offset_).IgnoreError();
266 }
267
268 // Encode additional fields from the legacy protocol.
269 if (ShouldEncodeLegacyFields()) {
270 // The legacy protocol uses the transfer_id field instead of session_id or
271 // resource_id.
272 if (resource_id_.has_value()) {
273 encoder.WriteTransferId(resource_id_.value()).IgnoreError();
274 } else {
275 encoder.WriteTransferId(session_id_).IgnoreError();
276 }
277
278 // In the legacy protocol, the pending_bytes field must be set alongside
279 // window_end_offset, as some transfer implementations require it.
280 if (window_end_offset_ != 0) {
281 encoder.WritePendingBytes(window_end_offset_ - offset_).IgnoreError();
282 }
283 }
284
285 if (max_chunk_size_bytes_.has_value()) {
286 encoder.WriteMaxChunkSizeBytes(max_chunk_size_bytes_.value()).IgnoreError();
287 }
288 if (min_delay_microseconds_.has_value()) {
289 encoder.WriteMinDelayMicroseconds(min_delay_microseconds_.value())
290 .IgnoreError();
291 }
292
293 if (offset_ != 0) {
294 encoder.WriteOffset(offset_).IgnoreError();
295 }
296
297 if (initial_offset_ != 0) {
298 encoder.WriteInitialOffset(initial_offset_).IgnoreError();
299 }
300
301 if (remaining_bytes_.has_value()) {
302 encoder.WriteRemainingBytes(remaining_bytes_.value()).IgnoreError();
303 }
304
305 if (status_.has_value()) {
306 encoder.WriteStatus(status_.value().code()).IgnoreError();
307 }
308
309 PW_TRY(encoder.status());
310 return ConstByteSpan(encoder);
311 }
312
EncodedSize() const313 size_t Chunk::EncodedSize() const {
314 size_t size = 0;
315
316 if (session_id_ != 0) {
317 if (protocol_version_ >= ProtocolVersion::kVersionTwo) {
318 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kSessionId,
319 session_id_);
320 }
321
322 if (ShouldEncodeLegacyFields()) {
323 if (resource_id_.has_value()) {
324 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kTransferId,
325 resource_id_.value());
326 } else {
327 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kTransferId,
328 session_id_);
329 }
330 }
331 }
332
333 if (IsInitialHandshakeChunk()) {
334 size +=
335 protobuf::SizeOfVarintField(ProtoChunk::Fields::kProtocolVersion,
336 static_cast<uint32_t>(protocol_version_));
337 }
338
339 if (protocol_version_ >= ProtocolVersion::kVersionTwo) {
340 if (resource_id_.has_value()) {
341 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kResourceId,
342 resource_id_.value());
343 }
344 if (desired_session_id_.has_value()) {
345 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kDesiredSessionId,
346 desired_session_id_.value());
347 }
348 }
349
350 if (offset_ != 0) {
351 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kOffset, offset_);
352 }
353
354 if (window_end_offset_ != 0) {
355 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kWindowEndOffset,
356 window_end_offset_);
357
358 if (ShouldEncodeLegacyFields()) {
359 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kPendingBytes,
360 window_end_offset_ - offset_);
361 }
362 }
363
364 if (type_.has_value()) {
365 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kType,
366 static_cast<uint32_t>(type_.value()));
367 }
368
369 if (has_payload()) {
370 size += protobuf::SizeOfDelimitedField(ProtoChunk::Fields::kData,
371 payload_.size());
372 }
373
374 if (max_chunk_size_bytes_.has_value()) {
375 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kMaxChunkSizeBytes,
376 max_chunk_size_bytes_.value());
377 }
378
379 if (min_delay_microseconds_.has_value()) {
380 size +=
381 protobuf::SizeOfVarintField(ProtoChunk::Fields::kMinDelayMicroseconds,
382 min_delay_microseconds_.value());
383 }
384
385 if (remaining_bytes_.has_value()) {
386 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kRemainingBytes,
387 remaining_bytes_.value());
388 }
389
390 if (status_.has_value()) {
391 size += protobuf::SizeOfVarintField(ProtoChunk::Fields::kStatus,
392 status_.value().code());
393 }
394
395 return size;
396 }
397
LogChunk(bool received,pw::chrono::SystemClock::duration rate_limit) const398 void Chunk::LogChunk(bool received,
399 pw::chrono::SystemClock::duration rate_limit) const {
400 // Log in two different spots so the rate limiting applies separately to sent
401 // and received
402 if (received) {
403 PW_LOG_EVERY_N_DURATION(
404 PW_LOG_LEVEL_DEBUG,
405 rate_limit,
406 "Chunk received, type: %u, session id: %u, protocol version: %u,\n"
407 "resource id: %d, desired session id: %d, offset: %u, size: %u,\n"
408 "window end offset: %u, remaining bytes: %d, status: %d",
409 type_.has_value() ? static_cast<unsigned>(type_.value()) : 0,
410 static_cast<unsigned>(session_id_),
411 static_cast<unsigned>(protocol_version_),
412 resource_id_.has_value() ? static_cast<unsigned>(resource_id_.value())
413 : -1,
414 desired_session_id_.has_value()
415 ? static_cast<int>(desired_session_id_.value())
416 : -1,
417 static_cast<unsigned>(offset_),
418 has_payload() ? static_cast<unsigned>(payload_.size()) : 0,
419 static_cast<unsigned>(window_end_offset_),
420 remaining_bytes_.has_value()
421 ? static_cast<unsigned>(remaining_bytes_.value())
422 : -1,
423 status_.has_value() ? static_cast<unsigned>(status_.value().code())
424 : -1);
425 } else {
426 PW_LOG_EVERY_N_DURATION(
427 PW_LOG_LEVEL_DEBUG,
428 rate_limit,
429 "Chunk sent, type: %u, session id: %u, protocol version: %u,\n"
430 "resource id: %d, desired session id: %d, offset: %u, size: %u,\n"
431 "window end offset: %u, remaining bytes: %d, status: %d",
432 type_.has_value() ? static_cast<unsigned>(type_.value()) : 0,
433 static_cast<unsigned>(session_id_),
434 static_cast<unsigned>(protocol_version_),
435 resource_id_.has_value() ? static_cast<unsigned>(resource_id_.value())
436 : -1,
437 desired_session_id_.has_value()
438 ? static_cast<int>(desired_session_id_.value())
439 : -1,
440 static_cast<unsigned>(offset_),
441 has_payload() ? static_cast<unsigned>(payload_.size()) : 0,
442 static_cast<unsigned>(window_end_offset_),
443 remaining_bytes_.has_value()
444 ? static_cast<unsigned>(remaining_bytes_.value())
445 : -1,
446 status_.has_value() ? static_cast<unsigned>(status_.value().code())
447 : -1);
448 }
449 }
450
451 } // namespace pw::transfer::internal
452