1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"errors"
9	"fmt"
10	"slices"
11	"strings"
12
13	"golang.org/x/crypto/cryptobyte"
14)
15
16// The marshalingFunction type is an adapter to allow the use of ordinary
17// functions as cryptobyte.MarshalingValue.
18type marshalingFunction func(b *cryptobyte.Builder) error
19
20func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
21	return f(b)
22}
23
24// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
25// the length of the sequence is not the value specified, it produces an error.
26func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
27	b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
28		if len(v) != n {
29			return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
30		}
31		b.AddBytes(v)
32		return nil
33	}))
34}
35
36// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
37func addUint64(b *cryptobyte.Builder, v uint64) {
38	b.AddUint32(uint32(v >> 32))
39	b.AddUint32(uint32(v))
40}
41
42// readUint64 decodes a big-endian, 64-bit value into out and advances over it.
43// It reports whether the read was successful.
44func readUint64(s *cryptobyte.String, out *uint64) bool {
45	var hi, lo uint32
46	if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
47		return false
48	}
49	*out = uint64(hi)<<32 | uint64(lo)
50	return true
51}
52
53// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
54// []byte instead of a cryptobyte.String.
55func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
56	return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
57}
58
59// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
60// []byte instead of a cryptobyte.String.
61func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
62	return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
63}
64
65// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
66// []byte instead of a cryptobyte.String.
67func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
68	return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
69}
70
71type clientHelloMsg struct {
72	original                         []byte
73	vers                             uint16
74	random                           []byte
75	sessionId                        []byte
76	cipherSuites                     []uint16
77	compressionMethods               []uint8
78	serverName                       string
79	ocspStapling                     bool
80	supportedCurves                  []CurveID
81	supportedPoints                  []uint8
82	ticketSupported                  bool
83	sessionTicket                    []uint8
84	supportedSignatureAlgorithms     []SignatureScheme
85	supportedSignatureAlgorithmsCert []SignatureScheme
86	secureRenegotiationSupported     bool
87	secureRenegotiation              []byte
88	extendedMasterSecret             bool
89	alpnProtocols                    []string
90	scts                             bool
91	supportedVersions                []uint16
92	cookie                           []byte
93	keyShares                        []keyShare
94	earlyData                        bool
95	pskModes                         []uint8
96	pskIdentities                    []pskIdentity
97	pskBinders                       [][]byte
98	quicTransportParameters          []byte
99	encryptedClientHello             []byte
100}
101
102func (m *clientHelloMsg) marshalMsg(echInner bool) ([]byte, error) {
103	var exts cryptobyte.Builder
104	if len(m.serverName) > 0 {
105		// RFC 6066, Section 3
106		exts.AddUint16(extensionServerName)
107		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
108			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
109				exts.AddUint8(0) // name_type = host_name
110				exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
111					exts.AddBytes([]byte(m.serverName))
112				})
113			})
114		})
115	}
116	if len(m.supportedPoints) > 0 && !echInner {
117		// RFC 4492, Section 5.1.2
118		exts.AddUint16(extensionSupportedPoints)
119		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
120			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
121				exts.AddBytes(m.supportedPoints)
122			})
123		})
124	}
125	if m.ticketSupported && !echInner {
126		// RFC 5077, Section 3.2
127		exts.AddUint16(extensionSessionTicket)
128		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
129			exts.AddBytes(m.sessionTicket)
130		})
131	}
132	if m.secureRenegotiationSupported && !echInner {
133		// RFC 5746, Section 3.2
134		exts.AddUint16(extensionRenegotiationInfo)
135		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
136			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
137				exts.AddBytes(m.secureRenegotiation)
138			})
139		})
140	}
141	if m.extendedMasterSecret && !echInner {
142		// RFC 7627
143		exts.AddUint16(extensionExtendedMasterSecret)
144		exts.AddUint16(0) // empty extension_data
145	}
146	if m.scts {
147		// RFC 6962, Section 3.3.1
148		exts.AddUint16(extensionSCT)
149		exts.AddUint16(0) // empty extension_data
150	}
151	if m.earlyData {
152		// RFC 8446, Section 4.2.10
153		exts.AddUint16(extensionEarlyData)
154		exts.AddUint16(0) // empty extension_data
155	}
156	if m.quicTransportParameters != nil { // marshal zero-length parameters when present
157		// RFC 9001, Section 8.2
158		exts.AddUint16(extensionQUICTransportParameters)
159		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
160			exts.AddBytes(m.quicTransportParameters)
161		})
162	}
163	if len(m.encryptedClientHello) > 0 {
164		exts.AddUint16(extensionEncryptedClientHello)
165		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
166			exts.AddBytes(m.encryptedClientHello)
167		})
168	}
169	// Note that any extension that can be compressed during ECH must be
170	// contiguous. If any additional extensions are to be compressed they must
171	// be added to the following block, so that they can be properly
172	// decompressed on the other side.
173	var echOuterExts []uint16
174	if m.ocspStapling {
175		// RFC 4366, Section 3.6
176		if echInner {
177			echOuterExts = append(echOuterExts, extensionStatusRequest)
178		} else {
179			exts.AddUint16(extensionStatusRequest)
180			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
181				exts.AddUint8(1)  // status_type = ocsp
182				exts.AddUint16(0) // empty responder_id_list
183				exts.AddUint16(0) // empty request_extensions
184			})
185		}
186	}
187	if len(m.supportedCurves) > 0 {
188		// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
189		if echInner {
190			echOuterExts = append(echOuterExts, extensionSupportedCurves)
191		} else {
192			exts.AddUint16(extensionSupportedCurves)
193			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
194				exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
195					for _, curve := range m.supportedCurves {
196						exts.AddUint16(uint16(curve))
197					}
198				})
199			})
200		}
201	}
202	if len(m.supportedSignatureAlgorithms) > 0 {
203		// RFC 5246, Section 7.4.1.4.1
204		if echInner {
205			echOuterExts = append(echOuterExts, extensionSignatureAlgorithms)
206		} else {
207			exts.AddUint16(extensionSignatureAlgorithms)
208			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
209				exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
210					for _, sigAlgo := range m.supportedSignatureAlgorithms {
211						exts.AddUint16(uint16(sigAlgo))
212					}
213				})
214			})
215		}
216	}
217	if len(m.supportedSignatureAlgorithmsCert) > 0 {
218		// RFC 8446, Section 4.2.3
219		if echInner {
220			echOuterExts = append(echOuterExts, extensionSignatureAlgorithmsCert)
221		} else {
222			exts.AddUint16(extensionSignatureAlgorithmsCert)
223			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
224				exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
225					for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
226						exts.AddUint16(uint16(sigAlgo))
227					}
228				})
229			})
230		}
231	}
232	if len(m.alpnProtocols) > 0 {
233		// RFC 7301, Section 3.1
234		if echInner {
235			echOuterExts = append(echOuterExts, extensionALPN)
236		} else {
237			exts.AddUint16(extensionALPN)
238			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
239				exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
240					for _, proto := range m.alpnProtocols {
241						exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
242							exts.AddBytes([]byte(proto))
243						})
244					}
245				})
246			})
247		}
248	}
249	if len(m.supportedVersions) > 0 {
250		// RFC 8446, Section 4.2.1
251		if echInner {
252			echOuterExts = append(echOuterExts, extensionSupportedVersions)
253		} else {
254			exts.AddUint16(extensionSupportedVersions)
255			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
256				exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
257					for _, vers := range m.supportedVersions {
258						exts.AddUint16(vers)
259					}
260				})
261			})
262		}
263	}
264	if len(m.cookie) > 0 {
265		// RFC 8446, Section 4.2.2
266		if echInner {
267			echOuterExts = append(echOuterExts, extensionCookie)
268		} else {
269			exts.AddUint16(extensionCookie)
270			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
271				exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
272					exts.AddBytes(m.cookie)
273				})
274			})
275		}
276	}
277	if len(m.keyShares) > 0 {
278		// RFC 8446, Section 4.2.8
279		if echInner {
280			echOuterExts = append(echOuterExts, extensionKeyShare)
281		} else {
282			exts.AddUint16(extensionKeyShare)
283			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
284				exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
285					for _, ks := range m.keyShares {
286						exts.AddUint16(uint16(ks.group))
287						exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
288							exts.AddBytes(ks.data)
289						})
290					}
291				})
292			})
293		}
294	}
295	if len(m.pskModes) > 0 {
296		// RFC 8446, Section 4.2.9
297		if echInner {
298			echOuterExts = append(echOuterExts, extensionPSKModes)
299		} else {
300			exts.AddUint16(extensionPSKModes)
301			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
302				exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
303					exts.AddBytes(m.pskModes)
304				})
305			})
306		}
307	}
308	if len(echOuterExts) > 0 && echInner {
309		exts.AddUint16(extensionECHOuterExtensions)
310		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
311			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
312				for _, e := range echOuterExts {
313					exts.AddUint16(e)
314				}
315			})
316		})
317	}
318	if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
319		// RFC 8446, Section 4.2.11
320		exts.AddUint16(extensionPreSharedKey)
321		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
322			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
323				for _, psk := range m.pskIdentities {
324					exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
325						exts.AddBytes(psk.label)
326					})
327					exts.AddUint32(psk.obfuscatedTicketAge)
328				}
329			})
330			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
331				for _, binder := range m.pskBinders {
332					exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
333						exts.AddBytes(binder)
334					})
335				}
336			})
337		})
338	}
339	extBytes, err := exts.Bytes()
340	if err != nil {
341		return nil, err
342	}
343
344	var b cryptobyte.Builder
345	b.AddUint8(typeClientHello)
346	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
347		b.AddUint16(m.vers)
348		addBytesWithLength(b, m.random, 32)
349		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
350			if !echInner {
351				b.AddBytes(m.sessionId)
352			}
353		})
354		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
355			for _, suite := range m.cipherSuites {
356				b.AddUint16(suite)
357			}
358		})
359		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
360			b.AddBytes(m.compressionMethods)
361		})
362
363		if len(extBytes) > 0 {
364			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
365				b.AddBytes(extBytes)
366			})
367		}
368	})
369
370	return b.Bytes()
371}
372
373func (m *clientHelloMsg) marshal() ([]byte, error) {
374	return m.marshalMsg(false)
375}
376
377// marshalWithoutBinders returns the ClientHello through the
378// PreSharedKeyExtension.identities field, according to RFC 8446, Section
379// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
380func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
381	bindersLen := 2 // uint16 length prefix
382	for _, binder := range m.pskBinders {
383		bindersLen += 1 // uint8 length prefix
384		bindersLen += len(binder)
385	}
386
387	var fullMessage []byte
388	if m.original != nil {
389		fullMessage = m.original
390	} else {
391		var err error
392		fullMessage, err = m.marshal()
393		if err != nil {
394			return nil, err
395		}
396	}
397	return fullMessage[:len(fullMessage)-bindersLen], nil
398}
399
400// updateBinders updates the m.pskBinders field. The supplied binders must have
401// the same length as the current m.pskBinders.
402func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
403	if len(pskBinders) != len(m.pskBinders) {
404		return errors.New("tls: internal error: pskBinders length mismatch")
405	}
406	for i := range m.pskBinders {
407		if len(pskBinders[i]) != len(m.pskBinders[i]) {
408			return errors.New("tls: internal error: pskBinders length mismatch")
409		}
410	}
411	m.pskBinders = pskBinders
412
413	return nil
414}
415
416func (m *clientHelloMsg) unmarshal(data []byte) bool {
417	*m = clientHelloMsg{original: data}
418	s := cryptobyte.String(data)
419
420	if !s.Skip(4) || // message type and uint24 length field
421		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
422		!readUint8LengthPrefixed(&s, &m.sessionId) {
423		return false
424	}
425
426	var cipherSuites cryptobyte.String
427	if !s.ReadUint16LengthPrefixed(&cipherSuites) {
428		return false
429	}
430	m.cipherSuites = []uint16{}
431	m.secureRenegotiationSupported = false
432	for !cipherSuites.Empty() {
433		var suite uint16
434		if !cipherSuites.ReadUint16(&suite) {
435			return false
436		}
437		if suite == scsvRenegotiation {
438			m.secureRenegotiationSupported = true
439		}
440		m.cipherSuites = append(m.cipherSuites, suite)
441	}
442
443	if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
444		return false
445	}
446
447	if s.Empty() {
448		// ClientHello is optionally followed by extension data
449		return true
450	}
451
452	var extensions cryptobyte.String
453	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
454		return false
455	}
456
457	seenExts := make(map[uint16]bool)
458	for !extensions.Empty() {
459		var extension uint16
460		var extData cryptobyte.String
461		if !extensions.ReadUint16(&extension) ||
462			!extensions.ReadUint16LengthPrefixed(&extData) {
463			return false
464		}
465
466		if seenExts[extension] {
467			return false
468		}
469		seenExts[extension] = true
470
471		switch extension {
472		case extensionServerName:
473			// RFC 6066, Section 3
474			var nameList cryptobyte.String
475			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
476				return false
477			}
478			for !nameList.Empty() {
479				var nameType uint8
480				var serverName cryptobyte.String
481				if !nameList.ReadUint8(&nameType) ||
482					!nameList.ReadUint16LengthPrefixed(&serverName) ||
483					serverName.Empty() {
484					return false
485				}
486				if nameType != 0 {
487					continue
488				}
489				if len(m.serverName) != 0 {
490					// Multiple names of the same name_type are prohibited.
491					return false
492				}
493				m.serverName = string(serverName)
494				// An SNI value may not include a trailing dot.
495				if strings.HasSuffix(m.serverName, ".") {
496					return false
497				}
498			}
499		case extensionStatusRequest:
500			// RFC 4366, Section 3.6
501			var statusType uint8
502			var ignored cryptobyte.String
503			if !extData.ReadUint8(&statusType) ||
504				!extData.ReadUint16LengthPrefixed(&ignored) ||
505				!extData.ReadUint16LengthPrefixed(&ignored) {
506				return false
507			}
508			m.ocspStapling = statusType == statusTypeOCSP
509		case extensionSupportedCurves:
510			// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
511			var curves cryptobyte.String
512			if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
513				return false
514			}
515			for !curves.Empty() {
516				var curve uint16
517				if !curves.ReadUint16(&curve) {
518					return false
519				}
520				m.supportedCurves = append(m.supportedCurves, CurveID(curve))
521			}
522		case extensionSupportedPoints:
523			// RFC 4492, Section 5.1.2
524			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
525				len(m.supportedPoints) == 0 {
526				return false
527			}
528		case extensionSessionTicket:
529			// RFC 5077, Section 3.2
530			m.ticketSupported = true
531			extData.ReadBytes(&m.sessionTicket, len(extData))
532		case extensionSignatureAlgorithms:
533			// RFC 5246, Section 7.4.1.4.1
534			var sigAndAlgs cryptobyte.String
535			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
536				return false
537			}
538			for !sigAndAlgs.Empty() {
539				var sigAndAlg uint16
540				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
541					return false
542				}
543				m.supportedSignatureAlgorithms = append(
544					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
545			}
546		case extensionSignatureAlgorithmsCert:
547			// RFC 8446, Section 4.2.3
548			var sigAndAlgs cryptobyte.String
549			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
550				return false
551			}
552			for !sigAndAlgs.Empty() {
553				var sigAndAlg uint16
554				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
555					return false
556				}
557				m.supportedSignatureAlgorithmsCert = append(
558					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
559			}
560		case extensionRenegotiationInfo:
561			// RFC 5746, Section 3.2
562			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
563				return false
564			}
565			m.secureRenegotiationSupported = true
566		case extensionExtendedMasterSecret:
567			// RFC 7627
568			m.extendedMasterSecret = true
569		case extensionALPN:
570			// RFC 7301, Section 3.1
571			var protoList cryptobyte.String
572			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
573				return false
574			}
575			for !protoList.Empty() {
576				var proto cryptobyte.String
577				if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
578					return false
579				}
580				m.alpnProtocols = append(m.alpnProtocols, string(proto))
581			}
582		case extensionSCT:
583			// RFC 6962, Section 3.3.1
584			m.scts = true
585		case extensionSupportedVersions:
586			// RFC 8446, Section 4.2.1
587			var versList cryptobyte.String
588			if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
589				return false
590			}
591			for !versList.Empty() {
592				var vers uint16
593				if !versList.ReadUint16(&vers) {
594					return false
595				}
596				m.supportedVersions = append(m.supportedVersions, vers)
597			}
598		case extensionCookie:
599			// RFC 8446, Section 4.2.2
600			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
601				len(m.cookie) == 0 {
602				return false
603			}
604		case extensionKeyShare:
605			// RFC 8446, Section 4.2.8
606			var clientShares cryptobyte.String
607			if !extData.ReadUint16LengthPrefixed(&clientShares) {
608				return false
609			}
610			for !clientShares.Empty() {
611				var ks keyShare
612				if !clientShares.ReadUint16((*uint16)(&ks.group)) ||
613					!readUint16LengthPrefixed(&clientShares, &ks.data) ||
614					len(ks.data) == 0 {
615					return false
616				}
617				m.keyShares = append(m.keyShares, ks)
618			}
619		case extensionEarlyData:
620			// RFC 8446, Section 4.2.10
621			m.earlyData = true
622		case extensionPSKModes:
623			// RFC 8446, Section 4.2.9
624			if !readUint8LengthPrefixed(&extData, &m.pskModes) {
625				return false
626			}
627		case extensionQUICTransportParameters:
628			m.quicTransportParameters = make([]byte, len(extData))
629			if !extData.CopyBytes(m.quicTransportParameters) {
630				return false
631			}
632		case extensionPreSharedKey:
633			// RFC 8446, Section 4.2.11
634			if !extensions.Empty() {
635				return false // pre_shared_key must be the last extension
636			}
637			var identities cryptobyte.String
638			if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() {
639				return false
640			}
641			for !identities.Empty() {
642				var psk pskIdentity
643				if !readUint16LengthPrefixed(&identities, &psk.label) ||
644					!identities.ReadUint32(&psk.obfuscatedTicketAge) ||
645					len(psk.label) == 0 {
646					return false
647				}
648				m.pskIdentities = append(m.pskIdentities, psk)
649			}
650			var binders cryptobyte.String
651			if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() {
652				return false
653			}
654			for !binders.Empty() {
655				var binder []byte
656				if !readUint8LengthPrefixed(&binders, &binder) ||
657					len(binder) == 0 {
658					return false
659				}
660				m.pskBinders = append(m.pskBinders, binder)
661			}
662		default:
663			// Ignore unknown extensions.
664			continue
665		}
666
667		if !extData.Empty() {
668			return false
669		}
670	}
671
672	return true
673}
674
675func (m *clientHelloMsg) originalBytes() []byte {
676	return m.original
677}
678
679func (m *clientHelloMsg) clone() *clientHelloMsg {
680	return &clientHelloMsg{
681		original:                         slices.Clone(m.original),
682		vers:                             m.vers,
683		random:                           slices.Clone(m.random),
684		sessionId:                        slices.Clone(m.sessionId),
685		cipherSuites:                     slices.Clone(m.cipherSuites),
686		compressionMethods:               slices.Clone(m.compressionMethods),
687		serverName:                       m.serverName,
688		ocspStapling:                     m.ocspStapling,
689		supportedCurves:                  slices.Clone(m.supportedCurves),
690		supportedPoints:                  slices.Clone(m.supportedPoints),
691		ticketSupported:                  m.ticketSupported,
692		sessionTicket:                    slices.Clone(m.sessionTicket),
693		supportedSignatureAlgorithms:     slices.Clone(m.supportedSignatureAlgorithms),
694		supportedSignatureAlgorithmsCert: slices.Clone(m.supportedSignatureAlgorithmsCert),
695		secureRenegotiationSupported:     m.secureRenegotiationSupported,
696		secureRenegotiation:              slices.Clone(m.secureRenegotiation),
697		extendedMasterSecret:             m.extendedMasterSecret,
698		alpnProtocols:                    slices.Clone(m.alpnProtocols),
699		scts:                             m.scts,
700		supportedVersions:                slices.Clone(m.supportedVersions),
701		cookie:                           slices.Clone(m.cookie),
702		keyShares:                        slices.Clone(m.keyShares),
703		earlyData:                        m.earlyData,
704		pskModes:                         slices.Clone(m.pskModes),
705		pskIdentities:                    slices.Clone(m.pskIdentities),
706		pskBinders:                       slices.Clone(m.pskBinders),
707		quicTransportParameters:          slices.Clone(m.quicTransportParameters),
708		encryptedClientHello:             slices.Clone(m.encryptedClientHello),
709	}
710}
711
712type serverHelloMsg struct {
713	original                     []byte
714	vers                         uint16
715	random                       []byte
716	sessionId                    []byte
717	cipherSuite                  uint16
718	compressionMethod            uint8
719	ocspStapling                 bool
720	ticketSupported              bool
721	secureRenegotiationSupported bool
722	secureRenegotiation          []byte
723	extendedMasterSecret         bool
724	alpnProtocol                 string
725	scts                         [][]byte
726	supportedVersion             uint16
727	serverShare                  keyShare
728	selectedIdentityPresent      bool
729	selectedIdentity             uint16
730	supportedPoints              []uint8
731	encryptedClientHello         []byte
732	serverNameAck                bool
733
734	// HelloRetryRequest extensions
735	cookie        []byte
736	selectedGroup CurveID
737}
738
739func (m *serverHelloMsg) marshal() ([]byte, error) {
740	var exts cryptobyte.Builder
741	if m.ocspStapling {
742		exts.AddUint16(extensionStatusRequest)
743		exts.AddUint16(0) // empty extension_data
744	}
745	if m.ticketSupported {
746		exts.AddUint16(extensionSessionTicket)
747		exts.AddUint16(0) // empty extension_data
748	}
749	if m.secureRenegotiationSupported {
750		exts.AddUint16(extensionRenegotiationInfo)
751		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
752			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
753				exts.AddBytes(m.secureRenegotiation)
754			})
755		})
756	}
757	if m.extendedMasterSecret {
758		exts.AddUint16(extensionExtendedMasterSecret)
759		exts.AddUint16(0) // empty extension_data
760	}
761	if len(m.alpnProtocol) > 0 {
762		exts.AddUint16(extensionALPN)
763		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
764			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
765				exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
766					exts.AddBytes([]byte(m.alpnProtocol))
767				})
768			})
769		})
770	}
771	if len(m.scts) > 0 {
772		exts.AddUint16(extensionSCT)
773		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
774			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
775				for _, sct := range m.scts {
776					exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
777						exts.AddBytes(sct)
778					})
779				}
780			})
781		})
782	}
783	if m.supportedVersion != 0 {
784		exts.AddUint16(extensionSupportedVersions)
785		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
786			exts.AddUint16(m.supportedVersion)
787		})
788	}
789	if m.serverShare.group != 0 {
790		exts.AddUint16(extensionKeyShare)
791		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
792			exts.AddUint16(uint16(m.serverShare.group))
793			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
794				exts.AddBytes(m.serverShare.data)
795			})
796		})
797	}
798	if m.selectedIdentityPresent {
799		exts.AddUint16(extensionPreSharedKey)
800		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
801			exts.AddUint16(m.selectedIdentity)
802		})
803	}
804
805	if len(m.cookie) > 0 {
806		exts.AddUint16(extensionCookie)
807		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
808			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
809				exts.AddBytes(m.cookie)
810			})
811		})
812	}
813	if m.selectedGroup != 0 {
814		exts.AddUint16(extensionKeyShare)
815		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
816			exts.AddUint16(uint16(m.selectedGroup))
817		})
818	}
819	if len(m.supportedPoints) > 0 {
820		exts.AddUint16(extensionSupportedPoints)
821		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
822			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
823				exts.AddBytes(m.supportedPoints)
824			})
825		})
826	}
827	if len(m.encryptedClientHello) > 0 {
828		exts.AddUint16(extensionEncryptedClientHello)
829		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
830			exts.AddBytes(m.encryptedClientHello)
831		})
832	}
833	if m.serverNameAck {
834		exts.AddUint16(extensionServerName)
835		exts.AddUint16(0)
836	}
837
838	extBytes, err := exts.Bytes()
839	if err != nil {
840		return nil, err
841	}
842
843	var b cryptobyte.Builder
844	b.AddUint8(typeServerHello)
845	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
846		b.AddUint16(m.vers)
847		addBytesWithLength(b, m.random, 32)
848		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
849			b.AddBytes(m.sessionId)
850		})
851		b.AddUint16(m.cipherSuite)
852		b.AddUint8(m.compressionMethod)
853
854		if len(extBytes) > 0 {
855			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
856				b.AddBytes(extBytes)
857			})
858		}
859	})
860
861	return b.Bytes()
862}
863
864func (m *serverHelloMsg) unmarshal(data []byte) bool {
865	*m = serverHelloMsg{original: data}
866	s := cryptobyte.String(data)
867
868	if !s.Skip(4) || // message type and uint24 length field
869		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
870		!readUint8LengthPrefixed(&s, &m.sessionId) ||
871		!s.ReadUint16(&m.cipherSuite) ||
872		!s.ReadUint8(&m.compressionMethod) {
873		return false
874	}
875
876	if s.Empty() {
877		// ServerHello is optionally followed by extension data
878		return true
879	}
880
881	var extensions cryptobyte.String
882	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
883		return false
884	}
885
886	seenExts := make(map[uint16]bool)
887	for !extensions.Empty() {
888		var extension uint16
889		var extData cryptobyte.String
890		if !extensions.ReadUint16(&extension) ||
891			!extensions.ReadUint16LengthPrefixed(&extData) {
892			return false
893		}
894
895		if seenExts[extension] {
896			return false
897		}
898		seenExts[extension] = true
899
900		switch extension {
901		case extensionStatusRequest:
902			m.ocspStapling = true
903		case extensionSessionTicket:
904			m.ticketSupported = true
905		case extensionRenegotiationInfo:
906			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
907				return false
908			}
909			m.secureRenegotiationSupported = true
910		case extensionExtendedMasterSecret:
911			m.extendedMasterSecret = true
912		case extensionALPN:
913			var protoList cryptobyte.String
914			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
915				return false
916			}
917			var proto cryptobyte.String
918			if !protoList.ReadUint8LengthPrefixed(&proto) ||
919				proto.Empty() || !protoList.Empty() {
920				return false
921			}
922			m.alpnProtocol = string(proto)
923		case extensionSCT:
924			var sctList cryptobyte.String
925			if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
926				return false
927			}
928			for !sctList.Empty() {
929				var sct []byte
930				if !readUint16LengthPrefixed(&sctList, &sct) ||
931					len(sct) == 0 {
932					return false
933				}
934				m.scts = append(m.scts, sct)
935			}
936		case extensionSupportedVersions:
937			if !extData.ReadUint16(&m.supportedVersion) {
938				return false
939			}
940		case extensionCookie:
941			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
942				len(m.cookie) == 0 {
943				return false
944			}
945		case extensionKeyShare:
946			// This extension has different formats in SH and HRR, accept either
947			// and let the handshake logic decide. See RFC 8446, Section 4.2.8.
948			if len(extData) == 2 {
949				if !extData.ReadUint16((*uint16)(&m.selectedGroup)) {
950					return false
951				}
952			} else {
953				if !extData.ReadUint16((*uint16)(&m.serverShare.group)) ||
954					!readUint16LengthPrefixed(&extData, &m.serverShare.data) {
955					return false
956				}
957			}
958		case extensionPreSharedKey:
959			m.selectedIdentityPresent = true
960			if !extData.ReadUint16(&m.selectedIdentity) {
961				return false
962			}
963		case extensionSupportedPoints:
964			// RFC 4492, Section 5.1.2
965			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
966				len(m.supportedPoints) == 0 {
967				return false
968			}
969		case extensionEncryptedClientHello: // encrypted_client_hello
970			m.encryptedClientHello = make([]byte, len(extData))
971			if !extData.CopyBytes(m.encryptedClientHello) {
972				return false
973			}
974		case extensionServerName:
975			if len(extData) != 0 {
976				return false
977			}
978			m.serverNameAck = true
979		default:
980			// Ignore unknown extensions.
981			continue
982		}
983
984		if !extData.Empty() {
985			return false
986		}
987	}
988
989	return true
990}
991
992func (m *serverHelloMsg) originalBytes() []byte {
993	return m.original
994}
995
996type encryptedExtensionsMsg struct {
997	alpnProtocol            string
998	quicTransportParameters []byte
999	earlyData               bool
1000	echRetryConfigs         []byte
1001}
1002
1003func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
1004	var b cryptobyte.Builder
1005	b.AddUint8(typeEncryptedExtensions)
1006	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1007		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1008			if len(m.alpnProtocol) > 0 {
1009				b.AddUint16(extensionALPN)
1010				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1011					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1012						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
1013							b.AddBytes([]byte(m.alpnProtocol))
1014						})
1015					})
1016				})
1017			}
1018			if m.quicTransportParameters != nil { // marshal zero-length parameters when present
1019				// draft-ietf-quic-tls-32, Section 8.2
1020				b.AddUint16(extensionQUICTransportParameters)
1021				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1022					b.AddBytes(m.quicTransportParameters)
1023				})
1024			}
1025			if m.earlyData {
1026				// RFC 8446, Section 4.2.10
1027				b.AddUint16(extensionEarlyData)
1028				b.AddUint16(0) // empty extension_data
1029			}
1030			if len(m.echRetryConfigs) > 0 {
1031				b.AddUint16(extensionEncryptedClientHello)
1032				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1033					b.AddBytes(m.echRetryConfigs)
1034				})
1035			}
1036		})
1037	})
1038
1039	return b.Bytes()
1040}
1041
1042func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
1043	*m = encryptedExtensionsMsg{}
1044	s := cryptobyte.String(data)
1045
1046	var extensions cryptobyte.String
1047	if !s.Skip(4) || // message type and uint24 length field
1048		!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
1049		return false
1050	}
1051
1052	for !extensions.Empty() {
1053		var extension uint16
1054		var extData cryptobyte.String
1055		if !extensions.ReadUint16(&extension) ||
1056			!extensions.ReadUint16LengthPrefixed(&extData) {
1057			return false
1058		}
1059
1060		switch extension {
1061		case extensionALPN:
1062			var protoList cryptobyte.String
1063			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
1064				return false
1065			}
1066			var proto cryptobyte.String
1067			if !protoList.ReadUint8LengthPrefixed(&proto) ||
1068				proto.Empty() || !protoList.Empty() {
1069				return false
1070			}
1071			m.alpnProtocol = string(proto)
1072		case extensionQUICTransportParameters:
1073			m.quicTransportParameters = make([]byte, len(extData))
1074			if !extData.CopyBytes(m.quicTransportParameters) {
1075				return false
1076			}
1077		case extensionEarlyData:
1078			// RFC 8446, Section 4.2.10
1079			m.earlyData = true
1080		case extensionEncryptedClientHello:
1081			m.echRetryConfigs = make([]byte, len(extData))
1082			if !extData.CopyBytes(m.echRetryConfigs) {
1083				return false
1084			}
1085		default:
1086			// Ignore unknown extensions.
1087			continue
1088		}
1089
1090		if !extData.Empty() {
1091			return false
1092		}
1093	}
1094
1095	return true
1096}
1097
1098type endOfEarlyDataMsg struct{}
1099
1100func (m *endOfEarlyDataMsg) marshal() ([]byte, error) {
1101	x := make([]byte, 4)
1102	x[0] = typeEndOfEarlyData
1103	return x, nil
1104}
1105
1106func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
1107	return len(data) == 4
1108}
1109
1110type keyUpdateMsg struct {
1111	updateRequested bool
1112}
1113
1114func (m *keyUpdateMsg) marshal() ([]byte, error) {
1115	var b cryptobyte.Builder
1116	b.AddUint8(typeKeyUpdate)
1117	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1118		if m.updateRequested {
1119			b.AddUint8(1)
1120		} else {
1121			b.AddUint8(0)
1122		}
1123	})
1124
1125	return b.Bytes()
1126}
1127
1128func (m *keyUpdateMsg) unmarshal(data []byte) bool {
1129	s := cryptobyte.String(data)
1130
1131	var updateRequested uint8
1132	if !s.Skip(4) || // message type and uint24 length field
1133		!s.ReadUint8(&updateRequested) || !s.Empty() {
1134		return false
1135	}
1136	switch updateRequested {
1137	case 0:
1138		m.updateRequested = false
1139	case 1:
1140		m.updateRequested = true
1141	default:
1142		return false
1143	}
1144	return true
1145}
1146
1147type newSessionTicketMsgTLS13 struct {
1148	lifetime     uint32
1149	ageAdd       uint32
1150	nonce        []byte
1151	label        []byte
1152	maxEarlyData uint32
1153}
1154
1155func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
1156	var b cryptobyte.Builder
1157	b.AddUint8(typeNewSessionTicket)
1158	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1159		b.AddUint32(m.lifetime)
1160		b.AddUint32(m.ageAdd)
1161		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
1162			b.AddBytes(m.nonce)
1163		})
1164		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1165			b.AddBytes(m.label)
1166		})
1167
1168		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1169			if m.maxEarlyData > 0 {
1170				b.AddUint16(extensionEarlyData)
1171				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1172					b.AddUint32(m.maxEarlyData)
1173				})
1174			}
1175		})
1176	})
1177
1178	return b.Bytes()
1179}
1180
1181func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
1182	*m = newSessionTicketMsgTLS13{}
1183	s := cryptobyte.String(data)
1184
1185	var extensions cryptobyte.String
1186	if !s.Skip(4) || // message type and uint24 length field
1187		!s.ReadUint32(&m.lifetime) ||
1188		!s.ReadUint32(&m.ageAdd) ||
1189		!readUint8LengthPrefixed(&s, &m.nonce) ||
1190		!readUint16LengthPrefixed(&s, &m.label) ||
1191		!s.ReadUint16LengthPrefixed(&extensions) ||
1192		!s.Empty() {
1193		return false
1194	}
1195
1196	for !extensions.Empty() {
1197		var extension uint16
1198		var extData cryptobyte.String
1199		if !extensions.ReadUint16(&extension) ||
1200			!extensions.ReadUint16LengthPrefixed(&extData) {
1201			return false
1202		}
1203
1204		switch extension {
1205		case extensionEarlyData:
1206			if !extData.ReadUint32(&m.maxEarlyData) {
1207				return false
1208			}
1209		default:
1210			// Ignore unknown extensions.
1211			continue
1212		}
1213
1214		if !extData.Empty() {
1215			return false
1216		}
1217	}
1218
1219	return true
1220}
1221
1222type certificateRequestMsgTLS13 struct {
1223	ocspStapling                     bool
1224	scts                             bool
1225	supportedSignatureAlgorithms     []SignatureScheme
1226	supportedSignatureAlgorithmsCert []SignatureScheme
1227	certificateAuthorities           [][]byte
1228}
1229
1230func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
1231	var b cryptobyte.Builder
1232	b.AddUint8(typeCertificateRequest)
1233	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1234		// certificate_request_context (SHALL be zero length unless used for
1235		// post-handshake authentication)
1236		b.AddUint8(0)
1237
1238		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1239			if m.ocspStapling {
1240				b.AddUint16(extensionStatusRequest)
1241				b.AddUint16(0) // empty extension_data
1242			}
1243			if m.scts {
1244				// RFC 8446, Section 4.4.2.1 makes no mention of
1245				// signed_certificate_timestamp in CertificateRequest, but
1246				// "Extensions in the Certificate message from the client MUST
1247				// correspond to extensions in the CertificateRequest message
1248				// from the server." and it appears in the table in Section 4.2.
1249				b.AddUint16(extensionSCT)
1250				b.AddUint16(0) // empty extension_data
1251			}
1252			if len(m.supportedSignatureAlgorithms) > 0 {
1253				b.AddUint16(extensionSignatureAlgorithms)
1254				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1255					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1256						for _, sigAlgo := range m.supportedSignatureAlgorithms {
1257							b.AddUint16(uint16(sigAlgo))
1258						}
1259					})
1260				})
1261			}
1262			if len(m.supportedSignatureAlgorithmsCert) > 0 {
1263				b.AddUint16(extensionSignatureAlgorithmsCert)
1264				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1265					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1266						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
1267							b.AddUint16(uint16(sigAlgo))
1268						}
1269					})
1270				})
1271			}
1272			if len(m.certificateAuthorities) > 0 {
1273				b.AddUint16(extensionCertificateAuthorities)
1274				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1275					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1276						for _, ca := range m.certificateAuthorities {
1277							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1278								b.AddBytes(ca)
1279							})
1280						}
1281					})
1282				})
1283			}
1284		})
1285	})
1286
1287	return b.Bytes()
1288}
1289
1290func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
1291	*m = certificateRequestMsgTLS13{}
1292	s := cryptobyte.String(data)
1293
1294	var context, extensions cryptobyte.String
1295	if !s.Skip(4) || // message type and uint24 length field
1296		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
1297		!s.ReadUint16LengthPrefixed(&extensions) ||
1298		!s.Empty() {
1299		return false
1300	}
1301
1302	for !extensions.Empty() {
1303		var extension uint16
1304		var extData cryptobyte.String
1305		if !extensions.ReadUint16(&extension) ||
1306			!extensions.ReadUint16LengthPrefixed(&extData) {
1307			return false
1308		}
1309
1310		switch extension {
1311		case extensionStatusRequest:
1312			m.ocspStapling = true
1313		case extensionSCT:
1314			m.scts = true
1315		case extensionSignatureAlgorithms:
1316			var sigAndAlgs cryptobyte.String
1317			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
1318				return false
1319			}
1320			for !sigAndAlgs.Empty() {
1321				var sigAndAlg uint16
1322				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
1323					return false
1324				}
1325				m.supportedSignatureAlgorithms = append(
1326					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
1327			}
1328		case extensionSignatureAlgorithmsCert:
1329			var sigAndAlgs cryptobyte.String
1330			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
1331				return false
1332			}
1333			for !sigAndAlgs.Empty() {
1334				var sigAndAlg uint16
1335				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
1336					return false
1337				}
1338				m.supportedSignatureAlgorithmsCert = append(
1339					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
1340			}
1341		case extensionCertificateAuthorities:
1342			var auths cryptobyte.String
1343			if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() {
1344				return false
1345			}
1346			for !auths.Empty() {
1347				var ca []byte
1348				if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 {
1349					return false
1350				}
1351				m.certificateAuthorities = append(m.certificateAuthorities, ca)
1352			}
1353		default:
1354			// Ignore unknown extensions.
1355			continue
1356		}
1357
1358		if !extData.Empty() {
1359			return false
1360		}
1361	}
1362
1363	return true
1364}
1365
1366type certificateMsg struct {
1367	certificates [][]byte
1368}
1369
1370func (m *certificateMsg) marshal() ([]byte, error) {
1371	var i int
1372	for _, slice := range m.certificates {
1373		i += len(slice)
1374	}
1375
1376	length := 3 + 3*len(m.certificates) + i
1377	x := make([]byte, 4+length)
1378	x[0] = typeCertificate
1379	x[1] = uint8(length >> 16)
1380	x[2] = uint8(length >> 8)
1381	x[3] = uint8(length)
1382
1383	certificateOctets := length - 3
1384	x[4] = uint8(certificateOctets >> 16)
1385	x[5] = uint8(certificateOctets >> 8)
1386	x[6] = uint8(certificateOctets)
1387
1388	y := x[7:]
1389	for _, slice := range m.certificates {
1390		y[0] = uint8(len(slice) >> 16)
1391		y[1] = uint8(len(slice) >> 8)
1392		y[2] = uint8(len(slice))
1393		copy(y[3:], slice)
1394		y = y[3+len(slice):]
1395	}
1396
1397	return x, nil
1398}
1399
1400func (m *certificateMsg) unmarshal(data []byte) bool {
1401	if len(data) < 7 {
1402		return false
1403	}
1404
1405	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
1406	if uint32(len(data)) != certsLen+7 {
1407		return false
1408	}
1409
1410	numCerts := 0
1411	d := data[7:]
1412	for certsLen > 0 {
1413		if len(d) < 4 {
1414			return false
1415		}
1416		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
1417		if uint32(len(d)) < 3+certLen {
1418			return false
1419		}
1420		d = d[3+certLen:]
1421		certsLen -= 3 + certLen
1422		numCerts++
1423	}
1424
1425	m.certificates = make([][]byte, numCerts)
1426	d = data[7:]
1427	for i := 0; i < numCerts; i++ {
1428		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
1429		m.certificates[i] = d[3 : 3+certLen]
1430		d = d[3+certLen:]
1431	}
1432
1433	return true
1434}
1435
1436type certificateMsgTLS13 struct {
1437	certificate  Certificate
1438	ocspStapling bool
1439	scts         bool
1440}
1441
1442func (m *certificateMsgTLS13) marshal() ([]byte, error) {
1443	var b cryptobyte.Builder
1444	b.AddUint8(typeCertificate)
1445	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1446		b.AddUint8(0) // certificate_request_context
1447
1448		certificate := m.certificate
1449		if !m.ocspStapling {
1450			certificate.OCSPStaple = nil
1451		}
1452		if !m.scts {
1453			certificate.SignedCertificateTimestamps = nil
1454		}
1455		marshalCertificate(b, certificate)
1456	})
1457
1458	return b.Bytes()
1459}
1460
1461func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
1462	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1463		for i, cert := range certificate.Certificate {
1464			b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1465				b.AddBytes(cert)
1466			})
1467			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1468				if i > 0 {
1469					// This library only supports OCSP and SCT for leaf certificates.
1470					return
1471				}
1472				if certificate.OCSPStaple != nil {
1473					b.AddUint16(extensionStatusRequest)
1474					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1475						b.AddUint8(statusTypeOCSP)
1476						b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1477							b.AddBytes(certificate.OCSPStaple)
1478						})
1479					})
1480				}
1481				if certificate.SignedCertificateTimestamps != nil {
1482					b.AddUint16(extensionSCT)
1483					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1484						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1485							for _, sct := range certificate.SignedCertificateTimestamps {
1486								b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1487									b.AddBytes(sct)
1488								})
1489							}
1490						})
1491					})
1492				}
1493			})
1494		}
1495	})
1496}
1497
1498func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
1499	*m = certificateMsgTLS13{}
1500	s := cryptobyte.String(data)
1501
1502	var context cryptobyte.String
1503	if !s.Skip(4) || // message type and uint24 length field
1504		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
1505		!unmarshalCertificate(&s, &m.certificate) ||
1506		!s.Empty() {
1507		return false
1508	}
1509
1510	m.scts = m.certificate.SignedCertificateTimestamps != nil
1511	m.ocspStapling = m.certificate.OCSPStaple != nil
1512
1513	return true
1514}
1515
1516func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
1517	var certList cryptobyte.String
1518	if !s.ReadUint24LengthPrefixed(&certList) {
1519		return false
1520	}
1521	for !certList.Empty() {
1522		var cert []byte
1523		var extensions cryptobyte.String
1524		if !readUint24LengthPrefixed(&certList, &cert) ||
1525			!certList.ReadUint16LengthPrefixed(&extensions) {
1526			return false
1527		}
1528		certificate.Certificate = append(certificate.Certificate, cert)
1529		for !extensions.Empty() {
1530			var extension uint16
1531			var extData cryptobyte.String
1532			if !extensions.ReadUint16(&extension) ||
1533				!extensions.ReadUint16LengthPrefixed(&extData) {
1534				return false
1535			}
1536			if len(certificate.Certificate) > 1 {
1537				// This library only supports OCSP and SCT for leaf certificates.
1538				continue
1539			}
1540
1541			switch extension {
1542			case extensionStatusRequest:
1543				var statusType uint8
1544				if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
1545					!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
1546					len(certificate.OCSPStaple) == 0 {
1547					return false
1548				}
1549			case extensionSCT:
1550				var sctList cryptobyte.String
1551				if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
1552					return false
1553				}
1554				for !sctList.Empty() {
1555					var sct []byte
1556					if !readUint16LengthPrefixed(&sctList, &sct) ||
1557						len(sct) == 0 {
1558						return false
1559					}
1560					certificate.SignedCertificateTimestamps = append(
1561						certificate.SignedCertificateTimestamps, sct)
1562				}
1563			default:
1564				// Ignore unknown extensions.
1565				continue
1566			}
1567
1568			if !extData.Empty() {
1569				return false
1570			}
1571		}
1572	}
1573	return true
1574}
1575
1576type serverKeyExchangeMsg struct {
1577	key []byte
1578}
1579
1580func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
1581	length := len(m.key)
1582	x := make([]byte, length+4)
1583	x[0] = typeServerKeyExchange
1584	x[1] = uint8(length >> 16)
1585	x[2] = uint8(length >> 8)
1586	x[3] = uint8(length)
1587	copy(x[4:], m.key)
1588
1589	return x, nil
1590}
1591
1592func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
1593	if len(data) < 4 {
1594		return false
1595	}
1596	m.key = data[4:]
1597	return true
1598}
1599
1600type certificateStatusMsg struct {
1601	response []byte
1602}
1603
1604func (m *certificateStatusMsg) marshal() ([]byte, error) {
1605	var b cryptobyte.Builder
1606	b.AddUint8(typeCertificateStatus)
1607	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1608		b.AddUint8(statusTypeOCSP)
1609		b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1610			b.AddBytes(m.response)
1611		})
1612	})
1613
1614	return b.Bytes()
1615}
1616
1617func (m *certificateStatusMsg) unmarshal(data []byte) bool {
1618	s := cryptobyte.String(data)
1619
1620	var statusType uint8
1621	if !s.Skip(4) || // message type and uint24 length field
1622		!s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
1623		!readUint24LengthPrefixed(&s, &m.response) ||
1624		len(m.response) == 0 || !s.Empty() {
1625		return false
1626	}
1627	return true
1628}
1629
1630type serverHelloDoneMsg struct{}
1631
1632func (m *serverHelloDoneMsg) marshal() ([]byte, error) {
1633	x := make([]byte, 4)
1634	x[0] = typeServerHelloDone
1635	return x, nil
1636}
1637
1638func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
1639	return len(data) == 4
1640}
1641
1642type clientKeyExchangeMsg struct {
1643	ciphertext []byte
1644}
1645
1646func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
1647	length := len(m.ciphertext)
1648	x := make([]byte, length+4)
1649	x[0] = typeClientKeyExchange
1650	x[1] = uint8(length >> 16)
1651	x[2] = uint8(length >> 8)
1652	x[3] = uint8(length)
1653	copy(x[4:], m.ciphertext)
1654
1655	return x, nil
1656}
1657
1658func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
1659	if len(data) < 4 {
1660		return false
1661	}
1662	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1663	if l != len(data)-4 {
1664		return false
1665	}
1666	m.ciphertext = data[4:]
1667	return true
1668}
1669
1670type finishedMsg struct {
1671	verifyData []byte
1672}
1673
1674func (m *finishedMsg) marshal() ([]byte, error) {
1675	var b cryptobyte.Builder
1676	b.AddUint8(typeFinished)
1677	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1678		b.AddBytes(m.verifyData)
1679	})
1680
1681	return b.Bytes()
1682}
1683
1684func (m *finishedMsg) unmarshal(data []byte) bool {
1685	s := cryptobyte.String(data)
1686	return s.Skip(1) &&
1687		readUint24LengthPrefixed(&s, &m.verifyData) &&
1688		s.Empty()
1689}
1690
1691type certificateRequestMsg struct {
1692	// hasSignatureAlgorithm indicates whether this message includes a list of
1693	// supported signature algorithms. This change was introduced with TLS 1.2.
1694	hasSignatureAlgorithm bool
1695
1696	certificateTypes             []byte
1697	supportedSignatureAlgorithms []SignatureScheme
1698	certificateAuthorities       [][]byte
1699}
1700
1701func (m *certificateRequestMsg) marshal() ([]byte, error) {
1702	// See RFC 4346, Section 7.4.4.
1703	length := 1 + len(m.certificateTypes) + 2
1704	casLength := 0
1705	for _, ca := range m.certificateAuthorities {
1706		casLength += 2 + len(ca)
1707	}
1708	length += casLength
1709
1710	if m.hasSignatureAlgorithm {
1711		length += 2 + 2*len(m.supportedSignatureAlgorithms)
1712	}
1713
1714	x := make([]byte, 4+length)
1715	x[0] = typeCertificateRequest
1716	x[1] = uint8(length >> 16)
1717	x[2] = uint8(length >> 8)
1718	x[3] = uint8(length)
1719
1720	x[4] = uint8(len(m.certificateTypes))
1721
1722	copy(x[5:], m.certificateTypes)
1723	y := x[5+len(m.certificateTypes):]
1724
1725	if m.hasSignatureAlgorithm {
1726		n := len(m.supportedSignatureAlgorithms) * 2
1727		y[0] = uint8(n >> 8)
1728		y[1] = uint8(n)
1729		y = y[2:]
1730		for _, sigAlgo := range m.supportedSignatureAlgorithms {
1731			y[0] = uint8(sigAlgo >> 8)
1732			y[1] = uint8(sigAlgo)
1733			y = y[2:]
1734		}
1735	}
1736
1737	y[0] = uint8(casLength >> 8)
1738	y[1] = uint8(casLength)
1739	y = y[2:]
1740	for _, ca := range m.certificateAuthorities {
1741		y[0] = uint8(len(ca) >> 8)
1742		y[1] = uint8(len(ca))
1743		y = y[2:]
1744		copy(y, ca)
1745		y = y[len(ca):]
1746	}
1747
1748	return x, nil
1749}
1750
1751func (m *certificateRequestMsg) unmarshal(data []byte) bool {
1752	if len(data) < 5 {
1753		return false
1754	}
1755
1756	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1757	if uint32(len(data))-4 != length {
1758		return false
1759	}
1760
1761	numCertTypes := int(data[4])
1762	data = data[5:]
1763	if numCertTypes == 0 || len(data) <= numCertTypes {
1764		return false
1765	}
1766
1767	m.certificateTypes = make([]byte, numCertTypes)
1768	if copy(m.certificateTypes, data) != numCertTypes {
1769		return false
1770	}
1771
1772	data = data[numCertTypes:]
1773
1774	if m.hasSignatureAlgorithm {
1775		if len(data) < 2 {
1776			return false
1777		}
1778		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
1779		data = data[2:]
1780		if sigAndHashLen&1 != 0 {
1781			return false
1782		}
1783		if len(data) < int(sigAndHashLen) {
1784			return false
1785		}
1786		numSigAlgos := sigAndHashLen / 2
1787		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
1788		for i := range m.supportedSignatureAlgorithms {
1789			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
1790			data = data[2:]
1791		}
1792	}
1793
1794	if len(data) < 2 {
1795		return false
1796	}
1797	casLength := uint16(data[0])<<8 | uint16(data[1])
1798	data = data[2:]
1799	if len(data) < int(casLength) {
1800		return false
1801	}
1802	cas := make([]byte, casLength)
1803	copy(cas, data)
1804	data = data[casLength:]
1805
1806	m.certificateAuthorities = nil
1807	for len(cas) > 0 {
1808		if len(cas) < 2 {
1809			return false
1810		}
1811		caLen := uint16(cas[0])<<8 | uint16(cas[1])
1812		cas = cas[2:]
1813
1814		if len(cas) < int(caLen) {
1815			return false
1816		}
1817
1818		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
1819		cas = cas[caLen:]
1820	}
1821
1822	return len(data) == 0
1823}
1824
1825type certificateVerifyMsg struct {
1826	hasSignatureAlgorithm bool // format change introduced in TLS 1.2
1827	signatureAlgorithm    SignatureScheme
1828	signature             []byte
1829}
1830
1831func (m *certificateVerifyMsg) marshal() ([]byte, error) {
1832	var b cryptobyte.Builder
1833	b.AddUint8(typeCertificateVerify)
1834	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1835		if m.hasSignatureAlgorithm {
1836			b.AddUint16(uint16(m.signatureAlgorithm))
1837		}
1838		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1839			b.AddBytes(m.signature)
1840		})
1841	})
1842
1843	return b.Bytes()
1844}
1845
1846func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
1847	s := cryptobyte.String(data)
1848
1849	if !s.Skip(4) { // message type and uint24 length field
1850		return false
1851	}
1852	if m.hasSignatureAlgorithm {
1853		if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
1854			return false
1855		}
1856	}
1857	return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
1858}
1859
1860type newSessionTicketMsg struct {
1861	ticket []byte
1862}
1863
1864func (m *newSessionTicketMsg) marshal() ([]byte, error) {
1865	// See RFC 5077, Section 3.3.
1866	ticketLen := len(m.ticket)
1867	length := 2 + 4 + ticketLen
1868	x := make([]byte, 4+length)
1869	x[0] = typeNewSessionTicket
1870	x[1] = uint8(length >> 16)
1871	x[2] = uint8(length >> 8)
1872	x[3] = uint8(length)
1873	x[8] = uint8(ticketLen >> 8)
1874	x[9] = uint8(ticketLen)
1875	copy(x[10:], m.ticket)
1876
1877	return x, nil
1878}
1879
1880func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
1881	if len(data) < 10 {
1882		return false
1883	}
1884
1885	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1886	if uint32(len(data))-4 != length {
1887		return false
1888	}
1889
1890	ticketLen := int(data[8])<<8 + int(data[9])
1891	if len(data)-10 != ticketLen {
1892		return false
1893	}
1894
1895	m.ticket = data[10:]
1896
1897	return true
1898}
1899
1900type helloRequestMsg struct {
1901}
1902
1903func (*helloRequestMsg) marshal() ([]byte, error) {
1904	return []byte{typeHelloRequest, 0, 0, 0}, nil
1905}
1906
1907func (*helloRequestMsg) unmarshal(data []byte) bool {
1908	return len(data) == 4
1909}
1910
1911type transcriptHash interface {
1912	Write([]byte) (int, error)
1913}
1914
1915// transcriptMsg is a helper used to hash messages which are not hashed when
1916// they are read from, or written to, the wire. This is typically the case for
1917// messages which are either not sent, or need to be hashed out of order from
1918// when they are read/written.
1919//
1920// For most messages, the message is marshalled using their marshal method,
1921// since their wire representation is idempotent. For clientHelloMsg and
1922// serverHelloMsg, we store the original wire representation of the message and
1923// use that for hashing, since unmarshal/marshal are not idempotent due to
1924// extension ordering and other malleable fields, which may cause differences
1925// between what was received and what we marshal.
1926func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
1927	if msgWithOrig, ok := msg.(handshakeMessageWithOriginalBytes); ok {
1928		if orig := msgWithOrig.originalBytes(); orig != nil {
1929			h.Write(msgWithOrig.originalBytes())
1930			return nil
1931		}
1932	}
1933
1934	data, err := msg.marshal()
1935	if err != nil {
1936		return err
1937	}
1938	h.Write(data)
1939	return nil
1940}
1941