1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define TLOG_TAG "apploader-cose"
18 
19 #include <apploader/cbor.h>
20 #include <apploader/cose.h>
21 #include <assert.h>
22 #include <inttypes.h>
23 #include <openssl/bn.h>
24 #include <openssl/crypto.h>
25 #include <openssl/ec.h>
26 #include <openssl/err.h>
27 #include <openssl/evp.h>
28 #include <openssl/rand.h>
29 #include <openssl/x509.h>
30 #include <stddef.h>
31 #include <trusty_log.h>
32 #include <array>
33 #include <optional>
34 #include <vector>
35 
36 #ifdef __COSE_HOST__
37 #define COSE_PRINT_ERROR(...)         \
38     if (!gSilenceErrors) {            \
39         fprintf(stderr, __VA_ARGS__); \
40     }
41 #else
42 #define COSE_PRINT_ERROR(...) \
43     if (!gSilenceErrors) {    \
44         TLOGE(__VA_ARGS__);   \
45     }
46 #endif
47 
48 #ifdef APPLOADER_PACKAGE_SIGN_P384
49 #define APPLOADER_DSA_LENGTH SHA384_DIGEST_LENGTH
50 #define APPLOADER_DSA_NID NID_secp384r1
51 #else
52 #define APPLOADER_DSA_LENGTH SHA256_DIGEST_LENGTH
53 #define APPLOADER_DSA_NID NID_X9_62_prime256v1
54 #endif
55 
56 #ifdef APPLOADER_PACKAGE_CIPHER_A256
57 #define EVP_aes_trusty_gcm() EVP_aes_256_gcm()
58 #else
59 #define EVP_aes_trusty_gcm() EVP_aes_128_gcm()
60 #endif
61 
62 static bool gSilenceErrors = false;
63 
64 constexpr size_t kEcdsaValueSize = APPLOADER_DSA_LENGTH;
65 constexpr size_t kEcdsaSignatureSize = 2 * kEcdsaValueSize;
66 
coseSetSilenceErrors(bool value)67 bool coseSetSilenceErrors(bool value) {
68     bool old = gSilenceErrors;
69     gSilenceErrors = value;
70     return old;
71 }
72 
73 using BIGNUM_Ptr = std::unique_ptr<BIGNUM, std::function<void(BIGNUM*)>>;
74 using EC_KEY_Ptr = std::unique_ptr<EC_KEY, std::function<void(EC_KEY*)>>;
75 using ECDSA_SIG_Ptr =
76         std::unique_ptr<ECDSA_SIG, std::function<void(ECDSA_SIG*)>>;
77 using EVP_CIPHER_CTX_Ptr =
78         std::unique_ptr<EVP_CIPHER_CTX, std::function<void(EVP_CIPHER_CTX*)>>;
79 
80 using SHADigest = std::array<uint8_t, kEcdsaValueSize>;
81 
coseBuildToBeSigned(const std::span<const uint8_t> & encodedProtectedHeaders,const std::vector<uint8_t> & data)82 static std::vector<uint8_t> coseBuildToBeSigned(
83         const std::span<const uint8_t>& encodedProtectedHeaders,
84         const std::vector<uint8_t>& data) {
85     cbor::VectorCborEncoder enc;
86     enc.encodeArray([&](auto& enc) {
87         enc.encodeTstr("Signature1");
88         enc.encodeBstr(encodedProtectedHeaders);
89         // We currently don't support Externally Supplied Data (RFC 8152
90         // section 4.3) so external_aad is the empty bstr
91         enc.encodeEmptyBstr();
92         enc.encodeBstr(data);
93     });
94 
95     return enc.intoVec();
96 }
97 
getRandom(size_t numBytes)98 static std::optional<std::vector<uint8_t>> getRandom(size_t numBytes) {
99     std::vector<uint8_t> output;
100     output.resize(numBytes);
101     if (RAND_bytes(output.data(), numBytes) != 1) {
102         COSE_PRINT_ERROR("RAND_bytes: failed getting %zu random\n", numBytes);
103         return {};
104     }
105     return output;
106 }
107 
108 #ifdef APPLOADER_PACKAGE_SIGN_P384
sha(const std::vector<std::tuple<const void *,size_t>> & data_list)109 static SHADigest sha(
110         const std::vector<std::tuple<const void*, size_t>>& data_list) {
111     SHADigest ret;
112     SHA512_CTX ctx;  // Note that SHA384 functions use a SHA512 context
113 
114     SHA384_Init(&ctx);
115 
116     for (auto data : data_list) {
117         SHA384_Update(&ctx, std::get<0>(data), std::get<1>(data));
118     }
119 
120     SHA384_Final((unsigned char*)ret.data(), &ctx);
121 
122     return ret;
123 }
124 #else
sha(const std::vector<std::tuple<const void *,size_t>> & data_list)125 static SHADigest sha(
126         const std::vector<std::tuple<const void*, size_t>>& data_list) {
127     SHADigest ret;
128     SHA256_CTX ctx;
129 
130     SHA256_Init(&ctx);
131 
132     for (auto data : data_list) {
133         SHA256_Update(&ctx, std::get<0>(data), std::get<1>(data));
134     }
135 
136     SHA256_Final((unsigned char*)ret.data(), &ctx);
137 
138     return ret;
139 }
140 #endif
141 
sha(const std::vector<uint8_t> & data)142 static SHADigest sha(const std::vector<uint8_t>& data) {
143     return sha({{data.data(), data.size()}});
144 }
145 
signEcDsaDigest(const std::vector<uint8_t> & key,const SHADigest & dataDigest)146 static std::optional<std::vector<uint8_t>> signEcDsaDigest(
147         const std::vector<uint8_t>& key,
148         const SHADigest& dataDigest) {
149     const unsigned char* k = key.data();
150     auto ecKey =
151             EC_KEY_Ptr(d2i_ECPrivateKey(nullptr, &k, key.size()), EC_KEY_free);
152     if (!ecKey) {
153         COSE_PRINT_ERROR("Error parsing EC private key\n");
154         return {};
155     }
156 
157     if (EC_KEY_check_key(ecKey.get()) == 0) {
158         COSE_PRINT_ERROR("Error checking EC private key\n");
159         return {};
160     }
161 
162     const EC_GROUP* ecGroup = EC_KEY_get0_group(ecKey.get());
163     if (EC_GROUP_get_curve_name(ecGroup) != APPLOADER_DSA_NID) {
164         COSE_PRINT_ERROR("Error checking EC group (not secp384r1)\n");
165         return {};
166     }
167 
168     auto sig = ECDSA_SIG_Ptr(
169             ECDSA_do_sign(dataDigest.data(), dataDigest.size(), ecKey.get()),
170             ECDSA_SIG_free);
171     if (!sig) {
172         COSE_PRINT_ERROR("Error signing digest:\n");
173         return {};
174     }
175     size_t len = i2d_ECDSA_SIG(sig.get(), nullptr);
176     std::vector<uint8_t> signature;
177     signature.resize(len);
178     unsigned char* p = (unsigned char*)signature.data();
179     i2d_ECDSA_SIG(sig.get(), &p);
180     return signature;
181 }
182 
signEcDsa(const std::vector<uint8_t> & key,const std::vector<uint8_t> & data)183 static std::optional<std::vector<uint8_t>> signEcDsa(
184         const std::vector<uint8_t>& key,
185         const std::vector<uint8_t>& data) {
186     return signEcDsaDigest(key, sha(data));
187 }
188 
ecdsaSignatureDerToCose(const std::vector<uint8_t> & ecdsaDerSignature,std::vector<uint8_t> & ecdsaCoseSignature)189 static bool ecdsaSignatureDerToCose(
190         const std::vector<uint8_t>& ecdsaDerSignature,
191         std::vector<uint8_t>& ecdsaCoseSignature) {
192     const unsigned char* p = ecdsaDerSignature.data();
193     auto sig =
194             ECDSA_SIG_Ptr(d2i_ECDSA_SIG(nullptr, &p, ecdsaDerSignature.size()),
195                           ECDSA_SIG_free);
196     if (!sig) {
197         COSE_PRINT_ERROR("Error decoding DER signature\n");
198         return false;
199     }
200 
201     ecdsaCoseSignature.clear();
202     ecdsaCoseSignature.resize(kEcdsaSignatureSize);
203     if (!BN_bn2bin_padded(ecdsaCoseSignature.data(), kEcdsaValueSize,
204                           ECDSA_SIG_get0_r(sig.get()))) {
205         COSE_PRINT_ERROR("Error encoding r\n");
206         return false;
207     }
208     if (!BN_bn2bin_padded(ecdsaCoseSignature.data() + kEcdsaValueSize,
209                           kEcdsaValueSize, ECDSA_SIG_get0_s(sig.get()))) {
210         COSE_PRINT_ERROR("Error encoding s\n");
211         return false;
212     }
213     return true;
214 }
215 
coseSignEcDsa(const std::vector<uint8_t> & key,uint8_t keyId,const std::vector<uint8_t> & data,const std::span<const uint8_t> & encodedProtectedHeaders,std::span<const uint8_t> & unprotectedHeaders,bool detachContent,bool tagged)216 std::optional<std::vector<uint8_t>> coseSignEcDsa(
217         const std::vector<uint8_t>& key,
218         uint8_t keyId,
219         const std::vector<uint8_t>& data,
220         const std::span<const uint8_t>& encodedProtectedHeaders,
221         std::span<const uint8_t>& unprotectedHeaders,
222         bool detachContent,
223         bool tagged) {
224     cbor::VectorCborEncoder addnHeadersEnc;
225     addnHeadersEnc.encodeMap([&](auto& enc) {
226         enc.encodeKeyValue(COSE_LABEL_KID, [&](auto& enc) {
227             enc.encodeBstr(std::span(&keyId, 1));
228         });
229     });
230     auto updatedUnprotectedHeaders =
231             cbor::mergeMaps(unprotectedHeaders, addnHeadersEnc.view());
232     if (!updatedUnprotectedHeaders.has_value()) {
233         COSE_PRINT_ERROR("Error updating unprotected headers\n");
234         return {};
235     }
236 
237     std::vector<uint8_t> toBeSigned =
238             coseBuildToBeSigned(encodedProtectedHeaders, data);
239 
240     std::optional<std::vector<uint8_t>> derSignature =
241             signEcDsa(key, toBeSigned);
242     if (!derSignature) {
243         COSE_PRINT_ERROR("Error signing toBeSigned data\n");
244         return {};
245     }
246     std::vector<uint8_t> coseSignature;
247     if (!ecdsaSignatureDerToCose(derSignature.value(), coseSignature)) {
248         COSE_PRINT_ERROR(
249                 "Error converting ECDSA signature from DER to COSE format\n");
250         return {};
251     }
252 
253     auto arrayEncodingFn = [&](auto& enc) {
254         enc.encodeArray([&](auto& enc) {
255             /* 1: protected:empty_or_serialized_map */
256             enc.encodeBstr(encodedProtectedHeaders);
257 
258             /* 2: unprotected:map */
259             enc.copyBytes(updatedUnprotectedHeaders.value());
260 
261             /* 3: payload:bstr_or_nil */
262             if (detachContent) {
263                 enc.encodeNull();
264             } else {
265                 enc.encodeBstr(data);
266             }
267 
268             /* 4: signature:bstr */
269             enc.encodeBstr(coseSignature);
270         });
271     };
272 
273     cbor::VectorCborEncoder enc;
274     if (tagged) {
275         enc.encodeTag(COSE_TAG_SIGN1, arrayEncodingFn);
276     } else {
277         arrayEncodingFn(enc);
278     }
279 
280     return enc.intoVec();
281 }
282 
coseIsSigned(CoseByteView data,size_t * signatureLength)283 bool coseIsSigned(CoseByteView data, size_t* signatureLength) {
284     struct CborIn in;
285     uint64_t tag;
286 
287     CborInInit(data.data(), data.size(), &in);
288     while (!CborInAtEnd(&in)) {
289         if (CborReadTag(&in, &tag) == CBOR_READ_RESULT_OK) {
290             if (tag == COSE_TAG_SIGN1) {
291                 if (signatureLength) {
292                     /* read tag item to get its size */
293                     CborReadSkip(&in);
294                     *signatureLength = CborInOffset(&in);
295                 }
296                 return true;
297             }
298         } else if (CborReadSkip(&in) != CBOR_READ_RESULT_OK) {
299             /*
300              * CborReadSkip uses a stack to track nested content so parsing can
301              * fail if nesting of CBOR items causes stack exhaustion. The COSE
302              * format does not cause stack exhaustion so the input must be bad.
303              */
304             return false;
305         }
306     }
307 
308     return false;
309 }
310 
checkEcDsaSignature(const SHADigest & digest,const uint8_t * signature,const uint8_t * publicKey,size_t publicKeySize)311 static bool checkEcDsaSignature(const SHADigest& digest,
312                                 const uint8_t* signature,
313                                 const uint8_t* publicKey,
314                                 size_t publicKeySize) {
315     auto rBn =
316             BIGNUM_Ptr(BN_bin2bn(signature, kEcdsaValueSize, nullptr), BN_free);
317     if (rBn.get() == nullptr) {
318         COSE_PRINT_ERROR("Error creating BIGNUM for r\n");
319         return false;
320     }
321 
322     auto sBn = BIGNUM_Ptr(
323             BN_bin2bn(signature + kEcdsaValueSize, kEcdsaValueSize, nullptr),
324             BN_free);
325     if (sBn.get() == nullptr) {
326         COSE_PRINT_ERROR("Error creating BIGNUM for s\n");
327         return false;
328     }
329 
330     auto sig = ECDSA_SIG_Ptr(ECDSA_SIG_new(), ECDSA_SIG_free);
331     if (!sig) {
332         COSE_PRINT_ERROR("Error allocating ECDSA_SIG\n");
333         return false;
334     }
335 
336     ECDSA_SIG_set0(sig.get(), rBn.release(), sBn.release());
337 
338     const unsigned char* k = publicKey;
339     auto ecKey =
340             EC_KEY_Ptr(d2i_EC_PUBKEY(nullptr, &k, publicKeySize), EC_KEY_free);
341     if (!ecKey) {
342         COSE_PRINT_ERROR("Error parsing EC public key\n");
343         return false;
344     }
345 
346     int rc = ECDSA_do_verify(digest.data(), digest.size(), sig.get(),
347                              ecKey.get());
348     if (rc != 1) {
349         COSE_PRINT_ERROR("Error verifying signature (rc=%d)\n", rc);
350         return false;
351     }
352 
353     return true;
354 }
355 
coseCheckEcDsaSignature(const std::vector<uint8_t> & signatureCoseSign1,const std::vector<uint8_t> & detachedContent,const std::vector<uint8_t> & publicKey)356 bool coseCheckEcDsaSignature(const std::vector<uint8_t>& signatureCoseSign1,
357                              const std::vector<uint8_t>& detachedContent,
358                              const std::vector<uint8_t>& publicKey) {
359     struct CborIn in;
360     CborInInit(signatureCoseSign1.data(), signatureCoseSign1.size(), &in);
361 
362     uint64_t tag;
363     /* COSE message tag is optional */
364     if (CborReadTag(&in, &tag) == CBOR_READ_RESULT_OK) {
365         if (tag != COSE_TAG_SIGN1) {
366             COSE_PRINT_ERROR("Passed-in COSE_Sign1 contained invalid tag\n");
367             return false;
368         }
369     }
370 
371     size_t arraySize;
372     if (CborReadArray(&in, &arraySize) != CBOR_READ_RESULT_OK) {
373         COSE_PRINT_ERROR("Value for COSE_Sign1 is not an array\n");
374         return false;
375     }
376 
377     if (arraySize != 4) {
378         COSE_PRINT_ERROR("Value for COSE_Sign1 is not an array of size 4\n");
379         return false;
380     }
381 
382     const uint8_t* encodedProtectedHeadersPtr;
383     size_t encodedProtectedHeadersSize;
384     if (CborReadBstr(&in, &encodedProtectedHeadersSize,
385                      &encodedProtectedHeadersPtr) != CBOR_READ_RESULT_OK) {
386         COSE_PRINT_ERROR("Value for encodedProtectedHeaders is not a bstr\n");
387         return false;
388     }
389     std::span encodedProtectedHeaders(encodedProtectedHeadersPtr,
390                                       encodedProtectedHeadersSize);
391 
392     size_t unprotectedHeadersSize;
393     if (CborReadMap(&in, &unprotectedHeadersSize) != CBOR_READ_RESULT_OK) {
394         COSE_PRINT_ERROR("Value for unprotectedHeaders is not a map\n");
395         return false;
396     }
397 
398     /* skip past unprotected headers by reading two items per map entry */
399     for (size_t item = 0; item < 2 * unprotectedHeadersSize; item++) {
400         if (CborReadSkip(&in) != CBOR_READ_RESULT_OK) {
401             COSE_PRINT_ERROR("Passed-in COSE_Sign1 is not valid CBOR\n");
402             return false;
403         }
404     }
405 
406     const uint8_t* dataPtr;
407     size_t dataSize = 0;
408     if (CborReadBstr(&in, &dataSize, &dataPtr) != CBOR_READ_RESULT_OK) {
409         if (CborReadNull(&in) != CBOR_READ_RESULT_OK) {
410             COSE_PRINT_ERROR("Value for payload is not null or a bstr\n");
411             return false;
412         }
413     }
414     std::vector<uint8_t> data(dataPtr, dataPtr + dataSize);
415 
416     if (data.size() > 0 && detachedContent.size() > 0) {
417         COSE_PRINT_ERROR("data and detachedContent cannot both be non-empty\n");
418         return false;
419     }
420 
421     const uint8_t* coseSignatureData;
422     size_t coseSignatureSize;
423     if (CborReadBstr(&in, &coseSignatureSize, &coseSignatureData) !=
424         CBOR_READ_RESULT_OK) {
425         COSE_PRINT_ERROR("Value for signature is not a bstr\n");
426         return false;
427     }
428 
429     if (coseSignatureSize != kEcdsaSignatureSize) {
430         COSE_PRINT_ERROR("COSE signature length is %zu, expected %zu\n",
431                          coseSignatureSize, kEcdsaSignatureSize);
432         return false;
433     }
434 
435     // The last field is the payload, independently of how it's transported (RFC
436     // 8152 section 4.4). Since our API specifies only one of |data| and
437     // |detachedContent| can be non-empty, it's simply just the non-empty one.
438     auto& signaturePayload = data.size() > 0 ? data : detachedContent;
439 
440     std::vector<uint8_t> toBeSigned =
441             coseBuildToBeSigned(encodedProtectedHeaders, signaturePayload);
442     if (!checkEcDsaSignature(sha(toBeSigned), coseSignatureData,
443                              publicKey.data(), publicKey.size())) {
444         COSE_PRINT_ERROR("Signature check failed\n");
445         return false;
446     }
447 
448     return true;
449 }
450 
451 /*
452  * Strict signature verification code
453  */
454 static const uint8_t kSignatureHeader[] = {
455         /* clang-format off */
456     0xD2,       // 0xc0 = Tagged item | tag = 18 = COSE_TAG_SIGN1
457     0x84,       // 0x80 = Array       | len = 4
458 
459     // Array item 1
460 #ifdef APPLOADER_PACKAGE_SIGN_P384
461     0x55,       // 0x20 = Byte string | len = 21
462 #else
463     0x54,       // 0x20 = Byte string | len = 20
464 #endif
465 
466         0xA2,       // 0xa0 = Map         | items = 2
467 
468         // Map entry 1: key, value
469         0x01,       // 0x0 = unsigned int | val = 1 = COSE_LABEL_ALG
470 #ifdef APPLOADER_PACKAGE_SIGN_P384
471         0x38,       // 0x2 = Negative int | additional = 24 (1 byte val)
472         0x22,       // Value = 34
473                     // == -1 - 34 = -35 = COSE_ALG_ECDSA_384
474 #else
475         0x26,       // 0x2 = Negative int | value = 6
476                     // == -1 - 6 = -7 = COSE_ALG_ECDSA_256
477 #endif
478         // Map entry 2: key, value
479         0x3A,       // 0x3 = Negative int | additional = 26 = next 4 bytes
480         0x00,       // 0x00010000 = 65536
481         0x01,       //              -1 - 65536 = -65535 = COSE_LABEL_TRUSTY
482         0x00,
483         0x00,
484 
485         0x82,       // 0x80 = Array       | len = 2
486             0x69,       // 0x30 = Text string | len = 9
487             0x54,       // T
488             0x72,       // r
489             0x75,       // u
490             0x73,       // s
491             0x74,       // t
492             0x79,       // y
493             0x41,       // A
494             0x70,       // p
495             0x70,       // p
496             // Version
497             0x01,       // 0x00 = Small value | value = 1 = APPLOADER_SIGNATURE_FORMAT_VERSION_CURRENT
498 
499     // Array Item 2
500     0xA1,       // 0xa = Map          | items = 1
501     0x04,       // 0x0 = unsigned int | value = 4 = COSE_LABEL_KID
502     0x41,       // 0x4 = byte string  | len = 1
503     /* Next octet is the key Id */
504 
505         /* clang-format on */
506 };
507 
508 static const uint8_t kSignatureHeaderPart2[] = {
509         /* clang-format off */
510     0xF6,       // 0x7 = simple value | value = 22 = null
511     0x58,       // 0x2 = bytes string | additional = 24 = next 1 byte
512 #ifdef APPLOADER_PACKAGE_SIGN_P384
513     0x60        // length = 96
514 #else
515     0x40        // length = 64
516 #endif
517         /* clang-format on */
518 };
519 
520 static const uint8_t kSignature1Header[] = {
521         /* clang-format off */
522     0x84,       // 0x8 = array       | length = 4
523 
524     // Array item 1
525     0x6A,       // 0x6 = text string | length = 10
526         0x53,       // S
527         0x69,       // i
528         0x67,       // g
529         0x6E,       // n
530         0x61,       // a
531         0x74,       // t
532         0x75,       // u
533         0x72,       // r
534         0x65,       // e
535         0x31,       // 1
536 
537     // Array item 2
538 #ifdef APPLOADER_PACKAGE_SIGN_P384
539     0x55,       // 0x20 = Byte string | len = 21
540 #else
541     0x54,       // 0x20 = Byte string | len = 20
542 #endif
543         0xA2,       // 0xa0 = Map         | items = 2
544 
545         // Map entry 1: key, value
546         0x01,       // 0x0 = unsigned int | val = 1 = COSE_LABEL_ALG
547 #ifdef APPLOADER_PACKAGE_SIGN_P384
548         0x38,       // 0x2 = Negative int | additional = 24 (1 byte val)
549         0x22,       // Value = 34
550                     // == -1 - 34 = -35 = COSE_ALG_ECDSA_384
551 #else
552         0x26,       // 0x2 = Negative int | value = 6
553                     // == -1 - 6 = -7 = COSE_ALG_ECDSA_256
554 #endif
555         // Map entry 2: key, value
556         0x3A,       // 0x3 = Negative int | additional = 26 = next 4 bytes
557         0x00,       // 0x00010000 = 65536
558         0x01,       //              -1 - 65536 = -65535 = COSE_LABEL_TRUSTY
559         0x00,
560         0x00,
561 
562         0x82,       // 0x8 = Array       | len = 2
563                 0x69,       // 0x30 = Text string | len = 9
564                 0x54,       // T
565                 0x72,       // r
566                 0x75,       // u
567                 0x73,       // s
568                 0x74,       // t
569                 0x79,       // y
570                 0x41,       // A
571                 0x70,       // p
572                 0x70,       // p
573                 // Version
574                 0x01,       // 0x00 = Small value | value = 1
575                             //    = APPLOADER_SIGNATURE_FORMAT_VERSION_CURRENT
576 
577     // Array item 3
578     0x40,       // 0x4 = byte string    | len = 0
579 
580         /* clang-format on */
581 };
582 
583 /*
584  * Fixed offset constants
585  */
586 constexpr size_t kSignatureKeyIdOffset = sizeof(kSignatureHeader);
587 constexpr size_t kSignatureHeaderPart2Offset = kSignatureKeyIdOffset + 1;
588 constexpr size_t kSignatureOffset =
589         kSignatureHeaderPart2Offset + sizeof(kSignatureHeaderPart2);
590 constexpr size_t kPayloadOffset = kSignatureOffset + kEcdsaSignatureSize;
591 
strictCheckEcDsaSignature(const uint8_t * packageStart,size_t packageSize,GetKeyFn keyFn,const uint8_t ** outPackageStart,size_t * outPackageSize)592 bool strictCheckEcDsaSignature(const uint8_t* packageStart,
593                                size_t packageSize,
594                                GetKeyFn keyFn,
595                                const uint8_t** outPackageStart,
596                                size_t* outPackageSize) {
597     if (packageSize < kPayloadOffset) {
598         COSE_PRINT_ERROR("Passed-in COSE_Sign1 is not large enough\n");
599         return false;
600     }
601 
602     if (CRYPTO_memcmp(packageStart, kSignatureHeader,
603                       sizeof(kSignatureHeader))) {
604         COSE_PRINT_ERROR("Passed-in COSE_Sign1 is not valid CBOR\n");
605         return false;
606     }
607 
608     uint8_t kid = packageStart[kSignatureKeyIdOffset];
609     auto [publicKey, publicKeySize] = keyFn(kid);
610     if (!publicKey) {
611         COSE_PRINT_ERROR("Failed to retrieve public key\n");
612         return false;
613     }
614 
615     if (CRYPTO_memcmp(packageStart + kSignatureHeaderPart2Offset,
616                       kSignatureHeaderPart2, sizeof(kSignatureHeaderPart2))) {
617         COSE_PRINT_ERROR("Passed-in COSE_Sign1 is not valid CBOR\n");
618         return false;
619     }
620 
621     // The Signature1 structure encodes the payload as a bstr wrapping the
622     // actual contents (even if they already are CBOR), so we need to manually
623     // prepend a CBOR bstr header to the payload
624     constexpr size_t kMaxPayloadSizeHeaderSize = 9;
625     size_t payloadSize = packageSize - kPayloadOffset;
626     size_t payloadSizeHeaderSize = cbor::encodedSizeOf(payloadSize);
627     assert(payloadSizeHeaderSize <= kMaxPayloadSizeHeaderSize);
628 
629     uint8_t payloadSizeHeader[kMaxPayloadSizeHeaderSize];
630 
631     cbor::encodeBstrHeader(payloadSize, kMaxPayloadSizeHeaderSize,
632                            payloadSizeHeader);
633 
634     SHADigest digest = sha({{kSignature1Header, sizeof(kSignature1Header)},
635                             {payloadSizeHeader, payloadSizeHeaderSize},
636                             {packageStart + kPayloadOffset, payloadSize}});
637 
638     if (!checkEcDsaSignature(digest, packageStart + kSignatureOffset,
639                              publicKey.get(), publicKeySize)) {
640         COSE_PRINT_ERROR("Signature check failed\n");
641         return false;
642     }
643 
644     if (outPackageStart != nullptr) {
645         *outPackageStart = packageStart + kPayloadOffset;
646     }
647     if (outPackageSize != nullptr) {
648         *outPackageSize = payloadSize;
649     }
650     return true;
651 }
652 
coseBuildGcmAad(const std::string_view context,const std::span<const uint8_t> encodedProtectedHeaders,const std::span<const uint8_t> externalAad)653 static std::tuple<std::unique_ptr<uint8_t[]>, size_t> coseBuildGcmAad(
654         const std::string_view context,
655         const std::span<const uint8_t> encodedProtectedHeaders,
656         const std::span<const uint8_t> externalAad) {
657     cbor::ArrayCborEncoder enc;
658     enc.encodeArray([&](auto& enc) {
659         enc.encodeTstr(context);
660         enc.encodeBstr(encodedProtectedHeaders);
661         enc.encodeBstr(externalAad);
662     });
663 
664     return {enc.intoVec().arr(), enc.size()};
665 }
666 
encryptAesGcm(const std::vector<uint8_t> & key,const std::vector<uint8_t> & nonce,const CoseByteView & data,std::span<const uint8_t> additionalAuthenticatedData)667 static std::optional<std::vector<uint8_t>> encryptAesGcm(
668         const std::vector<uint8_t>& key,
669         const std::vector<uint8_t>& nonce,
670         const CoseByteView& data,
671         std::span<const uint8_t> additionalAuthenticatedData) {
672     if (key.size() != kAesGcmKeySize) {
673         COSE_PRINT_ERROR("key is not kAesGcmKeySize (%zu) bytes, got %zu\n",
674                          kAesGcmKeySize, key.size());
675         return {};
676     }
677     if (nonce.size() != kAesGcmIvSize) {
678         COSE_PRINT_ERROR("nonce is not kAesGcmIvSize bytes, got %zu\n",
679                          nonce.size());
680         return {};
681     }
682 
683     // The result is the ciphertext followed by the tag (kAesGcmTagSize bytes).
684     std::vector<uint8_t> encryptedData;
685     encryptedData.resize(data.size() + kAesGcmTagSize);
686     unsigned char* ciphertext = (unsigned char*)encryptedData.data();
687     unsigned char* tag = ciphertext + data.size();
688 
689     auto ctx = EVP_CIPHER_CTX_Ptr(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free);
690     if (ctx.get() == nullptr) {
691         COSE_PRINT_ERROR("EVP_CIPHER_CTX_new: failed, error 0x%lx\n",
692                          static_cast<unsigned long>(ERR_get_error()));
693         return {};
694     }
695 
696     if (EVP_EncryptInit_ex(ctx.get(), EVP_aes_trusty_gcm(), NULL, NULL, NULL) !=
697         1) {
698         COSE_PRINT_ERROR("EVP_EncryptInit_ex: failed, error 0x%lx\n",
699                          static_cast<unsigned long>(ERR_get_error()));
700         return {};
701     }
702 
703     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_IVLEN, kAesGcmIvSize,
704                             NULL) != 1) {
705         COSE_PRINT_ERROR(
706                 "EVP_CIPHER_CTX_ctrl: failed setting nonce length, "
707                 "error 0x%lx\n",
708                 static_cast<unsigned long>(ERR_get_error()));
709         return {};
710     }
711 
712     if (EVP_EncryptInit_ex(ctx.get(), NULL, NULL, key.data(), nonce.data()) !=
713         1) {
714         COSE_PRINT_ERROR("EVP_EncryptInit_ex: failed, error 0x%lx\n",
715                          static_cast<unsigned long>(ERR_get_error()));
716         return {};
717     }
718 
719     int numWritten;
720     if (additionalAuthenticatedData.size() > 0) {
721         if (EVP_EncryptUpdate(ctx.get(), NULL, &numWritten,
722                               additionalAuthenticatedData.data(),
723                               additionalAuthenticatedData.size()) != 1) {
724             fprintf(stderr,
725                     "EVP_EncryptUpdate: failed for "
726                     "additionalAuthenticatedData, error 0x%lx\n",
727                     static_cast<unsigned long>(ERR_get_error()));
728             return {};
729         }
730         /*
731          * std::span::size() should return an size_type==size_t
732          * value but older versions of libcxx return an index_type
733          * which is an alias of ptrdiff_t (a signed type).
734          * We cast the size explicitly to a size_t to cover both cases.
735          */
736         if (static_cast<size_t>(numWritten) !=
737             static_cast<size_t>(additionalAuthenticatedData.size())) {
738             fprintf(stderr,
739                     "EVP_EncryptUpdate: Unexpected outl=%d (expected %zu) "
740                     "for additionalAuthenticatedData\n",
741                     numWritten, additionalAuthenticatedData.size());
742             return {};
743         }
744     }
745 
746     if (data.size() > 0) {
747         if (EVP_EncryptUpdate(ctx.get(), ciphertext, &numWritten, data.data(),
748                               data.size()) != 1) {
749             COSE_PRINT_ERROR("EVP_EncryptUpdate: failed, error 0x%lx\n",
750                              static_cast<unsigned long>(ERR_get_error()));
751             return {};
752         }
753         if (static_cast<size_t>(numWritten) !=
754             static_cast<size_t>(data.size())) {
755             fprintf(stderr,
756                     "EVP_EncryptUpdate: Unexpected outl=%d (expected %zu)\n",
757                     numWritten, data.size());
758             ;
759             return {};
760         }
761     }
762 
763     if (EVP_EncryptFinal_ex(ctx.get(), ciphertext + numWritten, &numWritten) !=
764         1) {
765         COSE_PRINT_ERROR("EVP_EncryptFinal_ex: failed, error 0x%lx\n",
766                          static_cast<unsigned long>(ERR_get_error()));
767         return {};
768     }
769     if (numWritten != 0) {
770         COSE_PRINT_ERROR("EVP_EncryptFinal_ex: Unexpected non-zero outl=%d\n",
771                          numWritten);
772         return {};
773     }
774 
775     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_GET_TAG, kAesGcmTagSize,
776                             tag) != 1) {
777         COSE_PRINT_ERROR(
778                 "EVP_CIPHER_CTX_ctrl: failed getting tag, "
779                 "error 0x%lx\n",
780                 static_cast<unsigned long>(ERR_get_error()));
781         return {};
782     }
783 
784     return encryptedData;
785 }
786 
coseEncryptAesGcm(const std::string_view context,const std::vector<uint8_t> & key,const CoseByteView & data,const std::vector<uint8_t> & externalAad,const std::vector<uint8_t> & encodedProtectedHeaders,const CoseByteView & unprotectedHeaders,std::optional<std::vector<uint8_t>> recipients)787 static std::optional<std::vector<uint8_t>> coseEncryptAesGcm(
788         const std::string_view context,
789         const std::vector<uint8_t>& key,
790         const CoseByteView& data,
791         const std::vector<uint8_t>& externalAad,
792         const std::vector<uint8_t>& encodedProtectedHeaders,
793         const CoseByteView& unprotectedHeaders,
794         std::optional<std::vector<uint8_t>> recipients) {
795     std::optional<std::vector<uint8_t>> iv = getRandom(kAesGcmIvSize);
796     if (!iv) {
797         COSE_PRINT_ERROR("Error generating encryption IV\n");
798         return {};
799     }
800 
801     cbor::VectorCborEncoder ivEnc;
802     ivEnc.encodeMap([&](auto& enc) {
803         enc.encodeKeyValue(COSE_LABEL_IV,
804                            [&](auto& enc) { enc.encodeBstr(iv.value()); });
805     });
806 
807     auto finalUnprotectedHeaders =
808             cbor::mergeMaps(unprotectedHeaders, ivEnc.view());
809     if (!finalUnprotectedHeaders) {
810         COSE_PRINT_ERROR("Error updating unprotected headers with IV\n");
811         return {};
812     }
813 
814     std::span encodedProtectedHeadersView(encodedProtectedHeaders.data(),
815                                           encodedProtectedHeaders.size());
816     std::span externalAadView = externalAad;
817     auto [gcmAad, gcmAadSize] = coseBuildGcmAad(
818             context, encodedProtectedHeadersView, externalAadView);
819     std::span gcmAadView(gcmAad.get(), gcmAadSize);
820 
821     std::optional<std::vector<uint8_t>> ciphertext =
822             encryptAesGcm(key, iv.value(), data, gcmAadView);
823     if (!ciphertext) {
824         COSE_PRINT_ERROR("Error encrypting data\n");
825         return {};
826     }
827 
828     cbor::VectorCborEncoder enc;
829     enc.encodeArray([&](auto& enc) {
830         enc.encodeBstr(encodedProtectedHeaders);
831         enc.copyBytes(finalUnprotectedHeaders.value());
832         enc.encodeBstr(ciphertext.value());
833         if (recipients) {
834             enc.copyBytes(recipients.value());
835         }
836     });
837 
838     return enc.intoVec();
839 }
840 
coseEncryptAesGcmKeyWrap(const std::vector<uint8_t> & key,uint8_t keyId,const CoseByteView & data,const std::vector<uint8_t> & externalAad,const std::vector<uint8_t> & encodedProtectedHeaders,const CoseByteView & unprotectedHeaders,bool tagged)841 std::optional<std::vector<uint8_t>> coseEncryptAesGcmKeyWrap(
842         const std::vector<uint8_t>& key,
843         uint8_t keyId,
844         const CoseByteView& data,
845         const std::vector<uint8_t>& externalAad,
846         const std::vector<uint8_t>& encodedProtectedHeaders,
847         const CoseByteView& unprotectedHeaders,
848         bool tagged) {
849     /* Generate and encrypt the CEK */
850     std::optional<std::vector<uint8_t>> contentEncryptionKey =
851             getRandom(kAesGcmKeySize);
852     if (!contentEncryptionKey) {
853         COSE_PRINT_ERROR("Error generating encryption key\n");
854         return {};
855     }
856 
857     cbor::VectorCborEncoder coseKeyEnc;
858     coseKeyEnc.encodeMap([&](auto& enc) {
859         enc.encodeKeyValue(COSE_LABEL_KEY_KTY, COSE_KEY_TYPE_SYMMETRIC);
860         enc.encodeKeyValue(COSE_LABEL_KEY_ALG, COSE_VAL_CIPHER_ALG);
861         enc.encodeKeyValue(COSE_LABEL_KEY_SYMMETRIC_KEY, [&](auto& enc) {
862             enc.encodeBstr(contentEncryptionKey.value());
863         });
864     });
865     CoseByteView coseKeyByteView = coseKeyEnc.view();
866 
867     cbor::VectorCborEncoder keyUnprotectedHeadersEnc;
868     keyUnprotectedHeadersEnc.encodeMap([&](auto& enc) {
869         enc.encodeKeyValue(COSE_LABEL_KID, [&](auto& enc) {
870             enc.encodeBstr(std::span<const uint8_t>(&keyId, 1));
871         });
872     });
873     auto keyUnprotectedHeaders = keyUnprotectedHeadersEnc.view();
874 
875     cbor::VectorCborEncoder encodedProtectedHeadersForEncKey;
876     encodedProtectedHeadersForEncKey.encodeMap([&](auto& enc) {
877         enc.encodeKeyValue(COSE_LABEL_ALG, COSE_VAL_CIPHER_ALG);
878     });
879 
880     auto encContentEncryptionKey =
881             coseEncryptAesGcm(COSE_CONTEXT_ENC_RECIPIENT, key, coseKeyByteView,
882                               {}, encodedProtectedHeadersForEncKey.intoVec(),
883                               keyUnprotectedHeaders, {});
884     if (!encContentEncryptionKey.has_value()) {
885         COSE_PRINT_ERROR("Error wrapping encryption key\n");
886         return {};
887     }
888 
889     cbor::VectorCborEncoder recipientsEnc;
890     recipientsEnc.encodeArray(
891             [&](auto& enc) { enc.copyBytes(encContentEncryptionKey.value()); });
892     auto recipients = recipientsEnc.intoVec();
893 
894     auto coseEncrypt = coseEncryptAesGcm(
895             COSE_CONTEXT_ENCRYPT, std::move(contentEncryptionKey.value()), data,
896             externalAad, encodedProtectedHeaders, unprotectedHeaders,
897             std::move(recipients));
898     if (!coseEncrypt.has_value()) {
899         COSE_PRINT_ERROR("Error encrypting application package\n");
900         return {};
901     }
902 
903     if (tagged) {
904         cbor::VectorCborEncoder enc;
905         enc.encodeTag(COSE_TAG_ENCRYPT,
906                       [&](auto& enc) { enc.copyBytes(coseEncrypt.value()); });
907         return enc.intoVec();
908     } else {
909         return coseEncrypt;
910     }
911 }
912 
decryptAesGcmInPlace(std::span<const uint8_t> key,std::span<const uint8_t> nonce,uint8_t * encryptedData,size_t encryptedDataSize,std::span<const uint8_t> additionalAuthenticatedData,size_t * outPlaintextSize)913 static bool decryptAesGcmInPlace(
914         std::span<const uint8_t> key,
915         std::span<const uint8_t> nonce,
916         uint8_t* encryptedData,
917         size_t encryptedDataSize,
918         std::span<const uint8_t> additionalAuthenticatedData,
919         size_t* outPlaintextSize) {
920     assert(outPlaintextSize != nullptr);
921 
922     int ciphertextSize = int(encryptedDataSize) - kAesGcmTagSize;
923     if (ciphertextSize < 0) {
924         COSE_PRINT_ERROR("encryptedData too small\n");
925         return false;
926     }
927     if (key.size() != kAesGcmKeySize) {
928         COSE_PRINT_ERROR("key is not kAesGcmKeySize (%zu) bytes, got %zu\n",
929                          kAesGcmKeySize, key.size());
930         return {};
931     }
932     if (nonce.size() != kAesGcmIvSize) {
933         COSE_PRINT_ERROR("nonce is not kAesGcmIvSize bytes, got %zu\n",
934                          nonce.size());
935         return false;
936     }
937     unsigned char* ciphertext = encryptedData;
938     unsigned char* tag = ciphertext + ciphertextSize;
939 
940     /*
941      * Decrypt the data in place. OpenSSL and BoringSSL support this as long as
942      * the plaintext buffer completely overlaps the ciphertext.
943      */
944     unsigned char* plaintext = encryptedData;
945 
946     auto ctx = EVP_CIPHER_CTX_Ptr(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free);
947     if (ctx.get() == nullptr) {
948         COSE_PRINT_ERROR("EVP_CIPHER_CTX_new: failed, error 0x%lx\n",
949                          static_cast<unsigned long>(ERR_get_error()));
950         return false;
951     }
952 
953     if (EVP_DecryptInit_ex(ctx.get(), EVP_aes_trusty_gcm(), NULL, NULL, NULL) !=
954         1) {
955         COSE_PRINT_ERROR("EVP_DecryptInit_ex: failed, error 0x%lx\n",
956                          static_cast<unsigned long>(ERR_get_error()));
957         return false;
958     }
959 
960     if (EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_IVLEN, kAesGcmIvSize,
961                             NULL) != 1) {
962         COSE_PRINT_ERROR(
963                 "EVP_CIPHER_CTX_ctrl: failed setting nonce length, "
964                 "error 0x%lx\n",
965                 static_cast<unsigned long>(ERR_get_error()));
966         return false;
967     }
968 
969     if (EVP_DecryptInit_ex(ctx.get(), NULL, NULL, key.data(), nonce.data()) !=
970         1) {
971         COSE_PRINT_ERROR("EVP_DecryptInit_ex: failed, error 0x%lx\n",
972                          static_cast<unsigned long>(ERR_get_error()));
973         return false;
974     }
975 
976     int numWritten;
977     if (additionalAuthenticatedData.size() > 0) {
978         if (EVP_DecryptUpdate(ctx.get(), NULL, &numWritten,
979                               additionalAuthenticatedData.data(),
980                               additionalAuthenticatedData.size()) != 1) {
981             COSE_PRINT_ERROR(
982                     "EVP_DecryptUpdate: failed for "
983                     "additionalAuthenticatedData, error 0x%lx\n",
984                     static_cast<unsigned long>(ERR_get_error()));
985             return false;
986         }
987         if (static_cast<size_t>(numWritten) !=
988             static_cast<size_t>(additionalAuthenticatedData.size())) {
989             COSE_PRINT_ERROR(
990                     "EVP_DecryptUpdate: Unexpected outl=%d "
991                     "(expected %zd) for additionalAuthenticatedData\n",
992                     numWritten, additionalAuthenticatedData.size());
993             return false;
994         }
995     }
996 
997     if (EVP_DecryptUpdate(ctx.get(), plaintext, &numWritten, ciphertext,
998                           ciphertextSize) != 1) {
999         COSE_PRINT_ERROR("EVP_DecryptUpdate: failed, error 0x%lx\n",
1000                          static_cast<unsigned long>(ERR_get_error()));
1001         return false;
1002     }
1003     if (numWritten != ciphertextSize) {
1004         COSE_PRINT_ERROR(
1005                 "EVP_DecryptUpdate: Unexpected outl=%d "
1006                 "(expected %d)\n",
1007                 numWritten, ciphertextSize);
1008         return false;
1009     }
1010 
1011     if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_GCM_SET_TAG, kAesGcmTagSize,
1012                              tag)) {
1013         COSE_PRINT_ERROR(
1014                 "EVP_CIPHER_CTX_ctrl: failed setting expected tag, "
1015                 "error 0x%lx\n",
1016                 static_cast<unsigned long>(ERR_get_error()));
1017         return false;
1018     }
1019 
1020     int ret =
1021             EVP_DecryptFinal_ex(ctx.get(), plaintext + numWritten, &numWritten);
1022     if (ret != 1) {
1023         COSE_PRINT_ERROR("EVP_DecryptFinal_ex: failed, error 0x%lx\n",
1024                          static_cast<unsigned long>(ERR_get_error()));
1025         return false;
1026     }
1027     if (numWritten != 0) {
1028         COSE_PRINT_ERROR("EVP_DecryptFinal_ex: Unexpected non-zero outl=%d\n",
1029                          numWritten);
1030         return false;
1031     }
1032 
1033     *outPlaintextSize = ciphertextSize;
1034     return true;
1035 }
1036 
coseDecryptAesGcmInPlace(const std::string_view context,const CoseByteView & item,const std::span<const uint8_t> key,const std::vector<uint8_t> & externalAad,const uint8_t ** outPlaintextStart,size_t * outPlaintextSize,DecryptFn keyDecryptFn)1037 static bool coseDecryptAesGcmInPlace(const std::string_view context,
1038                                      const CoseByteView& item,
1039                                      const std::span<const uint8_t> key,
1040                                      const std::vector<uint8_t>& externalAad,
1041                                      const uint8_t** outPlaintextStart,
1042                                      size_t* outPlaintextSize,
1043                                      DecryptFn keyDecryptFn) {
1044     assert(outPlaintextStart != nullptr);
1045     assert(outPlaintextSize != nullptr);
1046 
1047     struct CborIn in;
1048     CborInInit(item.data(), item.size(), &in);
1049 
1050     size_t num_elements;
1051     if (CborReadArray(&in, &num_elements) != CBOR_READ_RESULT_OK) {
1052         COSE_PRINT_ERROR("Encrypted data is not a CBOR array\n");
1053         return false;
1054     }
1055 
1056     if (num_elements < 3 || num_elements > 4) {
1057         COSE_PRINT_ERROR("Invalid COSE encryption array size, got %zu\n",
1058                          num_elements);
1059         return false;
1060     }
1061 
1062     const uint8_t* enc_protected_headers_data;
1063     size_t enc_protected_headers_size;
1064     if (CborReadBstr(&in, &enc_protected_headers_size,
1065                      &enc_protected_headers_data) != CBOR_READ_RESULT_OK) {
1066         COSE_PRINT_ERROR(
1067                 "Failed to retrieve protected headers "
1068                 "from COSE encryption structure\n");
1069         return false;
1070     }
1071 
1072     struct CborIn protHdrIn;
1073     CborInInit(enc_protected_headers_data, enc_protected_headers_size,
1074                &protHdrIn);
1075 
1076     size_t numPairs;
1077     if (CborReadMap(&protHdrIn, &numPairs) != CBOR_READ_RESULT_OK) {
1078         COSE_PRINT_ERROR("Invalid protected headers CBOR type\n");
1079         return false;
1080     }
1081 
1082     int64_t label;
1083     std::optional<uint64_t> alg;
1084     for (size_t i = 0; i < numPairs; i++) {
1085         // Read key
1086         if (CborReadInt(&protHdrIn, &label) != CBOR_READ_RESULT_OK) {
1087             COSE_PRINT_ERROR(
1088                     "Failed to read protected headers "
1089                     "in COSE encryption structure\n");
1090             return false;
1091         }
1092 
1093         // Read value
1094         if (label == COSE_LABEL_ALG) {
1095             uint64_t algVal;
1096             if (CborReadUint(&protHdrIn, &algVal) != CBOR_READ_RESULT_OK) {
1097                 COSE_PRINT_ERROR(
1098                         "Wrong CBOR type for alg value in unprotected headers\n");
1099                 return false;
1100             }
1101 
1102             if (algVal != COSE_VAL_CIPHER_ALG) {
1103                 COSE_PRINT_ERROR("Invalid COSE algorithm, got %" PRId64 "\n",
1104                                  algVal);
1105                 return false;
1106             }
1107 
1108             alg = algVal;
1109         } else if (CborReadSkip(&protHdrIn) != CBOR_READ_RESULT_OK) {
1110             COSE_PRINT_ERROR(
1111                     "Failed to read protected headers "
1112                     "in COSE encryption structure\n");
1113             return false;
1114         }
1115     }
1116 
1117     if (CborReadMap(&in, &numPairs) != CBOR_READ_RESULT_OK) {
1118         COSE_PRINT_ERROR(
1119                 "Failed to retrieve unprotected headers "
1120                 "from COSE encryption structure\n");
1121         return false;
1122     }
1123 
1124     const uint8_t* ivData = nullptr;
1125     size_t ivSize;
1126     for (size_t i = 0; i < numPairs; i++) {
1127         // Read key
1128         if (CborReadInt(&in, &label) != CBOR_READ_RESULT_OK) {
1129             COSE_PRINT_ERROR(
1130                     "Failed to read unprotected headers "
1131                     "in COSE encryption structure\n");
1132             return false;
1133         }
1134 
1135         // Read value
1136         if (label == COSE_LABEL_IV) {
1137             if (CborReadBstr(&in, &ivSize, &ivData) != CBOR_READ_RESULT_OK) {
1138                 COSE_PRINT_ERROR(
1139                         "Wrong CBOR type for IV value in unprotected headers\n");
1140                 return false;
1141             }
1142         } else if (CborReadSkip(&in) != CBOR_READ_RESULT_OK) {
1143             COSE_PRINT_ERROR(
1144                     "Failed to read unprotected headers "
1145                     "in COSE encryption structure\n");
1146             return false;
1147         }
1148     }
1149 
1150     if (ivData == nullptr) {
1151         COSE_PRINT_ERROR("Missing IV field in COSE encryption structure\n");
1152         return false;
1153     }
1154 
1155     const uint8_t* ciphertextData;
1156     size_t ciphertextSize;
1157     if (CborReadBstr(&in, &ciphertextSize, &ciphertextData) !=
1158         CBOR_READ_RESULT_OK) {
1159         COSE_PRINT_ERROR(
1160                 "Failed to retrieve ciphertext "
1161                 "from COSE encryption structure\n");
1162         return false;
1163     }
1164 
1165     std::span externalAadView = externalAad;
1166     std::span encodedProtectedHeaders(enc_protected_headers_data,
1167                                       enc_protected_headers_size);
1168     auto [gcmAad, gcmAadSize] =
1169             coseBuildGcmAad(context, encodedProtectedHeaders, externalAadView);
1170 
1171     std::span gcmAadView(gcmAad.get(), gcmAadSize);
1172     std::span ivView(ivData, ivSize);
1173     if (!keyDecryptFn(key, ivView, const_cast<uint8_t*>(ciphertextData),
1174                       ciphertextSize, gcmAadView, outPlaintextSize)) {
1175         return false;
1176     }
1177 
1178     *outPlaintextStart = ciphertextData;
1179 
1180     return true;
1181 }
1182 
coseDecryptAesGcmKeyWrapInPlace(const CoseByteView & cose_encrypt,GetKeyFn keyFn,const std::vector<uint8_t> & externalAad,bool checkTag,const uint8_t ** outPackageStart,size_t * outPackageSize,DecryptFn keyDecryptFn)1183 bool coseDecryptAesGcmKeyWrapInPlace(const CoseByteView& cose_encrypt,
1184                                      GetKeyFn keyFn,
1185                                      const std::vector<uint8_t>& externalAad,
1186                                      bool checkTag,
1187                                      const uint8_t** outPackageStart,
1188                                      size_t* outPackageSize,
1189                                      DecryptFn keyDecryptFn) {
1190     assert(outPackageStart != nullptr);
1191     assert(outPackageSize != nullptr);
1192 
1193     if (!keyDecryptFn) {
1194         keyDecryptFn = &decryptAesGcmInPlace;
1195     }
1196 
1197     struct CborIn in;
1198     CborInInit(cose_encrypt.data(), cose_encrypt.size(), &in);
1199 
1200     uint64_t tag;
1201     if (CborReadTag(&in, &tag) == CBOR_READ_RESULT_OK) {
1202         if (checkTag && tag != COSE_TAG_ENCRYPT) {
1203             TLOGE("Invalid COSE_Encrypt semantic tag: %" PRIu64 "\n", tag);
1204             return false;
1205         }
1206     } else if (checkTag) {
1207         TLOGE("Expected COSE_Encrypt semantic tag\n");
1208         return false;
1209     }
1210 
1211     size_t num_elements;
1212     if (CborReadArray(&in, &num_elements) != CBOR_READ_RESULT_OK) {
1213         COSE_PRINT_ERROR("Encrypted data is not a CBOR array\n");
1214         return false;
1215     }
1216 
1217     if (num_elements != kCoseEncryptArrayElements) {
1218         COSE_PRINT_ERROR("Invalid COSE_Encrypt array size, got %zu\n",
1219                          num_elements);
1220         return false;
1221     }
1222 
1223     // Skip past the first three array elemements
1224     while (num_elements-- > 1) {
1225         if (CborReadSkip(&in) != CBOR_READ_RESULT_OK) {
1226             COSE_PRINT_ERROR(
1227                     "Failed to retrieve recipients "
1228                     "from COSE_Encrypt structure\n");
1229             return false;
1230         }
1231     }
1232 
1233     // Read recipients array
1234     if (CborReadArray(&in, &num_elements) != CBOR_READ_RESULT_OK) {
1235         COSE_PRINT_ERROR(
1236                 "Failed to retrieve recipients "
1237                 "from COSE_Encrypt structure\n");
1238         return false;
1239     }
1240 
1241     if (num_elements != 1) {
1242         COSE_PRINT_ERROR("Invalid recipients array size, got %zu\n",
1243                          num_elements);
1244         return false;
1245     }
1246 
1247     const size_t recipientOffset = CborInOffset(&in);
1248     // Read singleton recipient
1249     if (CborReadArray(&in, &num_elements) != CBOR_READ_RESULT_OK) {
1250         COSE_PRINT_ERROR("COSE_Recipient is not a CBOR array\n");
1251         return false;
1252     }
1253 
1254     if (num_elements != 3) {
1255         COSE_PRINT_ERROR(
1256                 "Invalid COSE_Recipient structure array size, "
1257                 "got %zu\n",
1258                 num_elements);
1259         return false;
1260     }
1261 
1262     // Skip to unprotected headers array element
1263     if (CborReadSkip(&in) != CBOR_READ_RESULT_OK) {
1264         COSE_PRINT_ERROR("Failed to read COSE_Recipient structure\n");
1265         return false;
1266     }
1267 
1268     size_t numPairs;
1269     if (CborReadMap(&in, &numPairs) != CBOR_READ_RESULT_OK) {
1270         COSE_PRINT_ERROR(
1271                 "Failed to retrieve unprotected headers "
1272                 "from COSE_Recipient structure\n");
1273         return false;
1274     }
1275 
1276     uint64_t label;
1277     const uint8_t* keyIdBytes = nullptr;
1278     size_t keyIdSize;
1279     for (size_t i = 0; i < numPairs; i++) {
1280         // Read key
1281         if (CborReadUint(&in, &label) != CBOR_READ_RESULT_OK) {
1282             COSE_PRINT_ERROR(
1283                     "Failed to read unprotected headers "
1284                     "in COSE_Recipient structure\n");
1285             return false;
1286         }
1287 
1288         // Read value
1289         if (label == COSE_LABEL_KID) {
1290             if (CborReadBstr(&in, &keyIdSize, &keyIdBytes) !=
1291                 CBOR_READ_RESULT_OK) {
1292                 COSE_PRINT_ERROR(
1293                         "Failed to extract key id from unprotected headers "
1294                         "in COSE_Recipient structure\n");
1295                 return false;
1296             }
1297         } else if (CborReadSkip(&in) != CBOR_READ_RESULT_OK) {
1298             COSE_PRINT_ERROR(
1299                     "Failed to read unprotected headers "
1300                     "in COSE_Recipient structure\n");
1301             return false;
1302         }
1303     }
1304 
1305     // Skip over ciphertext
1306     if (CborReadSkip(&in) != CBOR_READ_RESULT_OK) {
1307         COSE_PRINT_ERROR("Failed to read COSE_Recipient structure\n");
1308         return false;
1309     }
1310 
1311     if (!CborInAtEnd(&in)) {
1312         COSE_PRINT_ERROR("Failed to read COSE_Recipient structure\n");
1313         return false;
1314     }
1315 
1316     CoseByteView recipient(cose_encrypt.data() + recipientOffset,
1317                            CborInOffset(&in) - recipientOffset);
1318 
1319     if (keyIdBytes == nullptr) {
1320         COSE_PRINT_ERROR("Missing key id field in COSE_Recipient\n");
1321         return false;
1322     }
1323 
1324     if (keyIdSize != 1) {
1325         COSE_PRINT_ERROR("Invalid key id field length, got %zu\n", keyIdSize);
1326         return false;
1327     }
1328 
1329     auto [keyEncryptionKeyStart, keyEncryptionKeySize] = keyFn(keyIdBytes[0]);
1330     if (!keyEncryptionKeyStart) {
1331         COSE_PRINT_ERROR("Failed to retrieve decryption key\n");
1332         return false;
1333     }
1334 
1335     std::span keyEncryptionKey(keyEncryptionKeyStart.get(),
1336                                keyEncryptionKeySize);
1337 
1338     const uint8_t* coseKeyStart;
1339     size_t coseKeySize;
1340     if (!coseDecryptAesGcmInPlace(COSE_CONTEXT_ENC_RECIPIENT, recipient,
1341                                   keyEncryptionKey, {}, &coseKeyStart,
1342                                   &coseKeySize, keyDecryptFn)) {
1343         COSE_PRINT_ERROR("Failed to decrypt COSE_Key structure\n");
1344         return false;
1345     }
1346 
1347     CborInInit(coseKeyStart, coseKeySize, &in);
1348     if (CborReadMap(&in, &numPairs) != CBOR_READ_RESULT_OK) {
1349         COSE_PRINT_ERROR("COSE_Key structure is not a map\n");
1350         return false;
1351     }
1352 
1353     int64_t keyLabel;
1354     int64_t value;
1355     bool ktyValidated = false;
1356     bool algValidated = false;
1357     const uint8_t* contentEncryptionKeyStart = nullptr;
1358     size_t contentEncryptionKeySize = 0;
1359     for (size_t i = 0; i < numPairs; i++) {
1360         if (CborReadInt(&in, &keyLabel) != CBOR_READ_RESULT_OK) {
1361             COSE_PRINT_ERROR("Failed to parse key in COSE_Key structure\n");
1362             return false;
1363         }
1364 
1365         switch (keyLabel) {
1366         case COSE_LABEL_KEY_KTY:
1367             if (CborReadInt(&in, &value) != CBOR_READ_RESULT_OK) {
1368                 COSE_PRINT_ERROR("Wrong CBOR type for kty field of COSE_Key\n");
1369                 return false;
1370             }
1371             if (value != COSE_KEY_TYPE_SYMMETRIC) {
1372                 COSE_PRINT_ERROR("Invalid COSE_Key key type: %" PRId64 "\n",
1373                                  value);
1374                 return false;
1375             }
1376             ktyValidated = true;
1377             break;
1378         case COSE_LABEL_KEY_ALG:
1379             if (CborReadInt(&in, &value) != CBOR_READ_RESULT_OK) {
1380                 COSE_PRINT_ERROR("Wrong CBOR type for kty field of COSE_Key\n");
1381                 return false;
1382             }
1383             if (value != COSE_VAL_CIPHER_ALG) {
1384                 COSE_PRINT_ERROR("Invalid COSE_Key algorithm value: %" PRId64
1385                                  "\n",
1386                                  value);
1387                 return false;
1388             }
1389             algValidated = true;
1390             break;
1391         case COSE_LABEL_KEY_SYMMETRIC_KEY:
1392             if (CborReadBstr(&in, &contentEncryptionKeySize,
1393                              &contentEncryptionKeyStart)) {
1394                 COSE_PRINT_ERROR("Wrong CBOR type for key field of COSE_Key\n");
1395                 return false;
1396             }
1397             if (contentEncryptionKeySize != kAesGcmKeySize) {
1398                 COSE_PRINT_ERROR(
1399                         "Invalid content encryption key size, got %zu\n",
1400                         contentEncryptionKeySize);
1401                 return false;
1402             }
1403             break;
1404         default:
1405             COSE_PRINT_ERROR("Invalid key field in COSE_Key: %" PRId64 "\n",
1406                              label);
1407             return false;
1408             break;
1409         }
1410     }
1411 
1412     if (!ktyValidated) {
1413         COSE_PRINT_ERROR("Missing kty field of COSE_Key\n");
1414         return false;
1415     } else if (!algValidated) {
1416         COSE_PRINT_ERROR("Missing alg field of COSE_Key\n");
1417         return false;
1418     } else if (!contentEncryptionKeyStart) {
1419         COSE_PRINT_ERROR("Missing key field in COSE_Key\n");
1420         return false;
1421     }
1422 
1423     const CoseByteView contentEncryptionKey(contentEncryptionKeyStart,
1424                                             contentEncryptionKeySize);
1425     if (!coseDecryptAesGcmInPlace(COSE_CONTEXT_ENCRYPT, cose_encrypt,
1426                                   contentEncryptionKey, externalAad,
1427                                   outPackageStart, outPackageSize,
1428                                   decryptAesGcmInPlace)) {
1429         COSE_PRINT_ERROR("Failed to decrypt payload\n");
1430         return false;
1431     }
1432 
1433     return true;
1434 }
1435 
coseGetCipherAlg(void)1436 const char* coseGetCipherAlg(void) {
1437 #ifdef APPLOADER_PACKAGE_CIPHER_A256
1438     return "AES-GCM with 256-bit key, 128-bit tag";
1439 #else
1440     return "AES-GCM with 128-bit key, 128-bit tag";
1441 #endif
1442 }
1443 
coseGetSigningDsa(void)1444 const char* coseGetSigningDsa(void) {
1445 #ifdef APPLOADER_PACKAGE_SIGN_P384
1446     return "ECDA P-384 + SHA-384 signatures";
1447 #else
1448     return "ECDA P-256 + SHA-256 signatures";
1449 #endif
1450 }
1451