1// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package hpke
6
7import (
8	"crypto"
9	"crypto/aes"
10	"crypto/cipher"
11	"crypto/ecdh"
12	"crypto/rand"
13	"encoding/binary"
14	"errors"
15	"math/bits"
16
17	"golang.org/x/crypto/chacha20poly1305"
18	"golang.org/x/crypto/hkdf"
19)
20
21// testingOnlyGenerateKey is only used during testing, to provide
22// a fixed test key to use when checking the RFC 9180 vectors.
23var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)
24
25type hkdfKDF struct {
26	hash crypto.Hash
27}
28
29func (kdf *hkdfKDF) LabeledExtract(suiteID []byte, salt []byte, label string, inputKey []byte) []byte {
30	labeledIKM := make([]byte, 0, 7+len(suiteID)+len(label)+len(inputKey))
31	labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
32	labeledIKM = append(labeledIKM, suiteID...)
33	labeledIKM = append(labeledIKM, label...)
34	labeledIKM = append(labeledIKM, inputKey...)
35	return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
36}
37
38func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte {
39	labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
40	labeledInfo = binary.BigEndian.AppendUint16(labeledInfo, length)
41	labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
42	labeledInfo = append(labeledInfo, suiteID...)
43	labeledInfo = append(labeledInfo, label...)
44	labeledInfo = append(labeledInfo, info...)
45	out := make([]byte, length)
46	n, err := hkdf.Expand(kdf.hash.New, randomKey, labeledInfo).Read(out)
47	if err != nil || n != int(length) {
48		panic("hpke: LabeledExpand failed unexpectedly")
49	}
50	return out
51}
52
53// dhKEM implements the KEM specified in RFC 9180, Section 4.1.
54type dhKEM struct {
55	dh  ecdh.Curve
56	kdf hkdfKDF
57
58	suiteID []byte
59	nSecret uint16
60}
61
62var SupportedKEMs = map[uint16]struct {
63	curve   ecdh.Curve
64	hash    crypto.Hash
65	nSecret uint16
66}{
67	// RFC 9180 Section 7.1
68	0x0020: {ecdh.X25519(), crypto.SHA256, 32},
69}
70
71func newDHKem(kemID uint16) (*dhKEM, error) {
72	suite, ok := SupportedKEMs[kemID]
73	if !ok {
74		return nil, errors.New("unsupported suite ID")
75	}
76	return &dhKEM{
77		dh:      suite.curve,
78		kdf:     hkdfKDF{suite.hash},
79		suiteID: binary.BigEndian.AppendUint16([]byte("KEM"), kemID),
80		nSecret: suite.nSecret,
81	}, nil
82}
83
84func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte {
85	eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
86	return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
87}
88
89func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) {
90	var privEph *ecdh.PrivateKey
91	if testingOnlyGenerateKey != nil {
92		privEph, err = testingOnlyGenerateKey()
93	} else {
94		privEph, err = dh.dh.GenerateKey(rand.Reader)
95	}
96	if err != nil {
97		return nil, nil, err
98	}
99	dhVal, err := privEph.ECDH(pubRecipient)
100	if err != nil {
101		return nil, nil, err
102	}
103	encPubEph := privEph.PublicKey().Bytes()
104
105	encPubRecip := pubRecipient.Bytes()
106	kemContext := append(encPubEph, encPubRecip...)
107
108	return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil
109}
110
111type Sender struct {
112	aead cipher.AEAD
113	kem  *dhKEM
114
115	sharedSecret []byte
116
117	suiteID []byte
118
119	key            []byte
120	baseNonce      []byte
121	exporterSecret []byte
122
123	seqNum uint128
124}
125
126var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
127	block, err := aes.NewCipher(key)
128	if err != nil {
129		return nil, err
130	}
131	return cipher.NewGCM(block)
132}
133
134var SupportedAEADs = map[uint16]struct {
135	keySize   int
136	nonceSize int
137	aead      func([]byte) (cipher.AEAD, error)
138}{
139	// RFC 9180, Section 7.3
140	0x0001: {keySize: 16, nonceSize: 12, aead: aesGCMNew},
141	0x0002: {keySize: 32, nonceSize: 12, aead: aesGCMNew},
142	0x0003: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
143}
144
145var SupportedKDFs = map[uint16]func() *hkdfKDF{
146	// RFC 9180, Section 7.2
147	0x0001: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
148}
149
150func SetupSender(kemID, kdfID, aeadID uint16, pub crypto.PublicKey, info []byte) ([]byte, *Sender, error) {
151	suiteID := SuiteID(kemID, kdfID, aeadID)
152
153	kem, err := newDHKem(kemID)
154	if err != nil {
155		return nil, nil, err
156	}
157	pubRecipient, ok := pub.(*ecdh.PublicKey)
158	if !ok {
159		return nil, nil, errors.New("incorrect public key type")
160	}
161	sharedSecret, encapsulatedKey, err := kem.Encap(pubRecipient)
162	if err != nil {
163		return nil, nil, err
164	}
165
166	kdfInit, ok := SupportedKDFs[kdfID]
167	if !ok {
168		return nil, nil, errors.New("unsupported KDF id")
169	}
170	kdf := kdfInit()
171
172	aeadInfo, ok := SupportedAEADs[aeadID]
173	if !ok {
174		return nil, nil, errors.New("unsupported AEAD id")
175	}
176
177	pskIDHash := kdf.LabeledExtract(suiteID, nil, "psk_id_hash", nil)
178	infoHash := kdf.LabeledExtract(suiteID, nil, "info_hash", info)
179	ksContext := append([]byte{0}, pskIDHash...)
180	ksContext = append(ksContext, infoHash...)
181
182	secret := kdf.LabeledExtract(suiteID, sharedSecret, "secret", nil)
183
184	key := kdf.LabeledExpand(suiteID, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
185	baseNonce := kdf.LabeledExpand(suiteID, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
186	exporterSecret := kdf.LabeledExpand(suiteID, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
187
188	aead, err := aeadInfo.aead(key)
189	if err != nil {
190		return nil, nil, err
191	}
192
193	return encapsulatedKey, &Sender{
194		kem:            kem,
195		aead:           aead,
196		sharedSecret:   sharedSecret,
197		suiteID:        suiteID,
198		key:            key,
199		baseNonce:      baseNonce,
200		exporterSecret: exporterSecret,
201	}, nil
202}
203
204func (s *Sender) nextNonce() []byte {
205	nonce := s.seqNum.bytes()[16-s.aead.NonceSize():]
206	for i := range s.baseNonce {
207		nonce[i] ^= s.baseNonce[i]
208	}
209	// Message limit is, according to the RFC, 2^95+1, which
210	// is somewhat confusing, but we do as we're told.
211	if s.seqNum.bitLen() >= (s.aead.NonceSize()*8)-1 {
212		panic("message limit reached")
213	}
214	s.seqNum = s.seqNum.addOne()
215	return nonce
216}
217
218func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
219
220	ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
221	return ciphertext, nil
222}
223
224func SuiteID(kemID, kdfID, aeadID uint16) []byte {
225	suiteID := make([]byte, 0, 4+2+2+2)
226	suiteID = append(suiteID, []byte("HPKE")...)
227	suiteID = binary.BigEndian.AppendUint16(suiteID, kemID)
228	suiteID = binary.BigEndian.AppendUint16(suiteID, kdfID)
229	suiteID = binary.BigEndian.AppendUint16(suiteID, aeadID)
230	return suiteID
231}
232
233func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
234	kemInfo, ok := SupportedKEMs[kemID]
235	if !ok {
236		return nil, errors.New("unsupported KEM id")
237	}
238	return kemInfo.curve.NewPublicKey(bytes)
239}
240
241type uint128 struct {
242	hi, lo uint64
243}
244
245func (u uint128) addOne() uint128 {
246	lo, carry := bits.Add64(u.lo, 1, 0)
247	return uint128{u.hi + carry, lo}
248}
249
250func (u uint128) bitLen() int {
251	return bits.Len64(u.hi) + bits.Len64(u.lo)
252}
253
254func (u uint128) bytes() []byte {
255	b := make([]byte, 16)
256	binary.BigEndian.PutUint64(b[0:], u.hi)
257	binary.BigEndian.PutUint64(b[8:], u.lo)
258	return b
259}
260