1// Copyright 2024 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package hpke 6 7import ( 8 "crypto" 9 "crypto/aes" 10 "crypto/cipher" 11 "crypto/ecdh" 12 "crypto/rand" 13 "encoding/binary" 14 "errors" 15 "math/bits" 16 17 "golang.org/x/crypto/chacha20poly1305" 18 "golang.org/x/crypto/hkdf" 19) 20 21// testingOnlyGenerateKey is only used during testing, to provide 22// a fixed test key to use when checking the RFC 9180 vectors. 23var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error) 24 25type hkdfKDF struct { 26 hash crypto.Hash 27} 28 29func (kdf *hkdfKDF) LabeledExtract(suiteID []byte, salt []byte, label string, inputKey []byte) []byte { 30 labeledIKM := make([]byte, 0, 7+len(suiteID)+len(label)+len(inputKey)) 31 labeledIKM = append(labeledIKM, []byte("HPKE-v1")...) 32 labeledIKM = append(labeledIKM, suiteID...) 33 labeledIKM = append(labeledIKM, label...) 34 labeledIKM = append(labeledIKM, inputKey...) 35 return hkdf.Extract(kdf.hash.New, labeledIKM, salt) 36} 37 38func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte { 39 labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info)) 40 labeledInfo = binary.BigEndian.AppendUint16(labeledInfo, length) 41 labeledInfo = append(labeledInfo, []byte("HPKE-v1")...) 42 labeledInfo = append(labeledInfo, suiteID...) 43 labeledInfo = append(labeledInfo, label...) 44 labeledInfo = append(labeledInfo, info...) 45 out := make([]byte, length) 46 n, err := hkdf.Expand(kdf.hash.New, randomKey, labeledInfo).Read(out) 47 if err != nil || n != int(length) { 48 panic("hpke: LabeledExpand failed unexpectedly") 49 } 50 return out 51} 52 53// dhKEM implements the KEM specified in RFC 9180, Section 4.1. 54type dhKEM struct { 55 dh ecdh.Curve 56 kdf hkdfKDF 57 58 suiteID []byte 59 nSecret uint16 60} 61 62var SupportedKEMs = map[uint16]struct { 63 curve ecdh.Curve 64 hash crypto.Hash 65 nSecret uint16 66}{ 67 // RFC 9180 Section 7.1 68 0x0020: {ecdh.X25519(), crypto.SHA256, 32}, 69} 70 71func newDHKem(kemID uint16) (*dhKEM, error) { 72 suite, ok := SupportedKEMs[kemID] 73 if !ok { 74 return nil, errors.New("unsupported suite ID") 75 } 76 return &dhKEM{ 77 dh: suite.curve, 78 kdf: hkdfKDF{suite.hash}, 79 suiteID: binary.BigEndian.AppendUint16([]byte("KEM"), kemID), 80 nSecret: suite.nSecret, 81 }, nil 82} 83 84func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte { 85 eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey) 86 return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret) 87} 88 89func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) { 90 var privEph *ecdh.PrivateKey 91 if testingOnlyGenerateKey != nil { 92 privEph, err = testingOnlyGenerateKey() 93 } else { 94 privEph, err = dh.dh.GenerateKey(rand.Reader) 95 } 96 if err != nil { 97 return nil, nil, err 98 } 99 dhVal, err := privEph.ECDH(pubRecipient) 100 if err != nil { 101 return nil, nil, err 102 } 103 encPubEph := privEph.PublicKey().Bytes() 104 105 encPubRecip := pubRecipient.Bytes() 106 kemContext := append(encPubEph, encPubRecip...) 107 108 return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil 109} 110 111type Sender struct { 112 aead cipher.AEAD 113 kem *dhKEM 114 115 sharedSecret []byte 116 117 suiteID []byte 118 119 key []byte 120 baseNonce []byte 121 exporterSecret []byte 122 123 seqNum uint128 124} 125 126var aesGCMNew = func(key []byte) (cipher.AEAD, error) { 127 block, err := aes.NewCipher(key) 128 if err != nil { 129 return nil, err 130 } 131 return cipher.NewGCM(block) 132} 133 134var SupportedAEADs = map[uint16]struct { 135 keySize int 136 nonceSize int 137 aead func([]byte) (cipher.AEAD, error) 138}{ 139 // RFC 9180, Section 7.3 140 0x0001: {keySize: 16, nonceSize: 12, aead: aesGCMNew}, 141 0x0002: {keySize: 32, nonceSize: 12, aead: aesGCMNew}, 142 0x0003: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New}, 143} 144 145var SupportedKDFs = map[uint16]func() *hkdfKDF{ 146 // RFC 9180, Section 7.2 147 0x0001: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} }, 148} 149 150func SetupSender(kemID, kdfID, aeadID uint16, pub crypto.PublicKey, info []byte) ([]byte, *Sender, error) { 151 suiteID := SuiteID(kemID, kdfID, aeadID) 152 153 kem, err := newDHKem(kemID) 154 if err != nil { 155 return nil, nil, err 156 } 157 pubRecipient, ok := pub.(*ecdh.PublicKey) 158 if !ok { 159 return nil, nil, errors.New("incorrect public key type") 160 } 161 sharedSecret, encapsulatedKey, err := kem.Encap(pubRecipient) 162 if err != nil { 163 return nil, nil, err 164 } 165 166 kdfInit, ok := SupportedKDFs[kdfID] 167 if !ok { 168 return nil, nil, errors.New("unsupported KDF id") 169 } 170 kdf := kdfInit() 171 172 aeadInfo, ok := SupportedAEADs[aeadID] 173 if !ok { 174 return nil, nil, errors.New("unsupported AEAD id") 175 } 176 177 pskIDHash := kdf.LabeledExtract(suiteID, nil, "psk_id_hash", nil) 178 infoHash := kdf.LabeledExtract(suiteID, nil, "info_hash", info) 179 ksContext := append([]byte{0}, pskIDHash...) 180 ksContext = append(ksContext, infoHash...) 181 182 secret := kdf.LabeledExtract(suiteID, sharedSecret, "secret", nil) 183 184 key := kdf.LabeledExpand(suiteID, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */) 185 baseNonce := kdf.LabeledExpand(suiteID, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */) 186 exporterSecret := kdf.LabeledExpand(suiteID, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/) 187 188 aead, err := aeadInfo.aead(key) 189 if err != nil { 190 return nil, nil, err 191 } 192 193 return encapsulatedKey, &Sender{ 194 kem: kem, 195 aead: aead, 196 sharedSecret: sharedSecret, 197 suiteID: suiteID, 198 key: key, 199 baseNonce: baseNonce, 200 exporterSecret: exporterSecret, 201 }, nil 202} 203 204func (s *Sender) nextNonce() []byte { 205 nonce := s.seqNum.bytes()[16-s.aead.NonceSize():] 206 for i := range s.baseNonce { 207 nonce[i] ^= s.baseNonce[i] 208 } 209 // Message limit is, according to the RFC, 2^95+1, which 210 // is somewhat confusing, but we do as we're told. 211 if s.seqNum.bitLen() >= (s.aead.NonceSize()*8)-1 { 212 panic("message limit reached") 213 } 214 s.seqNum = s.seqNum.addOne() 215 return nonce 216} 217 218func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) { 219 220 ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad) 221 return ciphertext, nil 222} 223 224func SuiteID(kemID, kdfID, aeadID uint16) []byte { 225 suiteID := make([]byte, 0, 4+2+2+2) 226 suiteID = append(suiteID, []byte("HPKE")...) 227 suiteID = binary.BigEndian.AppendUint16(suiteID, kemID) 228 suiteID = binary.BigEndian.AppendUint16(suiteID, kdfID) 229 suiteID = binary.BigEndian.AppendUint16(suiteID, aeadID) 230 return suiteID 231} 232 233func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) { 234 kemInfo, ok := SupportedKEMs[kemID] 235 if !ok { 236 return nil, errors.New("unsupported KEM id") 237 } 238 return kemInfo.curve.NewPublicKey(bytes) 239} 240 241type uint128 struct { 242 hi, lo uint64 243} 244 245func (u uint128) addOne() uint128 { 246 lo, carry := bits.Add64(u.lo, 1, 0) 247 return uint128{u.hi + carry, lo} 248} 249 250func (u uint128) bitLen() int { 251 return bits.Len64(u.hi) + bits.Len64(u.lo) 252} 253 254func (u uint128) bytes() []byte { 255 b := make([]byte, 16) 256 binary.BigEndian.PutUint64(b[0:], u.hi) 257 binary.BigEndian.PutUint64(b[8:], u.lo) 258 return b 259} 260