1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bytes"
9	"crypto/x509"
10	"encoding/hex"
11	"math"
12	"math/rand"
13	"reflect"
14	"strings"
15	"testing"
16	"testing/quick"
17	"time"
18)
19
20var tests = []handshakeMessage{
21	&clientHelloMsg{},
22	&serverHelloMsg{},
23	&finishedMsg{},
24
25	&certificateMsg{},
26	&certificateRequestMsg{},
27	&certificateVerifyMsg{
28		hasSignatureAlgorithm: true,
29	},
30	&certificateStatusMsg{},
31	&clientKeyExchangeMsg{},
32	&newSessionTicketMsg{},
33	&encryptedExtensionsMsg{},
34	&endOfEarlyDataMsg{},
35	&keyUpdateMsg{},
36	&newSessionTicketMsgTLS13{},
37	&certificateRequestMsgTLS13{},
38	&certificateMsgTLS13{},
39	&SessionState{},
40}
41
42func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
43	t.Helper()
44	b, err := msg.marshal()
45	if err != nil {
46		t.Fatal(err)
47	}
48	return b
49}
50
51func TestMarshalUnmarshal(t *testing.T) {
52	rand := rand.New(rand.NewSource(time.Now().UnixNano()))
53
54	for i, m := range tests {
55		ty := reflect.ValueOf(m).Type()
56		t.Run(ty.String(), func(t *testing.T) {
57			n := 100
58			if testing.Short() {
59				n = 5
60			}
61			for j := 0; j < n; j++ {
62				v, ok := quick.Value(ty, rand)
63				if !ok {
64					t.Errorf("#%d: failed to create value", i)
65					break
66				}
67
68				m1 := v.Interface().(handshakeMessage)
69				marshaled := mustMarshal(t, m1)
70				if !m.unmarshal(marshaled) {
71					t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
72					break
73				}
74
75				if m, ok := m.(*SessionState); ok {
76					m.activeCertHandles = nil
77				}
78
79				// clientHelloMsg and serverHelloMsg, when unmarshalled, store
80				// their original representation, for later use in the handshake
81				// transcript. In order to prevent DeepEqual from failing since
82				// we didn't create the original message via unmarshalling, nil
83				// the field.
84				switch t := m.(type) {
85				case *clientHelloMsg:
86					t.original = nil
87				case *serverHelloMsg:
88					t.original = nil
89				}
90
91				if !reflect.DeepEqual(m1, m) {
92					t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
93					break
94				}
95
96				if i >= 3 {
97					// The first three message types (ClientHello,
98					// ServerHello and Finished) are allowed to
99					// have parsable prefixes because the extension
100					// data is optional and the length of the
101					// Finished varies across versions.
102					for j := 0; j < len(marshaled); j++ {
103						if m.unmarshal(marshaled[0:j]) {
104							t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
105							break
106						}
107					}
108				}
109			}
110		})
111	}
112}
113
114func TestFuzz(t *testing.T) {
115	rand := rand.New(rand.NewSource(0))
116	for _, m := range tests {
117		for j := 0; j < 1000; j++ {
118			len := rand.Intn(1000)
119			bytes := randomBytes(len, rand)
120			// This just looks for crashes due to bounds errors etc.
121			m.unmarshal(bytes)
122		}
123	}
124}
125
126func randomBytes(n int, rand *rand.Rand) []byte {
127	r := make([]byte, n)
128	if _, err := rand.Read(r); err != nil {
129		panic("rand.Read failed: " + err.Error())
130	}
131	return r
132}
133
134func randomString(n int, rand *rand.Rand) string {
135	b := randomBytes(n, rand)
136	return string(b)
137}
138
139func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
140	m := &clientHelloMsg{}
141	m.vers = uint16(rand.Intn(65536))
142	m.random = randomBytes(32, rand)
143	m.sessionId = randomBytes(rand.Intn(32), rand)
144	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
145	for i := 0; i < len(m.cipherSuites); i++ {
146		cs := uint16(rand.Int31())
147		if cs == scsvRenegotiation {
148			cs += 1
149		}
150		m.cipherSuites[i] = cs
151	}
152	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
153	if rand.Intn(10) > 5 {
154		m.serverName = randomString(rand.Intn(255), rand)
155		for strings.HasSuffix(m.serverName, ".") {
156			m.serverName = m.serverName[:len(m.serverName)-1]
157		}
158	}
159	m.ocspStapling = rand.Intn(10) > 5
160	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
161	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
162	for i := range m.supportedCurves {
163		m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
164	}
165	if rand.Intn(10) > 5 {
166		m.ticketSupported = true
167		if rand.Intn(10) > 5 {
168			m.sessionTicket = randomBytes(rand.Intn(300), rand)
169		} else {
170			m.sessionTicket = make([]byte, 0)
171		}
172	}
173	if rand.Intn(10) > 5 {
174		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
175	}
176	if rand.Intn(10) > 5 {
177		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
178	}
179	for i := 0; i < rand.Intn(5); i++ {
180		m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
181	}
182	if rand.Intn(10) > 5 {
183		m.scts = true
184	}
185	if rand.Intn(10) > 5 {
186		m.secureRenegotiationSupported = true
187		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
188	}
189	if rand.Intn(10) > 5 {
190		m.extendedMasterSecret = true
191	}
192	for i := 0; i < rand.Intn(5); i++ {
193		m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
194	}
195	if rand.Intn(10) > 5 {
196		m.cookie = randomBytes(rand.Intn(500)+1, rand)
197	}
198	for i := 0; i < rand.Intn(5); i++ {
199		var ks keyShare
200		ks.group = CurveID(rand.Intn(30000) + 1)
201		ks.data = randomBytes(rand.Intn(200)+1, rand)
202		m.keyShares = append(m.keyShares, ks)
203	}
204	switch rand.Intn(3) {
205	case 1:
206		m.pskModes = []uint8{pskModeDHE}
207	case 2:
208		m.pskModes = []uint8{pskModeDHE, pskModePlain}
209	}
210	for i := 0; i < rand.Intn(5); i++ {
211		var psk pskIdentity
212		psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
213		psk.label = randomBytes(rand.Intn(500)+1, rand)
214		m.pskIdentities = append(m.pskIdentities, psk)
215		m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
216	}
217	if rand.Intn(10) > 5 {
218		m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
219	}
220	if rand.Intn(10) > 5 {
221		m.earlyData = true
222	}
223
224	return reflect.ValueOf(m)
225}
226
227func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
228	m := &serverHelloMsg{}
229	m.vers = uint16(rand.Intn(65536))
230	m.random = randomBytes(32, rand)
231	m.sessionId = randomBytes(rand.Intn(32), rand)
232	m.cipherSuite = uint16(rand.Int31())
233	m.compressionMethod = uint8(rand.Intn(256))
234	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
235
236	if rand.Intn(10) > 5 {
237		m.ocspStapling = true
238	}
239	if rand.Intn(10) > 5 {
240		m.ticketSupported = true
241	}
242	if rand.Intn(10) > 5 {
243		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
244	}
245
246	for i := 0; i < rand.Intn(4); i++ {
247		m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
248	}
249
250	if rand.Intn(10) > 5 {
251		m.secureRenegotiationSupported = true
252		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
253	}
254	if rand.Intn(10) > 5 {
255		m.extendedMasterSecret = true
256	}
257	if rand.Intn(10) > 5 {
258		m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
259	}
260	if rand.Intn(10) > 5 {
261		m.cookie = randomBytes(rand.Intn(500)+1, rand)
262	}
263	if rand.Intn(10) > 5 {
264		for i := 0; i < rand.Intn(5); i++ {
265			m.serverShare.group = CurveID(rand.Intn(30000) + 1)
266			m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
267		}
268	} else if rand.Intn(10) > 5 {
269		m.selectedGroup = CurveID(rand.Intn(30000) + 1)
270	}
271	if rand.Intn(10) > 5 {
272		m.selectedIdentityPresent = true
273		m.selectedIdentity = uint16(rand.Intn(0xffff))
274	}
275	if rand.Intn(10) > 5 {
276		m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
277	}
278	if rand.Intn(10) > 5 {
279		m.serverNameAck = rand.Intn(2) == 1
280	}
281
282	return reflect.ValueOf(m)
283}
284
285func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
286	m := &encryptedExtensionsMsg{}
287
288	if rand.Intn(10) > 5 {
289		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
290	}
291	if rand.Intn(10) > 5 {
292		m.earlyData = true
293	}
294
295	return reflect.ValueOf(m)
296}
297
298func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
299	m := &certificateMsg{}
300	numCerts := rand.Intn(20)
301	m.certificates = make([][]byte, numCerts)
302	for i := 0; i < numCerts; i++ {
303		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
304	}
305	return reflect.ValueOf(m)
306}
307
308func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
309	m := &certificateRequestMsg{}
310	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
311	for i := 0; i < rand.Intn(100); i++ {
312		m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
313	}
314	return reflect.ValueOf(m)
315}
316
317func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
318	m := &certificateVerifyMsg{}
319	m.hasSignatureAlgorithm = true
320	m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
321	m.signature = randomBytes(rand.Intn(15)+1, rand)
322	return reflect.ValueOf(m)
323}
324
325func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
326	m := &certificateStatusMsg{}
327	m.response = randomBytes(rand.Intn(10)+1, rand)
328	return reflect.ValueOf(m)
329}
330
331func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
332	m := &clientKeyExchangeMsg{}
333	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
334	return reflect.ValueOf(m)
335}
336
337func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
338	m := &finishedMsg{}
339	m.verifyData = randomBytes(12, rand)
340	return reflect.ValueOf(m)
341}
342
343func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
344	m := &newSessionTicketMsg{}
345	m.ticket = randomBytes(rand.Intn(4), rand)
346	return reflect.ValueOf(m)
347}
348
349var sessionTestCerts []*x509.Certificate
350
351func init() {
352	cert, err := x509.ParseCertificate(testRSACertificate)
353	if err != nil {
354		panic(err)
355	}
356	sessionTestCerts = append(sessionTestCerts, cert)
357	cert, err = x509.ParseCertificate(testRSACertificateIssuer)
358	if err != nil {
359		panic(err)
360	}
361	sessionTestCerts = append(sessionTestCerts, cert)
362}
363
364func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
365	s := &SessionState{}
366	isTLS13 := rand.Intn(10) > 5
367	if isTLS13 {
368		s.version = VersionTLS13
369	} else {
370		s.version = uint16(rand.Intn(VersionTLS13))
371	}
372	s.isClient = rand.Intn(10) > 5
373	s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
374	s.createdAt = uint64(rand.Int63())
375	s.secret = randomBytes(rand.Intn(100)+1, rand)
376	for n, i := rand.Intn(3), 0; i < n; i++ {
377		s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
378	}
379	if rand.Intn(10) > 5 {
380		s.EarlyData = true
381	}
382	if rand.Intn(10) > 5 {
383		s.extMasterSecret = true
384	}
385	if s.isClient || rand.Intn(10) > 5 {
386		if rand.Intn(10) > 5 {
387			s.peerCertificates = sessionTestCerts
388		} else {
389			s.peerCertificates = sessionTestCerts[:1]
390		}
391	}
392	if rand.Intn(10) > 5 && s.peerCertificates != nil {
393		s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
394	}
395	if rand.Intn(10) > 5 && s.peerCertificates != nil {
396		for i := 0; i < rand.Intn(2)+1; i++ {
397			s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
398		}
399	}
400	if len(s.peerCertificates) > 0 {
401		for i := 0; i < rand.Intn(3); i++ {
402			if rand.Intn(10) > 5 {
403				s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
404			} else {
405				s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
406			}
407		}
408	}
409	if rand.Intn(10) > 5 && s.EarlyData {
410		s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
411	}
412	if s.isClient {
413		if isTLS13 {
414			s.useBy = uint64(rand.Int63())
415			s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
416		}
417	}
418	return reflect.ValueOf(s)
419}
420
421func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
422func (s *SessionState) unmarshal(b []byte) bool {
423	ss, err := ParseSessionState(b)
424	if err != nil {
425		return false
426	}
427	*s = *ss
428	return true
429}
430
431func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
432	m := &endOfEarlyDataMsg{}
433	return reflect.ValueOf(m)
434}
435
436func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
437	m := &keyUpdateMsg{}
438	m.updateRequested = rand.Intn(10) > 5
439	return reflect.ValueOf(m)
440}
441
442func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
443	m := &newSessionTicketMsgTLS13{}
444	m.lifetime = uint32(rand.Intn(500000))
445	m.ageAdd = uint32(rand.Intn(500000))
446	m.nonce = randomBytes(rand.Intn(100), rand)
447	m.label = randomBytes(rand.Intn(1000), rand)
448	if rand.Intn(10) > 5 {
449		m.maxEarlyData = uint32(rand.Intn(500000))
450	}
451	return reflect.ValueOf(m)
452}
453
454func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
455	m := &certificateRequestMsgTLS13{}
456	if rand.Intn(10) > 5 {
457		m.ocspStapling = true
458	}
459	if rand.Intn(10) > 5 {
460		m.scts = true
461	}
462	if rand.Intn(10) > 5 {
463		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
464	}
465	if rand.Intn(10) > 5 {
466		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
467	}
468	if rand.Intn(10) > 5 {
469		m.certificateAuthorities = make([][]byte, 3)
470		for i := 0; i < 3; i++ {
471			m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
472		}
473	}
474	return reflect.ValueOf(m)
475}
476
477func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
478	m := &certificateMsgTLS13{}
479	for i := 0; i < rand.Intn(2)+1; i++ {
480		m.certificate.Certificate = append(
481			m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
482	}
483	if rand.Intn(10) > 5 {
484		m.ocspStapling = true
485		m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
486	}
487	if rand.Intn(10) > 5 {
488		m.scts = true
489		for i := 0; i < rand.Intn(2)+1; i++ {
490			m.certificate.SignedCertificateTimestamps = append(
491				m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
492		}
493	}
494	return reflect.ValueOf(m)
495}
496
497func TestRejectEmptySCTList(t *testing.T) {
498	// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
499
500	var random [32]byte
501	sct := []byte{0x42, 0x42, 0x42, 0x42}
502	serverHello := &serverHelloMsg{
503		vers:   VersionTLS12,
504		random: random[:],
505		scts:   [][]byte{sct},
506	}
507	serverHelloBytes := mustMarshal(t, serverHello)
508
509	var serverHelloCopy serverHelloMsg
510	if !serverHelloCopy.unmarshal(serverHelloBytes) {
511		t.Fatal("Failed to unmarshal initial message")
512	}
513
514	// Change serverHelloBytes so that the SCT list is empty
515	i := bytes.Index(serverHelloBytes, sct)
516	if i < 0 {
517		t.Fatal("Cannot find SCT in ServerHello")
518	}
519
520	var serverHelloEmptySCT []byte
521	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
522	// Append the extension length and SCT list length for an empty list.
523	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
524	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
525
526	// Update the handshake message length.
527	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
528	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
529	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
530
531	// Update the extensions length
532	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
533	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
534
535	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
536		t.Fatal("Unmarshaled ServerHello with empty SCT list")
537	}
538}
539
540func TestRejectEmptySCT(t *testing.T) {
541	// Not only must the SCT list be non-empty, but the SCT elements must
542	// not be zero length.
543
544	var random [32]byte
545	serverHello := &serverHelloMsg{
546		vers:   VersionTLS12,
547		random: random[:],
548		scts:   [][]byte{nil},
549	}
550	serverHelloBytes := mustMarshal(t, serverHello)
551
552	var serverHelloCopy serverHelloMsg
553	if serverHelloCopy.unmarshal(serverHelloBytes) {
554		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
555	}
556}
557
558func TestRejectDuplicateExtensions(t *testing.T) {
559	clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
560	if err != nil {
561		t.Fatalf("failed to decode test ClientHello: %s", err)
562	}
563	var clientHelloCopy clientHelloMsg
564	if clientHelloCopy.unmarshal(clientHelloBytes) {
565		t.Error("Unmarshaled ClientHello with duplicate extensions")
566	}
567
568	serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
569	if err != nil {
570		t.Fatalf("failed to decode test ServerHello: %s", err)
571	}
572	var serverHelloCopy serverHelloMsg
573	if serverHelloCopy.unmarshal(serverHelloBytes) {
574		t.Fatal("Unmarshaled ServerHello with duplicate extensions")
575	}
576}
577