1// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"crypto/internal/hpke"
9	"errors"
10	"strings"
11
12	"golang.org/x/crypto/cryptobyte"
13)
14
15type echCipher struct {
16	KDFID  uint16
17	AEADID uint16
18}
19
20type echExtension struct {
21	Type uint16
22	Data []byte
23}
24
25type echConfig struct {
26	raw []byte
27
28	Version uint16
29	Length  uint16
30
31	ConfigID             uint8
32	KemID                uint16
33	PublicKey            []byte
34	SymmetricCipherSuite []echCipher
35
36	MaxNameLength uint8
37	PublicName    []byte
38	Extensions    []echExtension
39}
40
41var errMalformedECHConfig = errors.New("tls: malformed ECHConfigList")
42
43// parseECHConfigList parses a draft-ietf-tls-esni-18 ECHConfigList, returning a
44// slice of parsed ECHConfigs, in the same order they were parsed, or an error
45// if the list is malformed.
46func parseECHConfigList(data []byte) ([]echConfig, error) {
47	s := cryptobyte.String(data)
48	// Skip the length prefix
49	var length uint16
50	if !s.ReadUint16(&length) {
51		return nil, errMalformedECHConfig
52	}
53	if length != uint16(len(data)-2) {
54		return nil, errMalformedECHConfig
55	}
56	var configs []echConfig
57	for len(s) > 0 {
58		var ec echConfig
59		ec.raw = []byte(s)
60		if !s.ReadUint16(&ec.Version) {
61			return nil, errMalformedECHConfig
62		}
63		if !s.ReadUint16(&ec.Length) {
64			return nil, errMalformedECHConfig
65		}
66		if len(ec.raw) < int(ec.Length)+4 {
67			return nil, errMalformedECHConfig
68		}
69		ec.raw = ec.raw[:ec.Length+4]
70		if ec.Version != extensionEncryptedClientHello {
71			s.Skip(int(ec.Length))
72			continue
73		}
74		if !s.ReadUint8(&ec.ConfigID) {
75			return nil, errMalformedECHConfig
76		}
77		if !s.ReadUint16(&ec.KemID) {
78			return nil, errMalformedECHConfig
79		}
80		if !s.ReadUint16LengthPrefixed((*cryptobyte.String)(&ec.PublicKey)) {
81			return nil, errMalformedECHConfig
82		}
83		var cipherSuites cryptobyte.String
84		if !s.ReadUint16LengthPrefixed(&cipherSuites) {
85			return nil, errMalformedECHConfig
86		}
87		for !cipherSuites.Empty() {
88			var c echCipher
89			if !cipherSuites.ReadUint16(&c.KDFID) {
90				return nil, errMalformedECHConfig
91			}
92			if !cipherSuites.ReadUint16(&c.AEADID) {
93				return nil, errMalformedECHConfig
94			}
95			ec.SymmetricCipherSuite = append(ec.SymmetricCipherSuite, c)
96		}
97		if !s.ReadUint8(&ec.MaxNameLength) {
98			return nil, errMalformedECHConfig
99		}
100		var publicName cryptobyte.String
101		if !s.ReadUint8LengthPrefixed(&publicName) {
102			return nil, errMalformedECHConfig
103		}
104		ec.PublicName = publicName
105		var extensions cryptobyte.String
106		if !s.ReadUint16LengthPrefixed(&extensions) {
107			return nil, errMalformedECHConfig
108		}
109		for !extensions.Empty() {
110			var e echExtension
111			if !extensions.ReadUint16(&e.Type) {
112				return nil, errMalformedECHConfig
113			}
114			if !extensions.ReadUint16LengthPrefixed((*cryptobyte.String)(&e.Data)) {
115				return nil, errMalformedECHConfig
116			}
117			ec.Extensions = append(ec.Extensions, e)
118		}
119
120		configs = append(configs, ec)
121	}
122	return configs, nil
123}
124
125func pickECHConfig(list []echConfig) *echConfig {
126	for _, ec := range list {
127		if _, ok := hpke.SupportedKEMs[ec.KemID]; !ok {
128			continue
129		}
130		var validSCS bool
131		for _, cs := range ec.SymmetricCipherSuite {
132			if _, ok := hpke.SupportedAEADs[cs.AEADID]; !ok {
133				continue
134			}
135			if _, ok := hpke.SupportedKDFs[cs.KDFID]; !ok {
136				continue
137			}
138			validSCS = true
139			break
140		}
141		if !validSCS {
142			continue
143		}
144		if !validDNSName(string(ec.PublicName)) {
145			continue
146		}
147		var unsupportedExt bool
148		for _, ext := range ec.Extensions {
149			// If high order bit is set to 1 the extension is mandatory.
150			// Since we don't support any extensions, if we see a mandatory
151			// bit, we skip the config.
152			if ext.Type&uint16(1<<15) != 0 {
153				unsupportedExt = true
154			}
155		}
156		if unsupportedExt {
157			continue
158		}
159		return &ec
160	}
161	return nil
162}
163
164func pickECHCipherSuite(suites []echCipher) (echCipher, error) {
165	for _, s := range suites {
166		// NOTE: all of the supported AEADs and KDFs are fine, rather than
167		// imposing some sort of preference here, we just pick the first valid
168		// suite.
169		if _, ok := hpke.SupportedAEADs[s.AEADID]; !ok {
170			continue
171		}
172		if _, ok := hpke.SupportedKDFs[s.KDFID]; !ok {
173			continue
174		}
175		return s, nil
176	}
177	return echCipher{}, errors.New("tls: no supported symmetric ciphersuites for ECH")
178}
179
180func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, error) {
181	h, err := inner.marshalMsg(true)
182	if err != nil {
183		return nil, err
184	}
185	h = h[4:] // strip four byte prefix
186
187	var paddingLen int
188	if inner.serverName != "" {
189		paddingLen = max(0, maxNameLength-len(inner.serverName))
190	} else {
191		paddingLen = maxNameLength + 9
192	}
193	paddingLen = 31 - ((len(h) + paddingLen - 1) % 32)
194
195	return append(h, make([]byte, paddingLen)...), nil
196}
197
198func generateOuterECHExt(id uint8, kdfID, aeadID uint16, encodedKey []byte, payload []byte) ([]byte, error) {
199	var b cryptobyte.Builder
200	b.AddUint8(0) // outer
201	b.AddUint16(kdfID)
202	b.AddUint16(aeadID)
203	b.AddUint8(id)
204	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(encodedKey) })
205	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(payload) })
206	return b.Bytes()
207}
208
209func computeAndUpdateOuterECHExtension(outer, inner *clientHelloMsg, ech *echContext, useKey bool) error {
210	var encapKey []byte
211	if useKey {
212		encapKey = ech.encapsulatedKey
213	}
214	encodedInner, err := encodeInnerClientHello(inner, int(ech.config.MaxNameLength))
215	if err != nil {
216		return err
217	}
218	// NOTE: the tag lengths for all of the supported AEADs are the same (16
219	// bytes), so we have hardcoded it here. If we add support for another AEAD
220	// with a different tag length, we will need to change this.
221	encryptedLen := len(encodedInner) + 16 // AEAD tag length
222	outer.encryptedClientHello, err = generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, make([]byte, encryptedLen))
223	if err != nil {
224		return err
225	}
226	serializedOuter, err := outer.marshal()
227	if err != nil {
228		return err
229	}
230	serializedOuter = serializedOuter[4:] // strip the four byte prefix
231	encryptedInner, err := ech.hpkeContext.Seal(serializedOuter, encodedInner)
232	if err != nil {
233		return err
234	}
235	outer.encryptedClientHello, err = generateOuterECHExt(ech.config.ConfigID, ech.kdfID, ech.aeadID, encapKey, encryptedInner)
236	if err != nil {
237		return err
238	}
239	return nil
240}
241
242// validDNSName is a rather rudimentary check for the validity of a DNS name.
243// This is used to check if the public_name in a ECHConfig is valid when we are
244// picking a config. This can be somewhat lax because even if we pick a
245// valid-looking name, the DNS layer will later reject it anyway.
246func validDNSName(name string) bool {
247	if len(name) > 253 {
248		return false
249	}
250	labels := strings.Split(name, ".")
251	if len(labels) <= 1 {
252		return false
253	}
254	for _, l := range labels {
255		labelLen := len(l)
256		if labelLen == 0 {
257			return false
258		}
259		for i, r := range l {
260			if r == '-' && (i == 0 || i == labelLen-1) {
261				return false
262			}
263			if (r < '0' || r > '9') && (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && r != '-' {
264				return false
265			}
266		}
267	}
268	return true
269}
270
271// ECHRejectionError is the error type returned when ECH is rejected by a remote
272// server. If the server offered a ECHConfigList to use for retries, the
273// RetryConfigList field will contain this list.
274//
275// The client may treat an ECHRejectionError with an empty set of RetryConfigs
276// as a secure signal from the server.
277type ECHRejectionError struct {
278	RetryConfigList []byte
279}
280
281func (e *ECHRejectionError) Error() string {
282	return "tls: server rejected ECH"
283}
284